"""engine_v2/engine_orchestrator.py — 멀티엔진 오케스트레이터."""

from __future__ import annotations

import asyncio
import logging
import time
from typing import Literal

from engine_v2 import cost_tracker
from engine_v2.circuit_breaker import CircuitBreaker
from engine_v2.cli_runner import CLIRunner
from engine_v2.content_sanitizer import check_error_gate, check_gate, sanitize
from engine_v2.engine_result import EngineResult, EngineRole

logger = logging.getLogger(__name__)

Mode = Literal["SEQUENTIAL", "PARALLEL", "BROADCAST"]

_SEMAPHORE_LIMIT = 3


class EngineOrchestrator:
    """멀티엔진 오케스트레이터.

    모드:
        SEQUENTIAL — 직렬 실행 (Semaphore 불필요)
        PARALLEL — asyncio.gather + Semaphore(3)
        BROADCAST — 동일 프롬프트를 모든 엔진에 병렬 전송
    """

    def __init__(self) -> None:
        self._semaphore = asyncio.Semaphore(_SEMAPHORE_LIMIT)
        self._breakers: dict[str, CircuitBreaker] = {
            "claude": CircuitBreaker(),
            "gemini": CircuitBreaker(),
            "codex": CircuitBreaker(),
        }

    async def run(
        self,
        mode: Mode,
        prompts: list[str],
        engines: list[EngineRole],
        task_id: str,
        step: int,
        timeout: int = 600,
    ) -> list[EngineResult]:
        """엔진 실행."""
        if mode == "SEQUENTIAL":
            return await self._run_sequential(prompts, engines, task_id, step, timeout)
        elif mode == "PARALLEL":
            return await self._run_parallel(prompts, engines, task_id, step, timeout)
        elif mode == "BROADCAST":
            return await self._run_broadcast(prompts, engines, task_id, step, timeout)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    async def _run_sequential(
        self,
        prompts: list[str],
        engines: list[EngineRole],
        task_id: str,
        step: int,
        timeout: int,
    ) -> list[EngineResult]:
        """직렬 실행."""
        results: list[EngineResult] = []
        for i, (prompt, engine) in enumerate(zip(prompts, engines)):
            result = await self._call_engine(engine, prompt, task_id, step + i, timeout)
            if check_gate(result.flagged_count):
                logger.warning("L3 gate triggered: flagged_count=%d", result.flagged_count)
                result.error = True
                results.append(result)
                break
            if check_error_gate(result.error, result.fallback_used):
                logger.warning("L4 gate triggered: error without fallback")
                results.append(result)
                break
            results.append(result)
        return results

    async def _run_parallel(
        self,
        prompts: list[str],
        engines: list[EngineRole],
        task_id: str,
        step: int,
        timeout: int,
    ) -> list[EngineResult]:
        """병렬 실행 (Semaphore 제한)."""

        async def _task(engine: EngineRole, prompt: str, s: int) -> EngineResult:
            async with self._semaphore:
                return await self._call_engine(engine, prompt, task_id, s, timeout)

        pairs = list(enumerate(zip(prompts, engines)))
        coros = [_task(eng, pmt, step + i) for i, (pmt, eng) in pairs]
        raw = await asyncio.gather(*coros, return_exceptions=True)
        return self._unwrap(raw, engines, task_id, step)

    async def _run_broadcast(
        self,
        prompts: list[str],
        engines: list[EngineRole],
        task_id: str,
        step: int,
        timeout: int,
    ) -> list[EngineResult]:
        """동일 프롬프트를 모든 엔진에 병렬 전송."""
        if not prompts:
            return []
        prompt = prompts[0]

        async def _task(engine: EngineRole) -> EngineResult:
            async with self._semaphore:
                return await self._call_engine(engine, prompt, task_id, step, timeout)

        raw = await asyncio.gather(*[_task(e) for e in engines], return_exceptions=True)
        return self._unwrap(raw, engines, task_id, step)

    def _unwrap(
        self,
        raw: list[EngineResult | BaseException],
        engines: list[EngineRole],
        task_id: str,
        step: int,
    ) -> list[EngineResult]:
        """gather 결과에서 Exception을 error EngineResult로 변환."""
        out: list[EngineResult] = []
        for i, r in enumerate(raw):
            if isinstance(r, BaseException):
                out.append(
                    EngineResult(engine=engines[i], content="", clean="", task_id=task_id, step=step + i, error=True)
                )
            else:
                out.append(r)
        return out

    async def _call_engine(
        self,
        engine: EngineRole,
        prompt: str,
        task_id: str,
        step: int,
        timeout: int,
    ) -> EngineResult:
        """단일 엔진 호출 + Sanitize + CircuitBreaker + CostTracker."""
        breaker = self._breakers.get(engine)

        if breaker is not None and not breaker.allow_request():
            logger.warning("CircuitBreaker OPEN: engine=%s blocked", engine)
            return EngineResult(engine=engine, content="", clean="", task_id=task_id, step=step, error=True)

        start = time.monotonic()
        try:
            if engine == "claude":
                cli_result = await CLIRunner.run_claude(prompt, timeout=timeout)
            elif engine == "gemini":
                cli_result = await CLIRunner.run_gemini(prompt, timeout=timeout)
            elif engine == "codex":
                cli_result = await CLIRunner.run_codex(prompt, timeout=timeout)
            else:
                raise ValueError(f"Unknown engine: {engine}")

            is_error = cli_result.returncode != 0 or cli_result.timed_out
            raw = cli_result.stdout if not is_error else cli_result.stderr
            clean, flagged = sanitize(raw)
            result = EngineResult(
                engine=engine,
                content=raw,
                clean=clean,
                task_id=task_id,
                step=step,
                token_est=len(raw) // 4,
                error=is_error,
                fallback_used=cli_result.fallback_used,
                flagged_count=flagged,
            )
        except Exception:
            result = EngineResult(engine=engine, content="", clean="", task_id=task_id, step=step, error=True)

        duration = time.monotonic() - start

        if breaker is not None:
            if result.error:
                breaker.record_failure()
            else:
                breaker.record_success()

        cost_tracker.log_usage(result, prompt_chars=len(prompt), duration_sec=duration)
        return result
