#!/usr/bin/env python3
"""utils/circuit_breaker.py 테스트 스위트 (TDD)"""

import subprocess
import sys
import time
from pathlib import Path
from unittest.mock import patch

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from utils.circuit_breaker import (
    AutoFixStrategy,
    CircuitBreaker,
    CircuitState,
    EscalationStrategy,
    FileSnapshot,
    RecoveryAction,
    RollbackManager,
    create_circuit_breaker,
)

CB_STATE_DIR = Path("/home/jay/workspace/memory/logs/circuit-breaker")
ESCALATIONS_DIR = Path("/home/jay/workspace/memory/escalations")


# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------

def _make_cb(
    tmp_path: Path,
    context: str = "test-ctx",
    strategy=None,
    threshold: int = 3,
    cooldown_seconds: int = 300,
    persistent: bool = True,
    monkeypatch=None,
) -> CircuitBreaker:
    if strategy is None:
        strategy = AutoFixStrategy(threshold=threshold)
    if monkeypatch is not None:
        fake_log = tmp_path / "logs" / "circuit-breaker"
        fake_log.mkdir(parents=True, exist_ok=True)
        monkeypatch.setattr(
            "utils.circuit_breaker.CB_STATE_DIR", fake_log, raising=False
        )
    return CircuitBreaker(
        context=context,
        strategy=strategy,
        threshold=threshold,
        cooldown_seconds=cooldown_seconds,
        persistent=persistent,
    )


# ===========================================================================
# 1. 기본 동작
# ===========================================================================

class TestInitialState:
    def test_initial_state(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, monkeypatch=monkeypatch)
        assert cb.state == CircuitState.CLOSED
        assert cb.error_count == 0

    def test_record_error_increments_count(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=5, monkeypatch=monkeypatch)
        cb.record_error({"msg": "err1"})
        assert cb.error_count == 1
        cb.record_error({"msg": "err2"})
        assert cb.error_count == 2

    def test_record_success_resets_count(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=5, monkeypatch=monkeypatch)
        cb.record_error({"msg": "err"})
        assert cb.error_count == 1
        cb.record_success()
        assert cb.error_count == 0


# ===========================================================================
# 2. 상태 전환
# ===========================================================================

