#!/usr/bin/env python3
"""
utils/context_compressor.py — 5단계 컨텍스트 압축 엔진

Phase1: 오래된 툴 결과 플레이스홀더 교체 | Phase2: head 보호
Phase3: tail 토큰 예산 보호 | Phase4: 중간 구역 규칙 기반 요약
Phase5: 툴 쌍 무결성 복원 (고아 result 제거, 고아 call 스텁 삽입)
토큰 추정: 한국어 비율 > 30%이면 chars / 2.5, 아니면 chars / 4
"""

from __future__ import annotations

import json
from typing import Any

from utils.context_summarizer import generate_summary
from utils.logger import get_logger

logger = get_logger(__name__)

_TOOL_RESULT_MIN_LEN = 200
_CHARS_PER_TOKEN = 4
_CHARS_PER_TOKEN_KO = 2.5
_KOREAN_THRESHOLD = 0.30


def _detect_korean_ratio(text: str) -> float:
    """텍스트 내 한글 음절 비율을 반환한다 (0.0 ~ 1.0)."""
    if not text:
        return 0.0
    korean_count = sum(1 for ch in text if '\uAC00' <= ch <= '\uD7AF')
    return korean_count / len(text)


def _estimate_tokens(text: str) -> int:
    """텍스트의 토큰 수를 추정한다. 한국어 비율에 따라 보정."""
    if not text:
        return 1
    ratio = _detect_korean_ratio(text)
    if ratio > _KOREAN_THRESHOLD:
        return max(1, int(len(text) / _CHARS_PER_TOKEN_KO))
    return max(1, len(text) // _CHARS_PER_TOKEN)


def _msg_tokens(msg: dict[str, Any]) -> int:
    content = msg.get("content") or ""
    tcs = msg.get("tool_calls")
    extra = ""
    if tcs:
        try:
            extra = json.dumps(tcs, ensure_ascii=False)
        except (TypeError, ValueError):
            extra = str(tcs)
    return _estimate_tokens(content + extra)


class ContextCompressor:
    """5단계 컨텍스트 압축기."""

    def __init__(
        self,
        context_limit: int = 128_000,
        threshold_percent: float = 0.50,
        protect_first_n: int = 3,
        tail_token_budget: int = 20_000,
    ) -> None:
        self.context_limit = context_limit
        self.threshold_percent = threshold_percent
        self.protect_first_n = protect_first_n
        self.tail_token_budget = tail_token_budget
        self._previous_summary: str | None = None

    def should_compress(self, messages: list[dict]) -> bool:
        """총 토큰이 threshold를 초과하면 True."""
        if not messages:
            return False
        total = sum(_msg_tokens(m) for m in messages)
        return total > int(self.context_limit * self.threshold_percent)

    def compress(self, messages: list[dict]) -> list[dict]:
        """5단계 압축을 수행하고 압축된 메시지 목록을 반환한다."""
        if not messages:
            return []

        head_end = min(self.protect_first_n, len(messages)) - 1

        # Phase 1: 오래된 툴 결과 교체
        tail_est = self._find_tail_cut_by_tokens(messages, head_end)
        msgs, pruned = self._prune_old_tool_results(messages, len(messages) - tail_est)
        if pruned:
            logger.debug("Phase1: pruned %d tool results", pruned)

        # Phase 2+3: 경계 정렬
        head_end = self._align_boundary(msgs, min(self.protect_first_n, len(msgs)) - 1, "forward")
        tail_start = self._align_boundary(msgs, self._find_tail_cut_by_tokens(msgs, head_end), "backward")

        mid_start, mid_end = head_end + 1, tail_start
        if mid_start >= mid_end:
            return self._sanitize_tool_pairs(msgs)

        head, middle, tail = msgs[:mid_start], msgs[mid_start:mid_end], msgs[mid_end:]

        # Phase 4: 중간 구역 요약
        summary_text = generate_summary(middle, prev_summary=self._previous_summary)
        if summary_text:
            self._previous_summary = summary_text
            compressed = head + [{"role": "user", "content": summary_text}] + tail
        else:
            compressed = head + tail

        # Phase 5: 툴 쌍 무결성
        result = self._sanitize_tool_pairs(compressed)
        logger.debug("compress: %d → %d messages (mid=%d)", len(messages), len(result), len(middle))
        return result

    def _prune_old_tool_results(self, messages: list[dict], protect_tail_count: int) -> tuple[list[dict], int]:
        """200자 이상 오래된 툴 결과를 플레이스홀더로 교체. (count, pruned) 반환."""
        result: list[dict] = []
        pruned = 0
        tail_boundary = len(messages) - protect_tail_count
        for i, msg in enumerate(messages):
            if (
                msg.get("role") == "tool"
                and i < tail_boundary
                and len(msg.get("content") or "") >= _TOOL_RESULT_MIN_LEN
            ):
                orig_len = len(msg.get("content") or "")
                replacement = dict(msg)
                replacement["content"] = f"[tool result truncated — original {orig_len} chars]"
                result.append(replacement)
                pruned += 1
            else:
                result.append(msg)
        return result, pruned

    def _find_tail_cut_by_tokens(self, messages: list[dict], head_end: int) -> int:
        """역방향 토큰 누적으로 tail 시작 인덱스를 반환한다 (head_end+1 이상 보장)."""
        accumulated = 0
        tail_start = len(messages)
        for i in range(len(messages) - 1, head_end, -1):
            accumulated += _msg_tokens(messages[i])
            if accumulated <= self.tail_token_budget:
                tail_start = i
            else:
                break
        return max(tail_start, head_end + 1)

    def _sanitize_tool_pairs(self, messages: list[dict]) -> list[dict]:
        """툴 쌍 무결성 복원: 고아 result 제거 → 고아 call에 스텁 삽입."""
        # 존재하는 tool_call id 수집
        all_call_ids: set[str] = {
            tc["id"]
            for msg in messages
            for tc in (msg.get("tool_calls") or [])
            if isinstance(tc, dict) and tc.get("id")
        }
        # 고아 tool_result 제거
        cleaned: list[dict] = [
            msg for msg in messages if not (msg.get("role") == "tool" and msg.get("tool_call_id") not in all_call_ids)
        ]
        # result_ids 재수집
        result_ids: set[str] = {
            msg["tool_call_id"] for msg in cleaned if msg.get("role") == "tool" and msg.get("tool_call_id")
        }
        # 고아 tool_call에 스텁 삽입
        final: list[dict] = []
        for msg in cleaned:
            final.append(msg)
            for tc in msg.get("tool_calls") or []:
                if isinstance(tc, dict):
                    cid = tc.get("id", "")
                    if cid and cid not in result_ids:
                        final.append(
                            {
                                "role": "tool",
                                "tool_call_id": cid,
                                "content": "[tool result not available — context compressed]",
                            }
                        )
                        result_ids.add(cid)
        return final

    def _align_boundary(self, messages: list[dict], idx: int, direction: str) -> int:
        """툴콜 그룹 중간에 경계가 걸리지 않도록 idx를 조정한다.

        direction="forward": 그룹 전체를 head에 포함, idx를 그룹 밖으로 전진
        direction="backward": 그룹 전체를 tail에서 제외, idx를 그룹 밖으로 후퇴
        """
        if not messages:
            return idx
        n = len(messages)
        idx = max(0, min(idx, n - 1))
        if direction == "forward":
            while idx < n - 1:
                cur, nxt = messages[idx], messages[idx + 1]
                if cur.get("tool_calls") or cur.get("role") == "tool" or nxt.get("role") == "tool":
                    idx += 1
                else:
                    break
        else:
            while idx > 0:
                cur, prev = messages[idx], messages[idx - 1]
                if cur.get("role") == "tool" or prev.get("tool_calls"):
                    idx -= 1
                else:
                    break
        return idx
