#!/usr/bin/env python3
"""
utils/context_compressor.py + context_summarizer.py 테스트 스위트 (TDD — RED → GREEN)

최소 15개 테스트:
- should_compress (임계값 이하/초과)
- 5단계 각각 단독 동작
- 툴 쌍 무결성 (고아 결과 제거, 고아 호출 스텁)
- 경계 정렬 (툴콜 그룹 중간 분할 방지)
- 빈 메시지, 모든 메시지 보호, 반복 압축
"""

from __future__ import annotations

import sys
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from utils.context_compressor import ContextCompressor
from utils.context_summarizer import generate_summary, serialize_for_summary

# ---------------------------------------------------------------------------
# 헬퍼
# ---------------------------------------------------------------------------

_MODEL_CONTEXT = 10_000  # 테스트용 컨텍스트 한계 (chars)


def _msg(role: str, content: str = "", tool_call_id: str | None = None, tool_calls: list | None = None) -> dict:
    """테스트용 메시지 dict 생성 헬퍼."""
    m: dict = {"role": role, "content": content}
    if tool_call_id is not None:
        m["tool_call_id"] = tool_call_id
    if tool_calls is not None:
        m["tool_calls"] = tool_calls
    return m


def _big(n: int = 300) -> str:
    """n글자 긴 문자열."""
    return "x" * n


def _make_compressor(**kwargs) -> ContextCompressor:
    """기본 파라미터 compressor 생성."""
    return ContextCompressor(
        context_limit=_MODEL_CONTEXT,
        threshold_percent=kwargs.get("threshold_percent", 0.50),
        protect_first_n=kwargs.get("protect_first_n", 3),
        tail_token_budget=kwargs.get("tail_token_budget", 2_000),
    )


# ---------------------------------------------------------------------------
# 1. should_compress
# ---------------------------------------------------------------------------


class TestShouldCompress:
    """임계값 기반 압축 필요 여부 판단"""

    def test_below_threshold_returns_false(self):
        """메시지 총 토큰이 임계값 미만이면 False."""
        cc = _make_compressor(threshold_percent=0.50)
        # 100자 * 1메시지 = 25토큰 (chars/4) << 5000토큰 임계값
        msgs = [_msg("user", "x" * 100)]
        assert cc.should_compress(msgs) is False

    def test_above_threshold_returns_true(self):
        """메시지 총 토큰이 임계값 초과이면 True."""
        cc = _make_compressor(threshold_percent=0.50)
        # context_limit=10000 → 임계=5000토큰=20000chars
        msgs = [_msg("user", "x" * 25_000)]
        assert cc.should_compress(msgs) is True

    def test_empty_messages_returns_false(self):
        """빈 메시지 목록이면 False."""
        cc = _make_compressor()
        assert cc.should_compress([]) is False

    def test_exactly_at_threshold_returns_false(self):
        """정확히 임계값이면 False (초과가 아니므로)."""
        cc = _make_compressor(threshold_percent=0.50)
        # 임계 = 5000토큰 = 20000chars — 딱 맞으면 압축 불필요
        msgs = [_msg("user", "x" * 20_000)]
        assert cc.should_compress(msgs) is False


# ---------------------------------------------------------------------------
# 2. Phase 1 — _prune_old_tool_results
# ---------------------------------------------------------------------------


