"""
chunker.py - Text chunking utility using tiktoken for token-aware splitting.

Chunking strategy:
1. Split on paragraph boundaries (\\n\\n)
2. If a paragraph exceeds max_tokens, split on sentence boundaries (. ! ?)
3. If a sentence exceeds max_tokens, force-split by token limit
4. Apply overlap: prepend the last `overlap` tokens of the previous chunk
   to the start of the next chunk.
"""

import re

import tiktoken


def chunk_text(
    text: str,
    max_tokens: int = 500,
    overlap: int = 50,
) -> list[dict[str, int | str]]:
    """Split text into token-aware chunks.

    Args:
        text: The input text to be chunked.
        max_tokens: Maximum number of tokens allowed per chunk.
        overlap: Number of tokens from the end of the previous chunk
                 to prepend to the start of the next chunk.

    Returns:
        A list of dicts with keys:
            - "content"     (str): The chunk text.
            - "chunk_index" (int): Zero-based index of the chunk.
            - "token_count" (int): Actual token count of the chunk content.
    """
    if not text or not text.strip():
        return []

    encoder: tiktoken.Encoding = tiktoken.get_encoding("cl100k_base")

    # Step 1: split into paragraphs on double newlines
    raw_paragraphs: list[str] = re.split(r"\n\n+", text)
    paragraphs: list[str] = [p.strip() for p in raw_paragraphs if p.strip()]

    # Step 2: break paragraphs into "segments" that fit within max_tokens
    segments: list[str] = []
    for para in paragraphs:
        segments.extend(_split_paragraph(para, max_tokens, encoder))

    # Step 3: merge segments into chunks, respecting max_tokens
    chunks: list[dict[str, int | str]] = _merge_segments(segments, max_tokens, overlap, encoder)

    return chunks


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _token_count(text: str, encoder: tiktoken.Encoding) -> int:
    """Return the number of tokens in *text*."""
    return len(encoder.encode(text))


def _split_paragraph(
    paragraph: str,
    max_tokens: int,
    encoder: tiktoken.Encoding,
) -> list[str]:
    """Return a list of segments from a single paragraph.

    If the paragraph fits within max_tokens it is returned as-is.
    Otherwise it is split on sentence boundaries, and each sentence group
    is further force-split if needed.
    """
    if _token_count(paragraph, encoder) <= max_tokens:
        return [paragraph]

    # Split on sentence-ending punctuation, keeping the delimiter
    raw_sentences: list[str] = re.split(r"(?<=[.!?])\s+", paragraph)
    sentences: list[str] = [s.strip() for s in raw_sentences if s.strip()]

    segments: list[str] = []
    current_parts: list[str] = []
    current_tokens: int = 0

    for sentence in sentences:
        sentence_tokens = _token_count(sentence, encoder)

        if sentence_tokens > max_tokens:
            # Flush current accumulation first
            if current_parts:
                segments.append(" ".join(current_parts))
                current_parts = []
                current_tokens = 0
            # Force-split the oversized sentence
            segments.extend(_force_split(sentence, max_tokens, encoder))
            continue

        # Adding a space separator costs 1 token in most cases; use a safe
        # estimate by re-encoding the joined string when close to the limit.
        sep_tokens = 1 if current_parts else 0
        if current_tokens + sep_tokens + sentence_tokens > max_tokens:
            if current_parts:
                segments.append(" ".join(current_parts))
                current_parts = []
                current_tokens = 0

        current_parts.append(sentence)
        current_tokens = _token_count(" ".join(current_parts), encoder)

    if current_parts:
        segments.append(" ".join(current_parts))

    return segments if segments else [paragraph]


def _force_split(
    text: str,
    max_tokens: int,
    encoder: tiktoken.Encoding,
) -> list[str]:
    """Force-split *text* purely by token count."""
    tokens: list[int] = encoder.encode(text)
    result: list[str] = []
    for start in range(0, len(tokens), max_tokens):
        chunk_tokens = tokens[start : start + max_tokens]
        result.append(encoder.decode(chunk_tokens))
    return result


def _merge_segments(
    segments: list[str],
    max_tokens: int,
    overlap: int,
    encoder: tiktoken.Encoding,
) -> list[dict[str, int | str]]:
    """Merge small segments together and apply overlap between chunks."""
    if not segments:
        return []

    # Build raw chunk texts by greedily packing segments
    raw_chunks: list[str] = []
    current_parts: list[str] = []
    current_tokens: int = 0

    for segment in segments:
        seg_tokens = _token_count(segment, encoder)

        sep_tokens = 1 if current_parts else 0
        if current_tokens + sep_tokens + seg_tokens > max_tokens:
            if current_parts:
                raw_chunks.append("\n\n".join(current_parts))
                current_parts = []
                current_tokens = 0

        current_parts.append(segment)
        current_tokens = _token_count("\n\n".join(current_parts), encoder)

    if current_parts:
        raw_chunks.append("\n\n".join(current_parts))

    # Apply overlap
    if overlap <= 0 or len(raw_chunks) <= 1:
        return [
            {
                "content": chunk,
                "chunk_index": idx,
                "token_count": _token_count(chunk, encoder),
            }
            for idx, chunk in enumerate(raw_chunks)
        ]

    result: list[dict[str, int | str]] = []
    for idx, chunk_text_val in enumerate(raw_chunks):
        if idx == 0:
            content = chunk_text_val
        else:
            prev_tokens: list[int] = encoder.encode(raw_chunks[idx - 1])
            curr_tokens: list[int] = encoder.encode(chunk_text_val)

            # Determine how many overlap tokens can be prepended without
            # exceeding max_tokens.
            desired_overlap = min(overlap, len(prev_tokens))
            available = max_tokens - len(curr_tokens)
            actual_overlap = max(0, min(desired_overlap, available))

            if actual_overlap > 0:
                overlap_tokens = prev_tokens[-actual_overlap:]
                overlap_text: str = encoder.decode(overlap_tokens)
                content = overlap_text + chunk_text_val
            else:
                content = chunk_text_val

        result.append(
            {
                "content": content,
                "chunk_index": idx,
                "token_count": _token_count(content, encoder),
            }
        )

    return result
