"""
tests/test_session_monitor.py

SessionMonitor 단위 테스트 (TDD - RED 단계)

테스트 항목:
- 초기 상태: total_tokens=0, level="normal"
- update()로 토큰 누적 → 70% 도달 시 level="warning"
- update()로 토큰 누적 → 85% 도달 시 level="critical"
- 임계값 커스텀 설정 테스트
- reset() 후 상태 초기화 확인
- 콜백 등록 + 호출 확인
- CLI --status 테스트 (mock으로)
"""

from __future__ import annotations

import json
import os
import sys
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, mock_open, patch

import pytest

_WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace"))
if str(_WORKSPACE) not in sys.path:
    sys.path.insert(0, str(_WORKSPACE))

from utils.session_monitor import SessionMonitor  # noqa: E402

# ---------------------------------------------------------------------------
# 1. 초기 상태 검증
# ---------------------------------------------------------------------------


class TestInitialState:
    """SessionMonitor 초기 상태가 올바른지 확인"""

    def test_default_context_limit(self):
        monitor = SessionMonitor()
        status = monitor.get_usage_status()
        assert status["limit"] == 200_000

    def test_initial_total_tokens_zero(self):
        monitor = SessionMonitor()
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 0

    def test_initial_level_normal(self):
        monitor = SessionMonitor()
        status = monitor.get_usage_status()
        assert status["level"] == "normal"

    def test_initial_usage_pct_zero(self):
        monitor = SessionMonitor()
        status = monitor.get_usage_status()
        assert status["usage_pct"] == 0.0

    def test_custom_context_limit(self):
        monitor = SessionMonitor(context_limit=100_000)
        status = monitor.get_usage_status()
        assert status["limit"] == 100_000

    def test_custom_warning_pct(self):
        monitor = SessionMonitor(context_limit=100_000, warning_pct=0.60)
        # 60%인 60_000 토큰에서 warning이어야 함
        monitor.update({"input_tokens": 60_000, "output_tokens": 0})
        status = monitor.get_usage_status()
        assert status["level"] == "warning"

    def test_custom_critical_pct(self):
        monitor = SessionMonitor(context_limit=100_000, critical_pct=0.90)
        # 90%인 90_000 토큰에서 critical이어야 함
        monitor.update({"input_tokens": 90_000, "output_tokens": 0})
        status = monitor.get_usage_status()
        assert status["level"] == "critical"


# ---------------------------------------------------------------------------
# 2. update() 토큰 누적 및 레벨 반환
# ---------------------------------------------------------------------------


class TestUpdate:
    """update()가 올바르게 토큰을 누적하고 레벨을 반환하는지 확인"""

    def test_update_returns_normal_below_warning(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        # 69% = 138_000 토큰 → normal
        level = monitor.update({"input_tokens": 100_000, "output_tokens": 38_000})
        assert level == "normal"

    def test_update_returns_warning_at_70pct(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        # 70% = 140_000 토큰 → warning
        level = monitor.update({"input_tokens": 100_000, "output_tokens": 40_000})
        assert level == "warning"

    def test_update_returns_warning_between_70_and_85(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        # 75% = 150_000 토큰 → warning
        level = monitor.update({"input_tokens": 100_000, "output_tokens": 50_000})
        assert level == "warning"

    def test_update_returns_critical_at_85pct(self):
        monitor = SessionMonitor(context_limit=200_000)
        # 85% = 170_000 토큰 → critical
        level = monitor.update({"input_tokens": 100_000, "output_tokens": 70_000})
        assert level == "critical"

    def test_update_accumulates_tokens(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 50_000, "output_tokens": 10_000})
        monitor.update({"input_tokens": 30_000, "output_tokens": 5_000})
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 95_000

    def test_update_uses_input_and_output_tokens(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 30_000, "output_tokens": 20_000})
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 50_000

    def test_update_with_only_input_tokens(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 50_000})
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 50_000

    def test_update_with_only_output_tokens(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"output_tokens": 30_000})
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 30_000

    def test_update_empty_dict_stays_normal(self):
        monitor = SessionMonitor(context_limit=200_000)
        level = monitor.update({})
        assert level == "normal"

    def test_usage_pct_calculation(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 100_000, "output_tokens": 50_000})
        status = monitor.get_usage_status()
        assert status["usage_pct"] == pytest.approx(75.0)


# ---------------------------------------------------------------------------
# 3. get_usage_status() 반환 형식 검증
# ---------------------------------------------------------------------------


