#!/usr/bin/env python3
"""
Circuit Breaker 모듈 — 에러 임계치 기반 자동 차단 + 복구 전략.

Usage:
    from utils.circuit_breaker import create_circuit_breaker
    cb = create_circuit_breaker("pyright_check", strategy_type="autofix")
    action = cb.record_error({"message": "...", "source": "pyright", "details": None})
"""

import argparse
import json
import os
import sys
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any

from utils.logger import get_logger

logger = get_logger(__name__)

WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace"))
CB_STATE_DIR = WORKSPACE / "memory" / "logs" / "circuit-breaker"
ESCALATIONS_DIR = WORKSPACE / "memory" / "escalations"
HISTORY_LIMIT = 50


# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------


class RecoveryAction(Enum):
    RETRY = "retry"
    ROLLBACK = "rollback"
    ESCALATE = "escalate"
    SKIP = "skip"


class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"


# ---------------------------------------------------------------------------
# RecoveryStrategy ABC
# ---------------------------------------------------------------------------


class RecoveryStrategy(ABC):
    @abstractmethod
    def on_error(self, context: str, error_info: dict, attempt: int) -> RecoveryAction:
        """에러 발생 시 복구 액션 결정. error_info: {"message": str, "source": str, "details": any}"""

    @abstractmethod
    def on_circuit_open(self, context: str, error_count: int) -> None:
        """서킷이 OPEN 상태로 전환될 때 호출"""


# ---------------------------------------------------------------------------
# AutoFixStrategy
# ---------------------------------------------------------------------------


class AutoFixStrategy(RecoveryStrategy):
    """fireauto 패턴: threshold 미만이면 RETRY, 초과 시 ESCALATE."""

    def __init__(self, threshold: int = 3) -> None:
        self.threshold = threshold

    def on_error(self, context: str, error_info: dict, attempt: int) -> RecoveryAction:
        if attempt < self.threshold:
            return RecoveryAction.RETRY
        return RecoveryAction.ESCALATE

    def on_circuit_open(self, context: str, error_count: int) -> None:
        logger.warning(f"[CircuitBreaker] OPEN 전환: context={context}, error_count={error_count}")
        _write_escalation_file(
            context=context,
            reason=f"에러 {error_count}회 연속 — threshold({self.threshold}) 초과",
            error_count=error_count,
        )


# ---------------------------------------------------------------------------
# EscalationStrategy
# ---------------------------------------------------------------------------


class EscalationStrategy(RecoveryStrategy):
    """gstack 패턴: 즉시 ESCALATE (CRITICAL GAP)."""

    def on_error(self, context: str, error_info: dict, attempt: int) -> RecoveryAction:
        return RecoveryAction.ESCALATE

    def on_circuit_open(self, context: str, error_count: int) -> None:
        _write_escalation_file(
            context=context,
            reason=f"에러 {error_count}회 연속 — escalation 전략 즉시 발동",
            error_count=error_count,
        )


# ---------------------------------------------------------------------------
# Escalation helper
# ---------------------------------------------------------------------------


def _write_escalation_file(
    context: str,
    reason: str,
    error_count: int,
    last_errors: list[str] | None = None,
    output_path: Path | None = None,
    **extra: Any,
) -> None:
    try:
        if output_path is not None:
            filepath = output_path
            filepath.parent.mkdir(parents=True, exist_ok=True)
        else:
            ESCALATIONS_DIR.mkdir(parents=True, exist_ok=True)
            ts = datetime.now().strftime("%Y%m%dT%H%M%S")
            filepath = ESCALATIONS_DIR / f"{context}_{ts}_escalation.json"
        payload = {
            "context": context,
            "triggered_at": datetime.now().isoformat(),
            "reason": reason,
            "error_count": error_count,
            "action": "escalation",
            "last_errors": last_errors or [],
        }
        if extra:
            payload.update(extra)
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(payload, f, ensure_ascii=False, indent=2)
        logger.warning(f"[CircuitBreaker] escalation 파일 생성: {filepath}")
    except OSError as e:
        logger.error(f"[CircuitBreaker] escalation 파일 생성 실패: {e}")


# ---------------------------------------------------------------------------
# FileSnapshot
# ---------------------------------------------------------------------------