class TestPruneOldToolResults:
    """오래된 툴 결과 플레이스홀더 교체"""

    def test_long_tool_result_replaced_with_placeholder(self):
        """200자 이상 오래된 tool 메시지가 플레이스홀더로 교체된다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "hello"),
            _msg("tool", _big(250), tool_call_id="call_1"),
            _msg("user", "next"),
        ]
        pruned, count = cc._prune_old_tool_results(msgs, protect_tail_count=0)
        tool_msg = next(m for m in pruned if m.get("role") == "tool")
        assert "[tool result truncated" in tool_msg["content"]
        assert count == 1

    def test_short_tool_result_not_replaced(self):
        """200자 미만 툴 결과는 교체되지 않는다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "hi"),
            _msg("tool", "short result", tool_call_id="call_2"),
        ]
        pruned, count = cc._prune_old_tool_results(msgs, protect_tail_count=0)
        tool_msg = next(m for m in pruned if m.get("role") == "tool")
        assert tool_msg["content"] == "short result"
        assert count == 0

    def test_tail_protected_tool_result_not_pruned(self):
        """tail 보호 구간 내 툴 결과는 교체되지 않는다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "q1"),
            _msg("tool", _big(300), tool_call_id="call_3"),
        ]
        # protect_tail_count=2 → 두 메시지 모두 tail에 속함
        pruned, count = cc._prune_old_tool_results(msgs, protect_tail_count=2)
        tool_msg = next(m for m in pruned if m.get("role") == "tool")
        assert "[tool result truncated" not in tool_msg["content"]
        assert count == 0


# ---------------------------------------------------------------------------
# 3. Phase 3 — _find_tail_cut_by_tokens
# ---------------------------------------------------------------------------


class TestFindTailCutByTokens:
    """역방향 토큰 누적 tail 경계 탐색"""

    def test_tail_cut_within_budget(self):
        """tail_token_budget 내에서 역방향으로 경계를 찾는다."""
        cc = _make_compressor(tail_token_budget=100)  # 100토큰 = 400chars
        msgs = [
            _msg("user", "head1"),
            _msg("user", "head2"),
            _msg("user", "x" * 100),  # 25토큰
            _msg("user", "x" * 200),  # 50토큰
            _msg("user", "x" * 100),  # 25토큰 → 합계 100토큰
        ]
        tail_start = cc._find_tail_cut_by_tokens(msgs, head_end=1)
        # tail_start는 head_end+1 이상이어야 한다
        assert tail_start > 1
        assert tail_start < len(msgs)

    def test_tail_cut_never_overlaps_head(self):
        """tail 시작점이 항상 head_end 이후여야 한다."""
        cc = _make_compressor(tail_token_budget=10_000)
        msgs = [_msg("user", f"m{i}") for i in range(10)]
        tail_start = cc._find_tail_cut_by_tokens(msgs, head_end=2)
        assert tail_start >= 3


# ---------------------------------------------------------------------------
# 4. Phase 5 — _sanitize_tool_pairs
# ---------------------------------------------------------------------------


class TestSanitizeToolPairs:
    """툴 쌍 무결성: 고아 결과 제거, 고아 호출 스텁 삽입"""

    def test_orphan_tool_result_removed(self):
        """대응하는 tool_call이 없는 tool_result가 제거된다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "hi"),
            _msg("tool", "orphan result", tool_call_id="missing_call"),
            _msg("assistant", "ok"),
        ]
        sanitized = cc._sanitize_tool_pairs(msgs)
        tool_msgs = [m for m in sanitized if m.get("role") == "tool"]
        assert len(tool_msgs) == 0

    def test_orphan_tool_call_gets_stub(self):
        """대응하는 tool_result가 없는 tool_call에 스텁이 삽입된다."""
        cc = _make_compressor()
        msgs = [
            _msg(
                "assistant",
                tool_calls=[{"id": "call_orphan", "type": "function", "function": {"name": "do_thing"}}],
            ),
            _msg("user", "next question"),
        ]
        sanitized = cc._sanitize_tool_pairs(msgs)
        tool_results = [m for m in sanitized if m.get("role") == "tool"]
        assert len(tool_results) == 1
        assert tool_results[0].get("tool_call_id") == "call_orphan"

    def test_matched_pair_preserved(self):
        """올바른 툴 쌍은 그대로 보존된다."""
        cc = _make_compressor()
        msgs = [
            _msg(
                "assistant",
                tool_calls=[{"id": "call_ok", "type": "function", "function": {"name": "fn"}}],
            ),
            _msg("tool", "result", tool_call_id="call_ok"),
        ]
        sanitized = cc._sanitize_tool_pairs(msgs)
        tool_results = [m for m in sanitized if m.get("role") == "tool" and m.get("content") == "result"]
        assert len(tool_results) == 1

    def test_stub_inserted_before_orphan_result_removed(self):
        """스텁 삽입이 고아 결과 제거보다 먼저 처리되어 새 스텁이 제거되지 않는다."""
        cc = _make_compressor()
        msgs = [
            _msg(
                "assistant",
                tool_calls=[{"id": "call_x", "type": "function", "function": {"name": "f"}}],
            ),
            # call_x에 대한 result 없음 → 스텁 삽입
            _msg("user", "after"),
        ]
        sanitized = cc._sanitize_tool_pairs(msgs)
        # 스텁이 삽입되어야 하고, 제거되지 않아야 한다
        stubs = [m for m in sanitized if m.get("role") == "tool" and m.get("tool_call_id") == "call_x"]
        assert len(stubs) == 1


