#!/usr/bin/env python3
"""
utils/session_monitor.py — 봇 세션 토큰 사용량 실시간 추적 모니터

봇 세션의 토큰 사용량을 실시간으로 추적하고, 임계값(warning/critical)
도달 시 콜백을 호출하며 현재 상태를 반환합니다.

Usage:
    from utils.session_monitor import SessionMonitor

    monitor = SessionMonitor(context_limit=200_000)
    level = monitor.update({"input_tokens": 50_000, "output_tokens": 30_000})
    status = monitor.get_usage_status()

CLI:
    python3 utils/session_monitor.py --status
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Callable

# 직접 실행 시 workspace root를 sys.path에 추가 (from utils.xxx 임포트 지원)
_SCRIPT_DIR = Path(__file__).resolve().parent.parent
if str(_SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(_SCRIPT_DIR))

from utils.logger import get_logger

logger = get_logger(__name__)

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace")


def _load_session_monitoring_config() -> dict:
    """constants.json에서 session_monitoring 설정을 로드한다.

    로드 실패 시 빈 dict 반환 (호출측에서 fallback 사용).
    """
    config_path = Path(_WORKSPACE_ROOT) / "config" / "constants.json"
    try:
        with open(config_path, encoding="utf-8") as f:
            data = json.load(f)
        return data.get("session_monitoring", {})
    except (FileNotFoundError, json.JSONDecodeError, KeyError):
        return {}


_SESSION_MON_CFG = _load_session_monitoring_config()

_DEFAULT_CONTEXT_LIMIT = _SESSION_MON_CFG.get("context_limit", 200_000)
_DEFAULT_WARNING_PCT = _SESSION_MON_CFG.get("warning_pct", 70) / 100
_DEFAULT_CRITICAL_PCT = _SESSION_MON_CFG.get("critical_pct", 85) / 100

_LEVEL_NORMAL = "normal"
_LEVEL_WARNING = "warning"
_LEVEL_CRITICAL = "critical"

_LEVEL_ORDER = {
    _LEVEL_NORMAL: 0,
    _LEVEL_WARNING: 1,
    _LEVEL_CRITICAL: 2,
}


class SessionMonitor:
    """봇 세션의 토큰 사용량을 실시간으로 추적합니다.

    Args:
        context_limit: 최대 컨텍스트 윈도우 토큰 수 (기본값: 200,000)
        warning_pct: 경고 임계값 비율 (기본값: 0.70 = 70%)
        critical_pct: 위험 임계값 비율 (기본값: 0.85 = 85%)
    """

    def __init__(
        self,
        context_limit: int = _DEFAULT_CONTEXT_LIMIT,
        warning_pct: float = _DEFAULT_WARNING_PCT,
        critical_pct: float = _DEFAULT_CRITICAL_PCT,
    ) -> None:
        self._context_limit = context_limit
        self._warning_pct = warning_pct
        self._critical_pct = critical_pct

        self._total_tokens: int = 0
        self._current_level: str = _LEVEL_NORMAL

        # 레벨별 콜백 목록
        self._callbacks: dict[str, list[Callable[[dict], None]]] = {
            _LEVEL_WARNING: [],
            _LEVEL_CRITICAL: [],
        }

        logger.debug(
            "SessionMonitor 초기화: limit=%d, warning=%.0f%%, critical=%.0f%%",
            context_limit,
            warning_pct * 100,
            critical_pct * 100,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def update(self, usage: dict) -> str:
        """Anthropic API 응답의 usage 필드를 처리하고 누적 카운터를 갱신합니다.

        Args:
            usage: {"input_tokens": N, "output_tokens": M} 형태의 딕셔너리

        Returns:
            현재 레벨 문자열: "normal", "warning", "critical"
        """
        input_tokens = int(usage.get("input_tokens") or 0)
        output_tokens = int(usage.get("output_tokens") or 0)
        self._total_tokens += input_tokens + output_tokens

        new_level = self._compute_level(self._total_tokens)
        self._fire_callbacks_on_transition(new_level)
        self._current_level = new_level

        logger.debug(
            "update: +%d tokens (input=%d, output=%d), total=%d, level=%s",
            input_tokens + output_tokens,
            input_tokens,
            output_tokens,
            self._total_tokens,
            new_level,
        )

        return new_level

    def get_usage_status(self) -> dict:
        """현재 세션 토큰 사용 상태를 반환합니다.

        Returns:
            {
                "total_tokens": N,
                "limit": 200000,
                "usage_pct": 75.0,
                "level": "warning"
            }
        """
        usage_pct = round(self._total_tokens / self._context_limit * 100, 2) if self._context_limit > 0 else 0.0
        return {
            "total_tokens": self._total_tokens,
            "limit": self._context_limit,
            "usage_pct": usage_pct,
            "level": self._current_level,
        }

    def reset(self, new_total: int = 0) -> None:
        """토큰 카운터를 리셋합니다. 압축 후 실제 토큰 수로 설정 가능합니다.

        Args:
            new_total: 리셋 후 설정할 토큰 수 (기본값: 0)
        """
        old_total = self._total_tokens
        self._total_tokens = new_total
        self._current_level = self._compute_level(new_total)

        logger.info(
            "SessionMonitor 리셋: %d → %d tokens, level=%s",
            old_total,
            new_total,
            self._current_level,
        )

    def register_callback(self, level: str, callback: Callable[[dict], None]) -> None:
        """특정 레벨 도달 시 호출될 콜백을 등록합니다.

        Args:
            level: "warning" 또는 "critical"
            callback: 상태 딕셔너리를 인자로 받는 콜백 함수
        """
        normalized = level.lower()
        if normalized not in self._callbacks:
            logger.warning("알 수 없는 콜백 레벨: %s (warning/critical만 지원)", level)
            return
        self._callbacks[normalized].append(callback)
        logger.debug("콜백 등록: level=%s", normalized)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _compute_level(self, total_tokens: int) -> str:
        """토큰 수에 따른 레벨을 계산합니다."""
        if self._context_limit <= 0:
            return _LEVEL_NORMAL
        pct = total_tokens / self._context_limit
        if pct >= self._critical_pct:
            return _LEVEL_CRITICAL
        if pct >= self._warning_pct:
            return _LEVEL_WARNING
        return _LEVEL_NORMAL

    def _fire_callbacks_on_transition(self, new_level: str) -> None:
        """레벨 전환 시에만 해당 레벨의 콜백을 호출합니다."""
        old_order = _LEVEL_ORDER.get(self._current_level, 0)
        new_order = _LEVEL_ORDER.get(new_level, 0)

        # 레벨이 상승한 경우에만 콜백 실행
        if new_order <= old_order:
            return

        status = self.get_usage_status()
        # 새 레벨의 콜백 실행
        if new_level in self._callbacks:
            for cb in self._callbacks[new_level]:
                try:
                    cb(status)
                except Exception as exc:
                    logger.error("콜백 실행 오류 (level=%s): %s", new_level, exc)


# ---------------------------------------------------------------------------
# CLI 헬퍼 함수
# ---------------------------------------------------------------------------


def get_active_sessions_status(
    timers_path: str | None = None,
    ledger_path: str | None = None,
    context_limit: int = _DEFAULT_CONTEXT_LIMIT,
) -> dict:
    """실행 중인 태스크의 세션 토큰 사용 상태를 반환합니다.

    Args:
        timers_path: task-timers.json 파일 경로 (기본: memory/task-timers.json)
        ledger_path: token-ledger.json 파일 경로 (기본: memory/token-ledger.json)
        context_limit: 컨텍스트 윈도우 한도 (기본: 200,000)

    Returns:
        {
            "sessions": [
                {
                    "task_id": "task-100.1",
                    "team_id": "dev6-team",
                    "total_tokens": 140000,
                    "limit": 200000,
                    "usage_pct": 70.0,
                    "level": "warning"
                },
                ...
            ]
        }
    """
    memory_dir = Path(_WORKSPACE_ROOT) / "memory"

    if timers_path is None:
        timers_path = str(memory_dir / "task-timers.json")
    if ledger_path is None:
        ledger_path = str(memory_dir / "token-ledger.json")

    # task-timers.json에서 running 상태 태스크 읽기
    running_tasks: dict[str, dict] = {}
    try:
        with open(timers_path, encoding="utf-8") as f:
            timers_data = json.load(f)
        for task_id, task_info in timers_data.get("tasks", {}).items():
            if task_info.get("status") == "running":
                running_tasks[task_id] = task_info
    except FileNotFoundError:
        logger.warning("task-timers.json 파일을 찾을 수 없음: %s", timers_path)
    except (json.JSONDecodeError, KeyError) as exc:
        logger.error("task-timers.json 파싱 오류: %s", exc)

    # token-ledger.json에서 토큰 정보 읽기
    ledger_tasks: dict[str, dict] = {}
    try:
        with open(ledger_path, encoding="utf-8") as f:
            ledger_data = json.load(f)
        ledger_tasks = ledger_data.get("tasks", {})
    except FileNotFoundError:
        logger.warning("token-ledger.json 파일을 찾을 수 없음: %s", ledger_path)
    except (json.JSONDecodeError, KeyError) as exc:
        logger.error("token-ledger.json 파싱 오류: %s", exc)

    # 세션별 상태 계산
    sessions = []
    for task_id, task_info in running_tasks.items():
        ledger_entry = ledger_tasks.get(task_id, {})
        total_tokens = int(ledger_entry.get("total_tokens", 0))
        team_id = task_info.get("team_id", ledger_entry.get("team_id", ""))

        monitor = SessionMonitor(context_limit=context_limit)
        monitor.reset(new_total=total_tokens)
        status = monitor.get_usage_status()

        sessions.append(
            {
                "task_id": task_id,
                "team_id": team_id,
                "total_tokens": status["total_tokens"],
                "limit": status["limit"],
                "usage_pct": status["usage_pct"],
                "level": status["level"],
            }
        )
        logger.debug(
            "세션 상태: task=%s, tokens=%d, pct=%.1f%%, level=%s",
            task_id,
            total_tokens,
            status["usage_pct"],
            status["level"],
        )

    return {"sessions": sessions}


# ---------------------------------------------------------------------------
# CLI 엔트리포인트
# ---------------------------------------------------------------------------


def _main() -> None:
    parser = argparse.ArgumentParser(
        description="봇 세션 토큰 사용량 모니터",
        prog="session_monitor",
    )
    parser.add_argument(
        "--status",
        action="store_true",
        help="현재 활성 세션 토큰 사용률 출력 (JSON)",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=_DEFAULT_CONTEXT_LIMIT,
        help=f"컨텍스트 윈도우 한도 (기본값: {_DEFAULT_CONTEXT_LIMIT:,})",
    )

    args = parser.parse_args()

    if args.status:
        result = get_active_sessions_status(context_limit=args.limit)
        print(json.dumps(result, ensure_ascii=False, indent=2))
    else:
        parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    _main()