class FileSnapshot:
    """파일 내용 스냅샷 + 복원."""

    def __init__(self) -> None:
        self._snapshots: dict[str, str | None] = {}

    def capture(self, filepath: str) -> None:
        """파일 현재 내용을 스냅샷. 파일이 없으면 None 저장."""
        try:
            p = Path(filepath)
            if p.exists():
                self._snapshots[filepath] = p.read_text(encoding="utf-8")
            else:
                self._snapshots[filepath] = None
        except OSError as e:
            logger.error(f"[FileSnapshot] capture 실패: {filepath}, {e}")

    def capture_multiple(self, filepaths: list[str]) -> None:
        """여러 파일 스냅샷"""
        for fp in filepaths:
            self.capture(fp)

    def restore(self) -> list[str]:
        """모든 스냅샷 파일을 원래 내용으로 복원. 복원된 파일 경로 리스트 반환."""
        restored = []
        for fp in list(self._snapshots.keys()):
            if self.restore_file(fp):
                restored.append(fp)
        return restored

    def restore_file(self, filepath: str) -> bool:
        """단일 파일 복원"""
        if filepath not in self._snapshots:
            return False
        try:
            p = Path(filepath)
            content = self._snapshots[filepath]
            if content is None:
                if p.exists():
                    p.unlink()
            else:
                p.parent.mkdir(parents=True, exist_ok=True)
                p.write_text(content, encoding="utf-8")
            return True
        except OSError as e:
            logger.error(f"[FileSnapshot] restore_file 실패: {filepath}, {e}")
            return False

    @property
    def files(self) -> list[str]:
        """스냅샷된 파일 목록"""
        return list(self._snapshots.keys())


# ---------------------------------------------------------------------------
# RollbackManager
# ---------------------------------------------------------------------------


class RollbackManager:
    """스냅샷 기반 롤백 관리자."""

    def __init__(self, task_id: str, scope: str = "last_operation") -> None:
        self.task_id = task_id
        self.scope = scope  # "last_operation" | "task_start"

    def create_snapshot(self, filepaths: list[str]) -> FileSnapshot:
        """스냅샷 생성 + 영속 저장"""
        snapshot = FileSnapshot()
        snapshot.capture_multiple(filepaths)
        self._persist_snapshot(snapshot)
        return snapshot

    def _persist_snapshot(self, snapshot: FileSnapshot) -> None:
        try:
            snap_dir = WORKSPACE / "memory" / "logs" / "rollback-snapshots"
            snap_dir.mkdir(parents=True, exist_ok=True)
            ts = datetime.now().strftime("%Y%m%dT%H%M%S")
            snap_file = snap_dir / f"{self.task_id}_{self.scope}_{ts}.json"
            data = {
                "task_id": self.task_id,
                "scope": self.scope,
                "created_at": datetime.now().isoformat(),
                "files": snapshot.files,
                "contents": {fp: snapshot._snapshots[fp] for fp in snapshot.files},
            }
            with open(snap_file, "w", encoding="utf-8") as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        except OSError as e:
            logger.error(f"[RollbackManager] 스냅샷 영속 저장 실패: {e}")

    def rollback(self, snapshot: FileSnapshot) -> dict:
        """롤백 수행."""
        restored: list[str] = []
        failed: list[str] = []
        for fp in snapshot.files:
            if snapshot.restore_file(fp):
                restored.append(fp)
            else:
                failed.append(fp)
        success = len(failed) == 0
        logger.info(f"[RollbackManager] rollback 완료: restored={len(restored)}, failed={len(failed)}")
        return {"success": success, "restored": restored, "failed": failed}

    def rollback_with_escalation(self, snapshot: FileSnapshot) -> dict:
        """롤백 시도 → 실패 시 escalation 파일 생성."""
        result = self.rollback(snapshot)
        if not result["success"]:
            _write_escalation_file(
                context=self.task_id,
                reason=f"롤백 실패 — 수동 개입 필요: {result['failed']}",
                error_count=len(result["failed"]),
            )
        return result


# ---------------------------------------------------------------------------
# CircuitBreaker
# ---------------------------------------------------------------------------