# ---------------------------------------------------------------------------
# 5. Phase 2+3+4 — compress 통합
# ---------------------------------------------------------------------------


class TestCompressIntegration:
    """compress() 전체 파이프라인 통합 테스트"""

    def test_compress_reduces_total_length(self):
        """compress() 후 메시지 총 길이가 감소한다."""
        cc = _make_compressor(protect_first_n=1, tail_token_budget=500)
        # 큰 메시지가 가운데에 있는 시나리오
        msgs = (
            [_msg("user", "system setup")]
            + [_msg("user", "q"), _msg("assistant", _big(500))] * 5
            + [_msg("user", "final question")]
        )
        original_len = sum(len(m.get("content") or "") for m in msgs)
        compressed = cc.compress(msgs)
        compressed_len = sum(len(m.get("content") or "") for m in compressed)
        assert compressed_len < original_len

    def test_compress_protects_first_n_messages(self):
        """compress() 후 첫 protect_first_n 메시지가 그대로 유지된다."""
        cc = _make_compressor(protect_first_n=2, tail_token_budget=200)
        msgs = [
            _msg("system", "you are an assistant"),
            _msg("user", "first user message"),
            _msg("assistant", _big(500)),
            _msg("user", _big(500)),
            _msg("assistant", "final"),
        ]
        compressed = cc.compress(msgs)
        # 첫 두 메시지가 원본 그대로 보존
        assert compressed[0]["content"] == "you are an assistant"
        assert compressed[1]["content"] == "first user message"

    def test_compress_empty_messages_returns_empty(self):
        """빈 메시지 목록을 compress해도 빈 목록 반환."""
        cc = _make_compressor()
        assert cc.compress([]) == []

    def test_compress_all_protected_returns_original(self):
        """모든 메시지가 보호 구간에 속하면 원본을 그대로 반환한다."""
        cc = _make_compressor(protect_first_n=10, tail_token_budget=5_000)
        msgs = [_msg("user", f"msg{i}") for i in range(5)]
        compressed = cc.compress(msgs)
        # 압축 대상이 없으므로 실질적으로 동일해야 함
        assert len(compressed) == len(msgs)

    def test_compress_repeated_idempotent_structure(self):
        """반복 압축 후에도 툴 쌍 무결성이 유지된다."""
        cc = _make_compressor(protect_first_n=1, tail_token_budget=200)
        msgs = [
            _msg("user", "start"),
            _msg(
                "assistant",
                tool_calls=[{"id": "c1", "type": "function", "function": {"name": "fn"}}],
            ),
            _msg("tool", "res", tool_call_id="c1"),
            _msg("user", "followup"),
            _msg("assistant", "done"),
        ]
        result1 = cc.compress(msgs)
        result2 = cc.compress(result1)
        # 두 번 압축 후에도 고아 tool 메시지가 없어야 한다
        tool_call_ids = set()
        for m in result2:
            if m.get("tool_calls"):
                for tc in m["tool_calls"]:
                    tool_call_ids.add(tc["id"])
        for m in result2:
            if m.get("role") == "tool":
                assert m.get("tool_call_id") in tool_call_ids


# ---------------------------------------------------------------------------
# 6. _align_boundary
# ---------------------------------------------------------------------------


