#!/usr/bin/env python3
"""
utils/session_resilience.py — 세션 복원력 오케스트레이터 (Phase 2)

Phase 1의 SessionMonitor + SessionAutoCompress를 연결하고,
전체 봇 세션의 토큰 사용량을 모니터링하여 자동 대응합니다.

Usage:
    from utils.session_resilience import SessionResilience

    resilience = SessionResilience()
    result = resilience.check_all_sessions()
"""

from __future__ import annotations

import json
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path

# 직접 실행 시 workspace root를 sys.path에 추가
_SCRIPT_DIR = Path(__file__).resolve().parent.parent
if str(_SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(_SCRIPT_DIR))

from utils.logger import get_logger

logger = get_logger(__name__)

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace")

# 봇 ID ↔ 팀 매핑 (org_loader에서 동적 로드)
try:
    from utils.org_loader import build_team_bot_map as _build_team_bot_map

    BOT_TEAMS: dict[str, str] = {"anu": "anu-direct", **_build_team_bot_map()}
except ImportError:
    BOT_TEAMS = {
        "anu": "anu-direct",
        "dev1-team": "dev1",
        "dev2-team": "dev2",
        "dev3-team": "dev3",
        "dev4-team": "dev4",
        "dev5-team": "dev5",
        "dev6-team": "dev6",
        "dev7-team": "dev7",
        "dev8-team": "dev8",
    }

_LEVEL_NORMAL = "normal"
_LEVEL_WARNING = "warning"
_LEVEL_CRITICAL = "critical"