class CircuitBreaker:
    """에러 임계치 기반 서킷 브레이커."""

    def __init__(
        self,
        context: str,
        strategy: RecoveryStrategy,
        threshold: int = 3,
        cooldown_seconds: int = 300,
        persistent: bool = True,
    ) -> None:
        self.context = context
        self.strategy = strategy
        self.threshold = threshold
        self.cooldown_seconds = cooldown_seconds
        self.persistent = persistent

        self._state: CircuitState = CircuitState.CLOSED
        self._error_count: int = 0
        self._last_error_ts: str | None = None
        self._last_success_ts: str | None = None
        self._history: list[dict] = []
        self._last_errors: list[str] = []

        if self.persistent:
            self._load_state()

    @property
    def state(self) -> CircuitState:
        return self._state

    @property
    def error_count(self) -> int:
        return self._error_count

    def record_error(self, error_info: dict) -> RecoveryAction:
        """에러 기록 + 복구 액션 결정."""
        msg = error_info.get("message", "")
        self._last_errors.append(msg)
        if len(self._last_errors) > self.threshold:
            self._last_errors = self._last_errors[-self.threshold :]

        now = datetime.now().isoformat()
        self._last_error_ts = now

        if self._state == CircuitState.OPEN:
            self._append_history({"ts": now, "action": "blocked", "message": msg})
            if self.persistent:
                self._save_state()
            return RecoveryAction.ESCALATE

        if self._state == CircuitState.HALF_OPEN:
            self._state = CircuitState.OPEN
            self._append_history({"ts": now, "action": "half_open_fail", "message": msg})
            logger.warning(f"[CircuitBreaker] HALF_OPEN → OPEN: context={self.context}")
            if self.persistent:
                self._save_state()
            return RecoveryAction.ESCALATE

        # CLOSED
        self._error_count += 1
        action = self.strategy.on_error(self.context, error_info, self._error_count)
        self._append_history({"ts": now, "action": "error", "message": msg})

        if self._error_count >= self.threshold:
            self._state = CircuitState.OPEN
            self._append_history({"ts": now, "action": "circuit_open", "error_count": self._error_count})
            logger.warning(
                f"[CircuitBreaker] CLOSED → OPEN: context={self.context}, " f"error_count={self._error_count}"
            )
            try:
                self.strategy.on_circuit_open(self.context, self._error_count)
            except Exception as e:
                logger.error(f"[CircuitBreaker] on_circuit_open 오류: {e}")

        if self.persistent:
            self._save_state()
        return action

    def record_success(self) -> None:
        """성공 기록."""
        now = datetime.now().isoformat()
        self._last_success_ts = now
        self._append_history({"ts": now, "action": "success"})

        if self._state == CircuitState.HALF_OPEN:
            self._state = CircuitState.CLOSED
            self._error_count = 0
            self._last_errors = []
            logger.info(f"[CircuitBreaker] HALF_OPEN → CLOSED: context={self.context}")
        elif self._state == CircuitState.CLOSED:
            self._error_count = 0
            self._last_errors = []

        if self.persistent:
            self._save_state()

    def try_reset(self) -> bool:
        """cooldown 경과 시 OPEN → HALF_OPEN."""
        if self._state != CircuitState.OPEN:
            return False
        if self._last_error_ts is None:
            return False
        try:
            last_ts = datetime.fromisoformat(self._last_error_ts)
            elapsed = (datetime.now() - last_ts).total_seconds()
            if elapsed >= self.cooldown_seconds:
                self._state = CircuitState.HALF_OPEN
                now = datetime.now().isoformat()
                self._append_history({"ts": now, "action": "half_open_try"})
                logger.info(f"[CircuitBreaker] OPEN → HALF_OPEN: context={self.context}, " f"elapsed={elapsed:.0f}s")
                if self.persistent:
                    self._save_state()
                return True
        except ValueError as e:
            logger.error(f"[CircuitBreaker] try_reset 타임스탬프 파싱 오류: {e}")
        return False

    def force_reset(self) -> None:
        """강제 리셋 → CLOSED"""
        self._state = CircuitState.CLOSED
        self._error_count = 0
        self._last_errors = []
        now = datetime.now().isoformat()
        self._append_history({"ts": now, "action": "force_reset"})
        logger.info(f"[CircuitBreaker] 강제 리셋: context={self.context}")
        if self.persistent:
            self._save_state()

    def _append_history(self, entry: dict) -> None:
        self._history.append(entry)
        if len(self._history) > HISTORY_LIMIT:
            self._history = self._history[-HISTORY_LIMIT:]

    def _state_filepath(self) -> Path:
        return CB_STATE_DIR / f"{self.context}.json"

    def _load_state(self) -> None:
        """영속 저장소에서 상태 로드"""
        fp = self._state_filepath()
        if not fp.exists():
            return
        try:
            with open(fp, encoding="utf-8") as f:
                data = json.load(f)
            self._state = CircuitState(data.get("state", "closed"))
            self._error_count = data.get("error_count", 0)
            self._last_error_ts = data.get("last_error_ts")
            self._last_success_ts = data.get("last_success_ts")
            self._history = data.get("history", [])
        except (OSError, json.JSONDecodeError, ValueError) as e:
            logger.error(f"[CircuitBreaker] 상태 로드 실패: {fp}, {e}")

    def _save_state(self) -> None:
        """영속 저장소에 상태 저장"""
        try:
            CB_STATE_DIR.mkdir(parents=True, exist_ok=True)
            fp = self._state_filepath()
            data = {
                "context": self.context,
                "state": self._state.value,
                "error_count": self._error_count,
                "threshold": self.threshold,
                "last_error_ts": self._last_error_ts,
                "last_success_ts": self._last_success_ts,
                "history": self._history,
            }
            with open(fp, "w", encoding="utf-8") as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        except OSError as e:
            logger.error(f"[CircuitBreaker] 상태 저장 실패: {e}")


# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------


def create_circuit_breaker(
    context: str,
    strategy_type: str = "autofix",
    threshold: int = 3,
    **kwargs: Any,
) -> CircuitBreaker:
    """편의 팩토리 함수"""
    if strategy_type == "escalation":
        strategy: RecoveryStrategy = EscalationStrategy()
    else:
        strategy = AutoFixStrategy(threshold=threshold)

    return CircuitBreaker(
        context=context,
        strategy=strategy,
        threshold=threshold,
        **kwargs,
    )


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def _cli_record_error(args: argparse.Namespace) -> None:
    cb = create_circuit_breaker(args.context, strategy_type=args.strategy)
    error_info = {"message": args.message, "source": args.context, "details": None}
    action = cb.record_error(error_info)
    result = {
        "context": args.context,
        "state": cb.state.value,
        "error_count": cb.error_count,
        "action": action.value,
    }
    print(json.dumps(result, ensure_ascii=False))


def _cli_check(args: argparse.Namespace) -> None:
    cb = create_circuit_breaker(args.context)
    result = {
        "context": args.context,
        "state": cb.state.value,
        "error_count": cb.error_count,
    }
    print(json.dumps(result, ensure_ascii=False))


def _cli_record_success(args: argparse.Namespace) -> None:
    cb = create_circuit_breaker(args.context)
    cb.record_success()
    result = {
        "context": args.context,
        "state": cb.state.value,
        "error_count": cb.error_count,
    }
    print(json.dumps(result, ensure_ascii=False))


def _cli_reset(args: argparse.Namespace) -> None:
    cb = create_circuit_breaker(args.context)
    cb.force_reset()
    result = {
        "context": args.context,
        "state": cb.state.value,
        "error_count": cb.error_count,
    }
    print(json.dumps(result, ensure_ascii=False))


def main() -> None:
    parser = argparse.ArgumentParser(
        prog="circuit_breaker",
        description="Circuit Breaker CLI",
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    # record-error
    p_err = subparsers.add_parser("record-error", help="에러 기록")
    p_err.add_argument("--context", required=True, help="컨텍스트 키")
    p_err.add_argument("--message", required=True, help="에러 메시지")
    p_err.add_argument(
        "--strategy",
        default="autofix",
        choices=["autofix", "escalation"],
        help="복구 전략",
    )
    p_err.set_defaults(func=_cli_record_error)

    # check
    p_check = subparsers.add_parser("check", help="상태 확인")
    p_check.add_argument("--context", required=True, help="컨텍스트 키")
    p_check.set_defaults(func=_cli_check)

    # record-success
    p_ok = subparsers.add_parser("record-success", help="성공 기록")
    p_ok.add_argument("--context", required=True, help="컨텍스트 키")
    p_ok.set_defaults(func=_cli_record_success)

    # reset
    p_reset = subparsers.add_parser("reset", help="강제 리셋")
    p_reset.add_argument("--context", required=True, help="컨텍스트 키")
    p_reset.set_defaults(func=_cli_reset)

    args = parser.parse_args()
    try:
        args.func(args)
    except Exception as e:
        logger.error(f"[CircuitBreaker CLI] 오류: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