class TestStateTransitions:
    def test_closed_to_open_at_threshold(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        assert cb.state == CircuitState.OPEN

    def test_open_always_escalates(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        assert cb.state == CircuitState.OPEN
        action = cb.record_error({"msg": "extra"})
        assert action == RecoveryAction.ESCALATE

    def test_try_reset_before_cooldown(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, cooldown_seconds=300, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        result = cb.try_reset()
        assert result is False
        assert cb.state == CircuitState.OPEN

    def test_try_reset_after_cooldown(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, cooldown_seconds=1, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        time.sleep(1.1)
        result = cb.try_reset()
        assert result is True
        assert cb.state == CircuitState.HALF_OPEN

    def test_half_open_to_closed_on_success(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, cooldown_seconds=1, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        time.sleep(1.1)
        cb.try_reset()
        assert cb.state == CircuitState.HALF_OPEN
        cb.record_success()
        assert cb.state == CircuitState.CLOSED
        assert cb.error_count == 0

    def test_half_open_to_open_on_error(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, cooldown_seconds=1, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        time.sleep(1.1)
        cb.try_reset()
        assert cb.state == CircuitState.HALF_OPEN
        cb.record_error({"msg": "fail again"})
        assert cb.state == CircuitState.OPEN

    def test_force_reset(self, tmp_path, monkeypatch):
        cb = _make_cb(tmp_path, threshold=3, monkeypatch=monkeypatch)
        for _ in range(3):
            cb.record_error({"msg": "e"})
        assert cb.state == CircuitState.OPEN
        cb.force_reset()
        assert cb.state == CircuitState.CLOSED


# ===========================================================================
# 3. 전략 패턴
# ===========================================================================

class TestStrategies:
    def test_autofix_retries_below_threshold(self):
        strategy = AutoFixStrategy(threshold=3)
        action = strategy.on_error("ctx", {"msg": "e"}, attempt=1)
        assert action == RecoveryAction.RETRY

    def test_autofix_escalates_at_threshold(self):
        strategy = AutoFixStrategy(threshold=3)
        action = strategy.on_error("ctx", {"msg": "e"}, attempt=3)
        assert action == RecoveryAction.ESCALATE

    def test_escalation_always_escalates(self):
        strategy = EscalationStrategy()
        for attempt in range(5):
            action = strategy.on_error("ctx", {"msg": "e"}, attempt=attempt)
            assert action == RecoveryAction.ESCALATE

    def test_autofix_on_circuit_open_creates_escalation(self, tmp_path, monkeypatch):
        monkeypatch.setattr(
            "utils.circuit_breaker.ESCALATIONS_DIR", tmp_path / "escalations", raising=False
        )
        strategy = AutoFixStrategy(threshold=3)
        strategy.on_circuit_open("test-ctx", error_count=3)
        files = list((tmp_path / "escalations").glob("test-ctx_*_escalation.json"))
        assert len(files) >= 1

    def test_escalation_on_circuit_open_creates_escalation(self, tmp_path, monkeypatch):
        monkeypatch.setattr(
            "utils.circuit_breaker.ESCALATIONS_DIR", tmp_path / "escalations", raising=False
        )
        strategy = EscalationStrategy()
        strategy.on_circuit_open("test-ctx", error_count=5)
        files = list((tmp_path / "escalations").glob("test-ctx_*_escalation.json"))
        assert len(files) >= 1


# ===========================================================================
# 4. FileSnapshot
# ===========================================================================

class TestFileSnapshot:
    def test_snapshot_capture_existing_file(self, tmp_path):
        target = tmp_path / "data.txt"
        target.write_text("hello")
        snap = FileSnapshot()
        snap.capture(str(target))
        assert str(target) in snap.files

    def test_snapshot_capture_nonexistent_file(self, tmp_path):
        missing = str(tmp_path / "nonexistent.txt")
        snap = FileSnapshot()
        snap.capture(missing)
        assert missing in snap.files
        assert snap._snapshots[missing] is None

    def test_snapshot_restore_file(self, tmp_path):
        target = tmp_path / "data.txt"
        target.write_text("original")
        snap = FileSnapshot()
        snap.capture(str(target))
        target.write_text("modified")
        result = snap.restore_file(str(target))
        assert result is True
        assert target.read_text() == "original"

    def test_snapshot_restore_deletes_new_file(self, tmp_path):
        target = tmp_path / "new_file.txt"
        snap = FileSnapshot()
        snap.capture(str(target))
        assert snap._snapshots[str(target)] is None
        target.write_text("new content")
        snap.restore_file(str(target))
        assert not target.exists()

    def test_snapshot_files_property(self, tmp_path):
        f1 = tmp_path / "a.txt"
        f2 = tmp_path / "b.txt"
        f1.write_text("a")
        f2.write_text("b")
        snap = FileSnapshot()
        snap.capture_multiple([str(f1), str(f2)])
        files = snap.files
        assert str(f1) in files
        assert str(f2) in files
        assert len(files) == 2

    def test_snapshot_restore_returns_restored_list(self, tmp_path):
        f1 = tmp_path / "c.txt"
        f1.write_text("original")
        snap = FileSnapshot()
        snap.capture(str(f1))
        f1.write_text("changed")
        restored = snap.restore()
        assert str(f1) in restored


# ===========================================================================
# 5. RollbackManager
# ===========================================================================

class TestRollbackManager:
    def test_rollback_success(self, tmp_path):
        f = tmp_path / "target.txt"
        f.write_text("v1")
        rm = RollbackManager(task_id="task-1", scope="last_operation")
        snap = rm.create_snapshot([str(f)])
        f.write_text("v2")
        result = rm.rollback(snap)
        assert result["success"] is True
        assert str(f) in result["restored"]
        assert result["failed"] == []
        assert f.read_text() == "v1"

    def test_rollback_with_escalation_on_failure(self, tmp_path, monkeypatch):
        monkeypatch.setattr(
            "utils.circuit_breaker.ESCALATIONS_DIR", tmp_path / "escalations", raising=False
        )
        f = tmp_path / "locked.txt"
        f.write_text("original")
        rm = RollbackManager(task_id="task-fail", scope="last_operation")
        snap = rm.create_snapshot([str(f)])

        with patch.object(snap, "restore_file", return_value=False):
            result = rm.rollback_with_escalation(snap)

        esc_files = list((tmp_path / "escalations").glob("*_escalation.json"))
        assert len(esc_files) >= 1
        assert result["success"] is False


# ===========================================================================
# 6. 영속성
# ===========================================================================

class TestPersistence:
    def test_persistent_state_save_load(self, tmp_path, monkeypatch):
        fake_log = tmp_path / "logs" / "circuit-breaker"
        fake_log.mkdir(parents=True, exist_ok=True)
        monkeypatch.setattr(
            "utils.circuit_breaker.CB_STATE_DIR", fake_log, raising=False
        )
        strategy = AutoFixStrategy(threshold=5)
        cb1 = CircuitBreaker(
            context="persist-ctx",
            strategy=strategy,
            threshold=5,
            persistent=True,
        )
        cb1.record_error({"msg": "e1"})
        cb1.record_error({"msg": "e2"})

        cb2 = CircuitBreaker(
            context="persist-ctx",
            strategy=AutoFixStrategy(threshold=5),
            threshold=5,
            persistent=True,
        )
        assert cb2.error_count == 2
        assert cb2.state == CircuitState.CLOSED

    def test_non_persistent_no_file(self, tmp_path, monkeypatch):
        fake_log = tmp_path / "logs" / "circuit-breaker"
        fake_log.mkdir(parents=True, exist_ok=True)
        monkeypatch.setattr(
            "utils.circuit_breaker.CB_STATE_DIR", fake_log, raising=False
        )
        cb = CircuitBreaker(
            context="no-persist-ctx",
            strategy=AutoFixStrategy(threshold=3),
            threshold=3,
            persistent=False,
        )
        cb.record_error({"msg": "e"})
        state_file = fake_log / "no-persist-ctx.json"
        assert not state_file.exists()


# ===========================================================================
# 7. 팩토리
# ===========================================================================

class TestFactory:
    def test_create_circuit_breaker_autofix(self, tmp_path, monkeypatch):
        fake_log = tmp_path / "logs" / "circuit-breaker"
        fake_log.mkdir(parents=True, exist_ok=True)
        monkeypatch.setattr(
            "utils.circuit_breaker.CB_STATE_DIR", fake_log, raising=False
        )
        cb = create_circuit_breaker("factory-ctx", strategy_type="autofix", threshold=5)
        assert isinstance(cb, CircuitBreaker)
        assert isinstance(cb.strategy, AutoFixStrategy)
        assert cb.state == CircuitState.CLOSED

    def test_create_circuit_breaker_escalation(self, tmp_path, monkeypatch):
        fake_log = tmp_path / "logs" / "circuit-breaker"
        fake_log.mkdir(parents=True, exist_ok=True)
        monkeypatch.setattr(
            "utils.circuit_breaker.CB_STATE_DIR", fake_log, raising=False
        )
        cb = create_circuit_breaker("factory-ctx2", strategy_type="escalation", threshold=3)
        assert isinstance(cb, CircuitBreaker)
        assert isinstance(cb.strategy, EscalationStrategy)


# ===========================================================================
# 8. CLI
# ===========================================================================

class TestCLI:
    def test_cli_record_error(self):
        workspace_root = str(Path(__file__).parent.parent.parent)
        env = {**__import__("os").environ, "PYTHONPATH": workspace_root}
        module_path = Path(__file__).parent.parent / "circuit_breaker.py"
        result = subprocess.run(
            [
                sys.executable,
                str(module_path),
                "record-error",
                "--context", "cli-ctx",
                "--message", "cli error",
            ],
            capture_output=True,
            text=True,
            env=env,
        )
        assert result.returncode == 0

    def test_cli_check_state(self):
        workspace_root = str(Path(__file__).parent.parent.parent)
        env = {**__import__("os").environ, "PYTHONPATH": workspace_root}
        module_path = Path(__file__).parent.parent / "circuit_breaker.py"
        result = subprocess.run(
            [
                sys.executable,
                str(module_path),
                "check",
                "--context", "cli-check-ctx",
            ],
            capture_output=True,
            text=True,
            env=env,
        )
        assert result.returncode == 0
        output = result.stdout + result.stderr
        assert "closed" in output.lower() or "state" in output.lower()
