import math
import os
from struct import pack, unpack, calcsize
from typing import BinaryIO, Dict, Iterable, List, Optional, Tuple, Union, cast

# segment structure base
SEG_STRUCT = [
    (">L", "number"),
    (">B", "flags"),
    (">B", "retention_flags"),
    (">B", "page_assoc"),
    (">L", "data_length"),
]

# segment header literals
HEADER_FLAG_DEFERRED = 0b10000000
HEADER_FLAG_PAGE_ASSOC_LONG = 0b01000000

SEG_TYPE_MASK = 0b00111111

REF_COUNT_SHORT_MASK = 0b11100000
REF_COUNT_LONG_MASK = 0x1FFFFFFF
REF_COUNT_LONG = 7

DATA_LEN_UNKNOWN = 0xFFFFFFFF

# segment types
SEG_TYPE_IMMEDIATE_GEN_REGION = 38
SEG_TYPE_END_OF_PAGE = 49
SEG_TYPE_END_OF_FILE = 51

# file literals
FILE_HEADER_ID = b"\x97\x4A\x42\x32\x0D\x0A\x1A\x0A"
FILE_HEAD_FLAG_SEQUENTIAL = 0b00000001


def bit_set(bit_pos: int, value: int) -> bool:
    return bool((value >> bit_pos) & 1)


def check_flag(flag: int, value: int) -> bool:
    return bool(flag & value)


def masked_value(mask: int, value: int) -> int:
    for bit_pos in range(0, 31):
        if bit_set(bit_pos, mask):
            return (value & mask) >> bit_pos

    raise Exception("Invalid mask or value")


def mask_value(mask: int, value: int) -> int:
    for bit_pos in range(0, 31):
        if bit_set(bit_pos, mask):
            return (value & (mask >> bit_pos)) << bit_pos

    raise Exception("Invalid mask or value")


def unpack_int(format: str, buffer: bytes) -> int:
    assert format in {">B", ">I", ">L"}
    [result] = cast(Tuple[int], unpack(format, buffer))
    return result


JBIG2SegmentFlags = Dict[str, Union[int, bool]]
JBIG2RetentionFlags = Dict[str, Union[int, List[int], List[bool]]]
JBIG2Segment = Dict[
    str, Union[bool, int, bytes, JBIG2SegmentFlags, JBIG2RetentionFlags]
]


class JBIG2StreamReader:
    """Read segments from a JBIG2 byte stream"""

    def __init__(self, stream: BinaryIO) -> None:
        self.stream = stream

    def get_segments(self) -> List[JBIG2Segment]:
        segments: List[JBIG2Segment] = []
        while not self.is_eof():
            segment: JBIG2Segment = {}
            for field_format, name in SEG_STRUCT:
                field_len = calcsize(field_format)
                field = self.stream.read(field_len)
                if len(field) < field_len:
                    segment["_error"] = True
                    break
                value = unpack_int(field_format, field)
                parser = getattr(self, "parse_%s" % name, None)
                if callable(parser):
                    value = parser(segment, value, field)
                segment[name] = value

            if not segment.get("_error"):
                segments.append(segment)
        return segments

    def is_eof(self) -> bool:
        if self.stream.read(1) == b"":
            return True
        else:
            self.stream.seek(-1, os.SEEK_CUR)
            return False

    def parse_flags(
        self, segment: JBIG2Segment, flags: int, field: bytes
    ) -> JBIG2SegmentFlags:
        return {
            "deferred": check_flag(HEADER_FLAG_DEFERRED, flags),
            "page_assoc_long": check_flag(HEADER_FLAG_PAGE_ASSOC_LONG, flags),
            "type": masked_value(SEG_TYPE_MASK, flags),
        }

    def parse_retention_flags(
        self, segment: JBIG2Segment, flags: int, field: bytes
    ) -> JBIG2RetentionFlags:
        ref_count = masked_value(REF_COUNT_SHORT_MASK, flags)
        retain_segments = []
        ref_segments = []

        if ref_count < REF_COUNT_LONG:
            for bit_pos in range(5):
                retain_segments.append(bit_set(bit_pos, flags))
        else:
            field += self.stream.read(3)
            ref_count = unpack_int(">L", field)
            ref_count = masked_value(REF_COUNT_LONG_MASK, ref_count)
            ret_bytes_count = int(math.ceil((ref_count + 1) / 8))
            for ret_byte_index in range(ret_bytes_count):
                ret_byte = unpack_int(">B", self.stream.read(1))
                for bit_pos in range(7):
                    retain_segments.append(bit_set(bit_pos, ret_byte))

        seg_num = segment["number"]
        assert isinstance(seg_num, int)
        if seg_num <= 256:
            ref_format = ">B"
        elif seg_num <= 65536:
            ref_format = ">I"
        else:
            ref_format = ">L"

        ref_size = calcsize(ref_format)

        for ref_index in range(ref_count):
            ref_data = self.stream.read(ref_size)
            ref = unpack_int(ref_format, ref_data)
            ref_segments.append(ref)

        return {
            "ref_count": ref_count,
            "retain_segments": retain_segments,
            "ref_segments": ref_segments,
        }

    def parse_page_assoc(self, segment: JBIG2Segment, page: int, field: bytes) -> int:
        if cast(JBIG2SegmentFlags, segment["flags"])["page_assoc_long"]:
            field += self.stream.read(3)
            page = unpack_int(">L", field)
        return page

    def parse_data_length(
        self, segment: JBIG2Segment, length: int, field: bytes
    ) -> int:
        if length:
            if (
                cast(JBIG2SegmentFlags, segment["flags"])["type"]
                == SEG_TYPE_IMMEDIATE_GEN_REGION
            ) and (length == DATA_LEN_UNKNOWN):

                raise NotImplementedError(
                    "Working with unknown segment length " "is not implemented yet"
                )
            else:
                segment["raw_data"] = self.stream.read(length)

        return length