class TestGetUsageStatus:
    """get_usage_status()가 올바른 형식을 반환하는지 확인"""

    def test_returns_required_keys(self):
        monitor = SessionMonitor()
        status = monitor.get_usage_status()
        assert "total_tokens" in status
        assert "limit" in status
        assert "usage_pct" in status
        assert "level" in status

    def test_level_warning_in_status(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        status = monitor.get_usage_status()
        assert status["level"] == "warning"

    def test_level_critical_in_status(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 170_000, "output_tokens": 0})
        status = monitor.get_usage_status()
        assert status["level"] == "critical"

    def test_usage_pct_is_float(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 50_000, "output_tokens": 0})
        status = monitor.get_usage_status()
        assert isinstance(status["usage_pct"], float)


# ---------------------------------------------------------------------------
# 4. reset() 후 상태 초기화 확인
# ---------------------------------------------------------------------------


class TestReset:
    """reset()이 올바르게 토큰 카운터를 초기화하는지 확인"""

    def test_reset_clears_tokens(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 100_000, "output_tokens": 50_000})
        monitor.reset()
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 0

    def test_reset_level_returns_normal(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 170_000, "output_tokens": 0})
        monitor.reset()
        status = monitor.get_usage_status()
        assert status["level"] == "normal"

    def test_reset_with_new_total(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 170_000, "output_tokens": 0})
        monitor.reset(new_total=50_000)
        status = monitor.get_usage_status()
        assert status["total_tokens"] == 50_000

    def test_reset_with_new_total_updates_level(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 170_000, "output_tokens": 0})
        # 50_000 / 200_000 = 25% → normal
        monitor.reset(new_total=50_000)
        status = monitor.get_usage_status()
        assert status["level"] == "normal"

    def test_reset_with_new_total_warning(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        # 150_000 / 200_000 = 75% → warning
        monitor.reset(new_total=150_000)
        status = monitor.get_usage_status()
        assert status["level"] == "warning"

    def test_reset_usage_pct_zero(self):
        monitor = SessionMonitor(context_limit=200_000)
        monitor.update({"input_tokens": 100_000, "output_tokens": 50_000})
        monitor.reset()
        status = monitor.get_usage_status()
        assert status["usage_pct"] == 0.0


# ---------------------------------------------------------------------------
# 5. 콜백 등록 및 호출 확인
# ---------------------------------------------------------------------------


class TestCallbacks:
    """콜백 등록 및 자동 호출이 올바르게 동작하는지 확인"""

    def test_register_warning_callback_called(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        callback = MagicMock()
        monitor.register_callback("warning", callback)
        # 70% 도달 → warning 콜백 호출
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        callback.assert_called_once()

    def test_register_critical_callback_called(self):
        monitor = SessionMonitor(context_limit=200_000)
        callback = MagicMock()
        monitor.register_callback("critical", callback)
        # 85% 도달 → critical 콜백 호출
        monitor.update({"input_tokens": 170_000, "output_tokens": 0})
        callback.assert_called_once()

    def test_warning_callback_not_called_below_threshold(self):
        monitor = SessionMonitor(context_limit=200_000)
        callback = MagicMock()
        monitor.register_callback("warning", callback)
        # 69% → normal, 콜백 호출 안 됨
        monitor.update({"input_tokens": 138_000, "output_tokens": 0})
        callback.assert_not_called()

    def test_critical_callback_not_called_at_warning(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        warning_cb = MagicMock()
        critical_cb = MagicMock()
        monitor.register_callback("warning", warning_cb)
        monitor.register_callback("critical", critical_cb)
        # 75% → warning only
        monitor.update({"input_tokens": 150_000, "output_tokens": 0})
        warning_cb.assert_called_once()
        critical_cb.assert_not_called()

    def test_callback_receives_status_dict(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        received: list[dict] = []
        monitor.register_callback("warning", lambda s: received.append(s))
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        assert len(received) == 1
        assert "total_tokens" in received[0]
        assert "level" in received[0]

    def test_multiple_callbacks_same_level(self):
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        cb1 = MagicMock()
        cb2 = MagicMock()
        monitor.register_callback("warning", cb1)
        monitor.register_callback("warning", cb2)
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        cb1.assert_called_once()
        cb2.assert_called_once()

    def test_callback_called_only_on_level_transition(self):
        """warning에서 한 번만 호출되고, 이미 warning 상태에서 재호출 안됨"""
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        callback = MagicMock()
        monitor.register_callback("warning", callback)
        # 첫 번째 업데이트: warning 도달 → 호출
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        # 두 번째 업데이트: 여전히 warning → 재호출 없음
        monitor.update({"input_tokens": 1_000, "output_tokens": 0})
        callback.assert_called_once()

    def test_critical_callback_on_transition_from_warning(self):
        """warning에서 critical로 전환 시 critical 콜백만 호출"""
        monitor = SessionMonitor(context_limit=200_000, warning_pct=0.70, critical_pct=0.85)
        warning_cb = MagicMock()
        critical_cb = MagicMock()
        monitor.register_callback("warning", warning_cb)
        monitor.register_callback("critical", critical_cb)
        # warning 도달
        monitor.update({"input_tokens": 140_000, "output_tokens": 0})
        warning_cb.assert_called_once()
        critical_cb.assert_not_called()
        # critical 도달
        monitor.update({"input_tokens": 30_000, "output_tokens": 0})
        warning_cb.assert_called_once()  # 추가 호출 없음
        critical_cb.assert_called_once()


# ---------------------------------------------------------------------------
# 6. CLI --status 테스트 (mock으로)
# ---------------------------------------------------------------------------


class TestCLIStatus:
    """CLI --status 모드가 올바른 JSON을 출력하는지 확인"""

    def test_cli_status_output_format(self, tmp_path):
        """mock으로 task-timers.json과 token-ledger.json을 제공하고 출력 검증"""
        task_timers_data = {
            "tasks": {
                "task-100.1": {
                    "task_id": "task-100.1",
                    "team_id": "dev6-team",
                    "status": "running",
                }
            }
        }
        token_ledger_data = {
            "tasks": {
                "task-100.1": {
                    "session_id": "abc123",
                    "team_id": "dev6-team",
                    "input_tokens": 80_000,
                    "output_tokens": 60_000,
                    "total_tokens": 140_000,
                    "message_count": 10,
                }
            }
        }

        timers_path = tmp_path / "task-timers.json"
        ledger_path = tmp_path / "token-ledger.json"
        timers_path.write_text(json.dumps(task_timers_data))
        ledger_path.write_text(json.dumps(token_ledger_data))

        from utils.session_monitor import get_active_sessions_status

        result = get_active_sessions_status(
            timers_path=str(timers_path),
            ledger_path=str(ledger_path),
        )

        assert "sessions" in result
        assert len(result["sessions"]) == 1
        session = result["sessions"][0]
        assert session["task_id"] == "task-100.1"
        assert session["team_id"] == "dev6-team"
        assert session["total_tokens"] == 140_000
        assert session["limit"] == 200_000
        assert "usage_pct" in session
        assert "level" in session

    def test_cli_status_warning_level(self, tmp_path):
        """55% 토큰 사용 시 warning 레벨 반환 (기본 임계값 50%/65% 기준)"""
        task_timers_data = {
            "tasks": {
                "task-200.1": {
                    "task_id": "task-200.1",
                    "team_id": "dev1-team",
                    "status": "running",
                }
            }
        }
        token_ledger_data = {
            "tasks": {
                "task-200.1": {
                    "total_tokens": 110_000,
                }
            }
        }

        timers_path = tmp_path / "task-timers.json"
        ledger_path = tmp_path / "token-ledger.json"
        timers_path.write_text(json.dumps(task_timers_data))
        ledger_path.write_text(json.dumps(token_ledger_data))

        from utils.session_monitor import get_active_sessions_status

        result = get_active_sessions_status(
            timers_path=str(timers_path),
            ledger_path=str(ledger_path),
        )
        session = result["sessions"][0]
        assert session["level"] == "warning"
        assert session["usage_pct"] == pytest.approx(55.0)

    def test_cli_status_no_running_tasks(self, tmp_path):
        """실행 중인 태스크가 없으면 빈 sessions 반환"""
        task_timers_data = {
            "tasks": {
                "task-300.1": {
                    "task_id": "task-300.1",
                    "team_id": "dev1-team",
                    "status": "completed",
                }
            }
        }
        token_ledger_data = {"tasks": {}}

        timers_path = tmp_path / "task-timers.json"
        ledger_path = tmp_path / "token-ledger.json"
        timers_path.write_text(json.dumps(task_timers_data))
        ledger_path.write_text(json.dumps(token_ledger_data))

        from utils.session_monitor import get_active_sessions_status

        result = get_active_sessions_status(
            timers_path=str(timers_path),
            ledger_path=str(ledger_path),
        )
        assert result["sessions"] == []

    def test_cli_status_task_not_in_ledger(self, tmp_path):
        """running 태스크가 token-ledger에 없으면 0 토큰으로 표시"""
        task_timers_data = {
            "tasks": {
                "task-400.1": {
                    "task_id": "task-400.1",
                    "team_id": "dev2-team",
                    "status": "running",
                }
            }
        }
        token_ledger_data = {"tasks": {}}

        timers_path = tmp_path / "task-timers.json"
        ledger_path = tmp_path / "token-ledger.json"
        timers_path.write_text(json.dumps(task_timers_data))
        ledger_path.write_text(json.dumps(token_ledger_data))

        from utils.session_monitor import get_active_sessions_status

        result = get_active_sessions_status(
            timers_path=str(timers_path),
            ledger_path=str(ledger_path),
        )
        assert len(result["sessions"]) == 1
        assert result["sessions"][0]["total_tokens"] == 0
        assert result["sessions"][0]["level"] == "normal"