class SessionResilience:
    """전체 봇 세션 복원력 오케스트레이터.

    Args:
        workspace_root: 워크스페이스 루트 경로
        context_limit: 컨텍스트 윈도우 토큰 한도 (기본: 200,000)
        warning_pct: WARNING 임계값 (기본: 0.70)
        critical_pct: CRITICAL 임계값 (기본: 0.85)
    """

    def __init__(
        self,
        workspace_root: str | None = None,
        context_limit: int = 200_000,
        warning_pct: float = 0.70,
        critical_pct: float = 0.85,
    ) -> None:
        root = Path(workspace_root) if workspace_root is not None else Path(_WORKSPACE_ROOT)

        self.workspace_root: Path = root
        self.context_limit: int = context_limit
        self.warning_pct: float = warning_pct
        self.critical_pct: float = critical_pct

        self.timers_path: Path = root / "memory" / "task-timers.json"
        self.ledger_path: Path = root / "memory" / "token-ledger.json"
        self.events_dir: Path = root / "memory" / "events"
        self.sessions_dir: Path = root / "memory" / "sessions"

        self.bot_teams: dict[str, str] = BOT_TEAMS

        logger.debug(
            "SessionResilience 초기화: root=%s, limit=%d, warning=%.0f%%, critical=%.0f%%",
            root,
            context_limit,
            warning_pct * 100,
            critical_pct * 100,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def check_all_sessions(self) -> dict:
        """모든 running 세션의 토큰 사용량을 체크하고 자동 대응.

        Returns:
            {
                "checked": N,
                "warnings": [{"task_id": ..., "usage_pct": ..., "event_path": ...}],
                "criticals": [{"task_id": ..., "usage_pct": ..., "event_path": ...,
                               "summary_path": ..., "resume_triggered": bool}],
                "normals": N,
                "timestamp": "..."
            }
        """
        running_tasks = self._load_running_tasks()
        ledger_tasks = self._load_ledger_tasks()

        warnings: list[dict] = []
        criticals: list[dict] = []
        normals = 0

        for task_id, task_info in running_tasks.items():
            ledger_info = ledger_tasks.get(task_id, {})
            status = self.check_session(task_id, task_info, ledger_info)
            level = status["level"]
            team_id = status["team_id"]

            if level == _LEVEL_WARNING:
                w_result = self.handle_warning(task_id, team_id, status)
                warnings.append(
                    {
                        "task_id": task_id,
                        "usage_pct": status["usage_pct"],
                        "event_path": w_result["event_path"],
                    }
                )
            elif level == _LEVEL_CRITICAL:
                if self._is_already_handled(task_id):
                    logger.info("세션 %s 이미 CRITICAL 처리됨 — 중복 트리거 방지", task_id)
                    criticals.append(
                        {
                            "task_id": task_id,
                            "usage_pct": status["usage_pct"],
                            "event_path": "",
                            "summary_path": "",
                            "resume_triggered": False,
                        }
                    )
                else:
                    original_desc = self._get_original_task_desc(task_id)
                    c_result = self.handle_critical(task_id, team_id, status, original_desc)
                    criticals.append(
                        {
                            "task_id": task_id,
                            "usage_pct": status["usage_pct"],
                            "event_path": c_result["event_path"],
                            "summary_path": c_result["summary_path"],
                            "resume_triggered": c_result["resume_triggered"],
                        }
                    )
            else:
                normals += 1

        checked = len(running_tasks)
        result = {
            "checked": checked,
            "warnings": warnings,
            "criticals": criticals,
            "normals": normals,
            "timestamp": datetime.now().isoformat(),
        }

        logger.info(
            "check_all_sessions 완료: checked=%d, warnings=%d, criticals=%d, normals=%d",
            checked,
            len(warnings),
            len(criticals),
            normals,
        )
        return result

    def check_session(self, task_id: str, task_info: dict, ledger_info: dict) -> dict:
        """개별 세션 토큰 상태 확인.

        Returns:
            {
                "task_id": str,
                "team_id": str,
                "total_tokens": int,
                "usage_pct": float,
                "level": "normal"|"warning"|"critical"
            }
        """
        total_tokens = int(ledger_info.get("total_tokens", 0))
        team_id = str(task_info.get("team_id", ledger_info.get("team_id", "")))

        if self.context_limit > 0:
            pct = total_tokens / self.context_limit
            usage_pct = round(pct * 100, 2)
        else:
            pct = 0.0
            usage_pct = 0.0

        if pct >= self.critical_pct:
            level = _LEVEL_CRITICAL
        elif pct >= self.warning_pct:
            level = _LEVEL_WARNING
        else:
            level = _LEVEL_NORMAL

        return {
            "task_id": task_id,
            "team_id": team_id,
            "total_tokens": total_tokens,
            "usage_pct": usage_pct,
            "level": level,
        }

    def handle_warning(self, task_id: str, team_id: str, session_status: dict) -> dict:
        """WARNING 임계값 도달: 이벤트 파일 생성 + 로그 기록.

        이벤트 파일: memory/events/session-warning-{task_id}-{timestamp}.json

        Returns:
            {"event_path": str, "level": "warning"}
        """
        timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
        self.events_dir.mkdir(parents=True, exist_ok=True)

        filename = f"session-warning-{task_id}-{timestamp}.json"
        event_path = self.events_dir / filename

        event_data = {
            "level": _LEVEL_WARNING,
            "task_id": task_id,
            "team_id": team_id,
            "total_tokens": session_status.get("total_tokens", 0),
            "usage_pct": session_status.get("usage_pct", 0.0),
            "timestamp": timestamp,
        }

        with open(event_path, "w", encoding="utf-8") as f:
            json.dump(event_data, f, ensure_ascii=False, indent=2)

        logger.warning(
            "WARNING 이벤트 기록: task=%s, team=%s, usage_pct=%.1f%%, event=%s",
            task_id,
            team_id,
            session_status.get("usage_pct", 0.0),
            event_path,
        )

        return {"event_path": str(event_path), "level": _LEVEL_WARNING}

    def handle_critical(
        self,
        task_id: str,
        team_id: str,
        session_status: dict,
        original_task_desc: str = "",
    ) -> dict:
        """CRITICAL 임계값 도달: 이벤트 + 요약 저장 + resume 트리거.

        이벤트 파일: memory/events/session-critical-{task_id}-{timestamp}.json
        세션 요약: memory/sessions/summary-{task_id}-{timestamp}.md

        Returns:
            {"event_path": str, "summary_path": str, "resume_triggered": bool, "level": "critical"}
        """
        timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")

        # 이벤트 파일 생성
        self.events_dir.mkdir(parents=True, exist_ok=True)
        event_filename = f"session-critical-{task_id}-{timestamp}.json"
        event_path = self.events_dir / event_filename

        event_data = {
            "level": _LEVEL_CRITICAL,
            "task_id": task_id,
            "team_id": team_id,
            "total_tokens": session_status.get("total_tokens", 0),
            "usage_pct": session_status.get("usage_pct", 0.0),
            "timestamp": timestamp,
        }

        with open(event_path, "w", encoding="utf-8") as f:
            json.dump(event_data, f, ensure_ascii=False, indent=2)

        # 세션 요약 저장
        self.sessions_dir.mkdir(parents=True, exist_ok=True)
        summary_filename = f"summary-{task_id}-{timestamp}.md"
        summary_path = self.sessions_dir / summary_filename

        summary_content = self._build_summary_markdown(
            task_id=task_id,
            team_id=team_id,
            timestamp=timestamp,
            session_status=session_status,
            original_task_desc=original_task_desc,
        )
        summary_path.write_text(summary_content, encoding="utf-8")

        logger.warning(
            "CRITICAL 이벤트 기록: task=%s, team=%s, usage_pct=%.1f%%, event=%s, summary=%s",
            task_id,
            team_id,
            session_status.get("usage_pct", 0.0),
            event_path,
            summary_path,
        )

        # resume 트리거
        resume_triggered = self._trigger_resume(task_id, team_id, str(summary_path))

        return {
            "event_path": str(event_path),
            "summary_path": str(summary_path),
            "resume_triggered": resume_triggered,
            "level": _LEVEL_CRITICAL,
        }

    def _is_already_handled(self, task_id: str) -> bool:
        """이 세션이 이미 CRITICAL 처리된 적 있는지 확인 (중복 방지).
        events/ 디렉토리에서 session-critical-{task_id} 파일 존재 여부로 판단.
        """
        if not self.events_dir.exists():
            return False

        prefix = f"session-critical-{task_id}"
        for p in self.events_dir.iterdir():
            if p.name.startswith(prefix):
                return True
        return False

    def _get_original_task_desc(self, task_id: str) -> str:
        """memory/tasks/{task_id}.md에서 원래 작업 설명 읽기."""
        task_file = self.workspace_root / "memory" / "tasks" / f"{task_id}.md"
        try:
            return task_file.read_text(encoding="utf-8")
        except FileNotFoundError:
            logger.debug("태스크 설명 파일 없음: %s", task_file)
            return ""
        except OSError as exc:
            logger.warning("태스크 설명 파일 읽기 오류: %s — %s", task_file, exc)
            return ""

    def _trigger_resume(self, task_id: str, team_id: str, summary_path: str) -> bool:
        """dispatch.py --resume-from 호출하여 새 세션 시작.
        subprocess.run으로 dispatch.py 호출.

        Returns:
            성공 여부 (bool)
        """
        dispatch_path = self.workspace_root / "dispatch.py"
        cmd = [
            sys.executable,
            str(dispatch_path),
            "--team",
            team_id,
            "--resume-from",
            summary_path,
        ]

        logger.info(
            "resume 트리거: task=%s, team=%s, summary=%s",
            task_id,
            team_id,
            summary_path,
        )

        try:
            proc = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=30,
            )
            if proc.returncode == 0:
                logger.info("resume 성공: task=%s", task_id)
                return True
            else:
                logger.error(
                    "resume 실패: task=%s, returncode=%d, stderr=%s",
                    task_id,
                    proc.returncode,
                    proc.stderr,
                )
                return False
        except subprocess.TimeoutExpired:
            logger.error("resume 타임아웃: task=%s", task_id)
            return False
        except OSError as exc:
            logger.error("resume 실행 오류: task=%s — %s", task_id, exc)
            return False

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _load_running_tasks(self) -> dict[str, dict]:
        """task-timers.json에서 running 상태 태스크 로드."""
        running: dict[str, dict] = {}
        try:
            with open(self.timers_path, encoding="utf-8") as f:
                data = json.load(f)
            for task_id, task_info in data.get("tasks", {}).items():
                if task_info.get("status") == "running":
                    running[task_id] = task_info
        except FileNotFoundError:
            logger.warning("task-timers.json 없음: %s", self.timers_path)
        except (json.JSONDecodeError, KeyError) as exc:
            logger.error("task-timers.json 파싱 오류: %s", exc)
        return running

    def _load_ledger_tasks(self) -> dict[str, dict]:
        """token-ledger.json에서 토큰 정보 로드."""
        ledger: dict[str, dict] = {}
        try:
            with open(self.ledger_path, encoding="utf-8") as f:
                data = json.load(f)
            ledger = data.get("tasks", {})
        except FileNotFoundError:
            logger.warning("token-ledger.json 없음: %s", self.ledger_path)
        except (json.JSONDecodeError, KeyError) as exc:
            logger.error("token-ledger.json 파싱 오류: %s", exc)
        return ledger

    def _build_summary_markdown(
        self,
        task_id: str,
        team_id: str,
        timestamp: str,
        session_status: dict,
        original_task_desc: str,
    ) -> str:
        """세션 요약 마크다운 문자열 생성."""
        usage_pct = session_status.get("usage_pct", 0.0)
        total_tokens = session_status.get("total_tokens", 0)

        lines: list[str] = [
            f"# 세션 요약: {task_id}",
            "",
            "## 기본 정보",
            f"- 작업 ID: {task_id}",
            f"- 팀: {team_id}",
            f"- 생성 시각: {timestamp}",
            f"- 토큰 사용량: {total_tokens:,} / {self.context_limit:,} ({usage_pct:.1f}%)",
            f"- 레벨: CRITICAL",
            "",
            "## 원래 작업 설명",
            original_task_desc if original_task_desc else "(없음)",
            "",
            "## 자동 요약",
            f"세션 {task_id}이 CRITICAL 임계값({self.critical_pct * 100:.0f}%)에 도달하여 자동으로 요약되었습니다.",
            f"토큰 사용량: {usage_pct:.1f}%",
            "",
            "## 재시작 안내",
            f"이 요약 파일을 기반으로 새 세션이 자동 시작됩니다.",
            f"팀: {team_id}",
        ]

        return "\n".join(lines) + "\n"