class TestAlignBoundary:
    """툴콜 그룹 중간 분할 방지 경계 정렬"""

    def test_align_forward_skips_tool_group(self):
        """forward 방향 정렬이 툴콜/결과 그룹 경계를 넘어간다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "q"),
            _msg("assistant", tool_calls=[{"id": "c1", "type": "function", "function": {"name": "f"}}]),
            _msg("tool", "res", tool_call_id="c1"),
            _msg("user", "next"),
        ]
        # 경계가 index=1(tool_call assistant)에 걸리면 forward로 3까지 이동
        aligned = cc._align_boundary(msgs, idx=1, direction="forward")
        assert aligned >= 3

    def test_align_backward_skips_tool_group(self):
        """backward 방향 정렬이 툴콜/결과 그룹 경계를 뒤로 보낸다."""
        cc = _make_compressor()
        msgs = [
            _msg("user", "q"),
            _msg("assistant", tool_calls=[{"id": "c2", "type": "function", "function": {"name": "g"}}]),
            _msg("tool", "res2", tool_call_id="c2"),
            _msg("user", "next"),
        ]
        # 경계가 index=2(tool_result)에 걸리면 backward로 1 이하로 이동
        aligned = cc._align_boundary(msgs, idx=2, direction="backward")
        assert aligned <= 1


# ---------------------------------------------------------------------------
# 7. context_summarizer
# ---------------------------------------------------------------------------


class TestContextSummarizer:
    """규칙 기반 요약 생성"""

    def test_generate_summary_returns_string(self):
        """generate_summary()가 문자열 또는 None을 반환한다."""
        turns = [
            _msg("user", "what is the capital of France?"),
            _msg("assistant", "Paris"),
        ]
        result = generate_summary(turns)
        assert result is None or isinstance(result, str)

    def test_generate_summary_empty_turns_returns_none(self):
        """빈 turns로 generate_summary() 호출 시 None 반환."""
        result = generate_summary([])
        assert result is None

    def test_generate_summary_includes_role_content(self):
        """요약이 turns의 핵심 내용을 포함한다."""
        turns = [
            _msg("user", "deploy the service"),
            _msg("assistant", "deploying now"),
        ]
        result = generate_summary(turns)
        if result is not None:
            # 요약에 뭔가 의미있는 내용이 포함되어야 함
            assert len(result) > 0

    def test_serialize_for_summary_basic(self):
        """serialize_for_summary()가 turns를 텍스트로 변환한다."""
        turns = [
            _msg("user", "hello"),
            _msg("assistant", "world"),
        ]
        text = serialize_for_summary(turns)
        assert isinstance(text, str)
        assert "hello" in text
        assert "world" in text

    def test_serialize_includes_tool_call_info(self):
        """serialize_for_summary()가 tool_calls 정보를 포함한다."""
        turns = [
            _msg(
                "assistant",
                tool_calls=[{"id": "c1", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}],
            ),
            _msg("tool", "file contents here", tool_call_id="c1"),
        ]
        text = serialize_for_summary(turns)
        assert "read_file" in text

    def test_generate_summary_with_prev_summary(self):
        """prev_summary가 주어지면 이를 반영한 요약이 생성된다."""
        turns = [_msg("user", "continue"), _msg("assistant", "ok")]
        prev = "Previously: discussed deployment strategy."
        result = generate_summary(turns, prev_summary=prev)
        # None이거나 문자열이어야 한다
        assert result is None or isinstance(result, str)


# ---------------------------------------------------------------------------
# 8. _estimate_tokens 한국어 보정
# ---------------------------------------------------------------------------


class TestEstimateTokensKorean:
    """한국어 비율에 따른 토큰 추정 보정 테스트"""

    def test_english_text_uses_default_ratio(self):
        """영어 텍스트는 chars/4 비율로 추정한다."""
        from utils.context_compressor import _estimate_tokens
        text = "a" * 400
        assert _estimate_tokens(text) == 100  # 400 / 4

    def test_korean_text_uses_korean_ratio(self):
        """한국어 비율 > 30%이면 chars/2.5 비율로 추정한다."""
        from utils.context_compressor import _estimate_tokens
        text = "안녕하세요 반갑습니다 테스트입니다"  # 한국어 비율 높음
        tokens = _estimate_tokens(text)
        # chars/2.5 기준으로 추정되어야 함
        expected = int(len(text) / 2.5)
        assert tokens == max(1, expected)

    def test_mixed_text_below_threshold_uses_default(self):
        """한국어 비율 < 30%이면 기본 chars/4 비율로 추정한다."""
        from utils.context_compressor import _estimate_tokens
        # 영어 70% + 한국어 30% 미만
        text = "a" * 80 + "가나다" * 3  # 80 + 9 = 89 chars, 한국어 9/89 ≈ 10%
        tokens = _estimate_tokens(text)
        assert tokens == max(1, len(text) // 4)

    def test_empty_text_returns_one(self):
        """빈 텍스트는 1을 반환한다."""
        from utils.context_compressor import _estimate_tokens
        assert _estimate_tokens("") == 1

    def test_detect_korean_ratio_pure_korean(self):
        """순수 한국어 텍스트의 비율이 1.0에 가깝다."""
        from utils.context_compressor import _detect_korean_ratio
        text = "가나다라마바사아"
        ratio = _detect_korean_ratio(text)
        assert ratio == 1.0

    def test_detect_korean_ratio_empty(self):
        """빈 텍스트는 0.0을 반환한다."""
        from utils.context_compressor import _detect_korean_ratio
        assert _detect_korean_ratio("") == 0.0

    def test_detect_korean_ratio_mixed(self):
        """혼합 텍스트에서 한글 비율이 올바르게 계산된다."""
        from utils.context_compressor import _detect_korean_ratio
        text = "hello가나다"  # 5 + 3 = 8 chars, 한글 3개
        ratio = _detect_korean_ratio(text)
        assert abs(ratio - 3/8) < 0.01


# ---------------------------------------------------------------------------
# 9. context_summarizer use_llm 옵션
# ---------------------------------------------------------------------------


class TestGenerateSummaryUseLlm:
    """use_llm 파라미터 동작 테스트 (mock 사용)"""

    def test_use_llm_false_returns_rule_based(self):
        """use_llm=False는 기존 규칙 기반 요약을 반환한다."""
        turns = [
            _msg("user", "deploy the service"),
            _msg("assistant", "deploying now"),
        ]
        result = generate_summary(turns, use_llm=False)
        # 규칙 기반이므로 [압축 요약] 헤더 포함
        assert result is not None
        assert "[압축 요약]" in result

    def test_use_llm_true_without_api_key_falls_back(self):
        """ANTHROPIC_API_KEY가 없으면 규칙 기반으로 fallback한다."""
        from unittest.mock import patch
        turns = [
            _msg("user", "deploy the service"),
            _msg("assistant", "deploying now"),
        ]
        with patch.dict("os.environ", {}, clear=False):
            # ANTHROPIC_API_KEY 삭제 보장
            import os
            env_backup = os.environ.pop("ANTHROPIC_API_KEY", None)
            try:
                result = generate_summary(turns, use_llm=True)
                # API 키 없으면 규칙 기반 fallback
                assert result is not None
                assert "[압축 요약]" in result
            finally:
                if env_backup is not None:
                    os.environ["ANTHROPIC_API_KEY"] = env_backup

    def test_use_llm_true_with_mock_api(self):
        """LLM 호출이 성공하면 LLM 요약을 반환한다."""
        import sys
        from unittest.mock import MagicMock, patch

        turns = [
            _msg("user", "deploy the service"),
            _msg("assistant", "deploying now"),
        ]

        mock_response = MagicMock()
        mock_response.content = [MagicMock(text="Goal: Deploy service\nProgress: In progress")]

        mock_client = MagicMock()
        mock_client.messages.create.return_value = mock_response

        mock_anthropic = MagicMock()
        mock_anthropic.Anthropic.return_value = mock_client

        with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}):
            with patch.dict(sys.modules, {"anthropic": mock_anthropic}):
                result = generate_summary(turns, use_llm=True)
                assert result is not None
                assert "[LLM 요약]" in result
                assert "Goal: Deploy service" in result

    def test_use_llm_true_api_error_falls_back(self):
        """LLM 호출 중 예외 발생 시 규칙 기반으로 fallback한다."""
        import sys
        from unittest.mock import MagicMock, patch

        turns = [
            _msg("user", "deploy the service"),
            _msg("assistant", "deploying now"),
        ]

        mock_anthropic = MagicMock()
        mock_anthropic.Anthropic.side_effect = Exception("API error")

        with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}):
            with patch.dict(sys.modules, {"anthropic": mock_anthropic}):
                result = generate_summary(turns, use_llm=True)
                # API 에러 시 규칙 기반 fallback
                assert result is not None
                assert "[압축 요약]" in result