class JBIG2StreamWriter:
    """Write JBIG2 segments to a file in JBIG2 format"""

    EMPTY_RETENTION_FLAGS: JBIG2RetentionFlags = {
        "ref_count": 0,
        "ref_segments": cast(List[int], []),
        "retain_segments": cast(List[bool], []),
    }

    def __init__(self, stream: BinaryIO) -> None:
        self.stream = stream

    def write_segments(
        self, segments: Iterable[JBIG2Segment], fix_last_page: bool = True
    ) -> int:
        data_len = 0
        current_page: Optional[int] = None
        seg_num: Optional[int] = None

        for segment in segments:
            data = self.encode_segment(segment)
            self.stream.write(data)
            data_len += len(data)

            seg_num = cast(Optional[int], segment["number"])

            if fix_last_page:
                seg_page = cast(int, segment.get("page_assoc"))

                if (
                    cast(JBIG2SegmentFlags, segment["flags"])["type"]
                    == SEG_TYPE_END_OF_PAGE
                ):
                    current_page = None
                elif seg_page:
                    current_page = seg_page

        if fix_last_page and current_page and (seg_num is not None):
            segment = self.get_eop_segment(seg_num + 1, current_page)
            data = self.encode_segment(segment)
            self.stream.write(data)
            data_len += len(data)

        return data_len

    def write_file(
        self, segments: Iterable[JBIG2Segment], fix_last_page: bool = True
    ) -> int:
        header = FILE_HEADER_ID
        header_flags = FILE_HEAD_FLAG_SEQUENTIAL
        header += pack(">B", header_flags)
        # The embedded JBIG2 files in a PDF always
        # only have one page
        number_of_pages = pack(">L", 1)
        header += number_of_pages
        self.stream.write(header)
        data_len = len(header)

        data_len += self.write_segments(segments, fix_last_page)

        seg_num = 0
        for segment in segments:
            seg_num = cast(int, segment["number"])

        if fix_last_page:
            seg_num_offset = 2
        else:
            seg_num_offset = 1
        eof_segment = self.get_eof_segment(seg_num + seg_num_offset)
        data = self.encode_segment(eof_segment)

        self.stream.write(data)
        data_len += len(data)

        return data_len

    def encode_segment(self, segment: JBIG2Segment) -> bytes:
        data = b""
        for field_format, name in SEG_STRUCT:
            value = segment.get(name)
            encoder = getattr(self, "encode_%s" % name, None)
            if callable(encoder):
                field = encoder(value, segment)
            else:
                field = pack(field_format, value)
            data += field
        return data

    def encode_flags(self, value: JBIG2SegmentFlags, segment: JBIG2Segment) -> bytes:
        flags = 0
        if value.get("deferred"):
            flags |= HEADER_FLAG_DEFERRED

        if "page_assoc_long" in value:
            flags |= HEADER_FLAG_PAGE_ASSOC_LONG if value["page_assoc_long"] else flags
        else:
            flags |= (
                HEADER_FLAG_PAGE_ASSOC_LONG
                if cast(int, segment.get("page", 0)) > 255
                else flags
            )

        flags |= mask_value(SEG_TYPE_MASK, value["type"])

        return pack(">B", flags)

    def encode_retention_flags(
        self, value: JBIG2RetentionFlags, segment: JBIG2Segment
    ) -> bytes:
        flags = []
        flags_format = ">B"
        ref_count = value["ref_count"]
        assert isinstance(ref_count, int)
        retain_segments = cast(List[bool], value.get("retain_segments", []))

        if ref_count <= 4:
            flags_byte = mask_value(REF_COUNT_SHORT_MASK, ref_count)
            for ref_index, ref_retain in enumerate(retain_segments):
                if ref_retain:
                    flags_byte |= 1 << ref_index
            flags.append(flags_byte)
        else:
            bytes_count = math.ceil((ref_count + 1) / 8)
            flags_format = ">L" + ("B" * bytes_count)
            flags_dword = mask_value(REF_COUNT_SHORT_MASK, REF_COUNT_LONG) << 24
            flags.append(flags_dword)

            for byte_index in range(bytes_count):
                ret_byte = 0
                ret_part = retain_segments[byte_index * 8 : byte_index * 8 + 8]
                for bit_pos, ret_seg in enumerate(ret_part):
                    ret_byte |= 1 << bit_pos if ret_seg else ret_byte

                flags.append(ret_byte)

        ref_segments = cast(List[int], value.get("ref_segments", []))

        seg_num = cast(int, segment["number"])
        if seg_num <= 256:
            ref_format = "B"
        elif seg_num <= 65536:
            ref_format = "I"
        else:
            ref_format = "L"

        for ref in ref_segments:
            flags_format += ref_format
            flags.append(ref)

        return pack(flags_format, *flags)

    def encode_data_length(self, value: int, segment: JBIG2Segment) -> bytes:
        data = pack(">L", value)
        data += cast(bytes, segment["raw_data"])
        return data

    def get_eop_segment(self, seg_number: int, page_number: int) -> JBIG2Segment:
        return {
            "data_length": 0,
            "flags": {"deferred": False, "type": SEG_TYPE_END_OF_PAGE},
            "number": seg_number,
            "page_assoc": page_number,
            "raw_data": b"",
            "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
        }

    def get_eof_segment(self, seg_number: int) -> JBIG2Segment:
        return {
            "data_length": 0,
            "flags": {"deferred": False, "type": SEG_TYPE_END_OF_FILE},
            "number": seg_number,
            "page_assoc": 0,
            "raw_data": b"",
            "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
        }
