"""utils/post_merge_smoke_runner.py — task-2512 5 모듈 #4.

회장 명시: 자동 머지 직후 origin/main 기준 smoke를 자동 실행하고,
smoke 실패 시에만 Critical #7 (POST_MERGE_SMOKE_FAILED) escalation packet 생성.

산출물 = 코드 + 회귀 테스트. wiring(task-2514)/reporting(task-2513) 보류.

automation_contracts.py(task-2509+2 freeze) 그대로 import:
  - SmokeResult, CriticalEscalationType, EscalationPacket
  - to_json
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import re
import shlex
import subprocess
import sys
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional

WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", str(Path(__file__).resolve().parent.parent)))

# CLI 직접 실행 시 패키지 루트를 sys.path에 추가
_HERE = Path(__file__).resolve().parent.parent  # utils/ → worktree root
if str(_HERE) not in sys.path:
    sys.path.insert(0, str(_HERE))

from utils.automation_contracts import (  # noqa: E402  # pyright: ignore[reportMissingImports]
    CriticalEscalationType,
    EscalationPacket,
    SmokeResult,
)

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

DEFAULT_TIMEOUT_SEC = 600
DEFAULT_OUTPUT_CAP_BYTES = 64 * 1024  # 64KB head/tail capture
HEAD_TAIL_RATIO = 2  # head=cap/2, tail=cap/2
TRUNCATE_MARKER_FMT = "\n...[TRUNCATED {n} bytes]...\n"

# replay fixture registry — 회장 §9 (4 fixture). entries are documentation:
# real smoke commands are read from task md or registry below.
REPLAY_FIXTURES: dict[str, dict[str, Any]] = {
    "task-2506": {"merge_commit_hint": "4486ea36", "pr": 56},
    "task-2507": {"merge_commit_hint": "2cd8178b", "pr": 55},
    "task-2509": {"merge_commit_hint": "38334b09", "pr": 58},
    "task-2511": {"merge_commit_hint": "59ec8d37", "pr": 62},
}

# REPLAY_FIXTURES lookup helper — 회장 §9 replay 시 merge_commit_hint 제공
def get_replay_fixture(task_id: str) -> Optional[dict[str, Any]]:
    """REPLAY_FIXTURES에서 task_id에 해당하는 fixture 반환. 없으면 None."""
    return REPLAY_FIXTURES.get(task_id)

# task_id → smoke_command registry (task §2). 미정의면 정책 #10 적용.
SMOKE_COMMAND_REGISTRY: dict[str, list[str]] = {
    "task-2506": ["pytest", "tests/regression/test_critical_gap_2506.py", "-q"],
    "task-2507": ["pytest", "tests/regression/test_git_evidence_2507.py", "-q"],
    "task-2509": ["pytest", "tests/regression/test_merge_queue_executor_2509.py", "-q"],
    "task-2511": ["pytest", "tests/regression/test_auto_gemini_triage_2511.py", "-q"],
    "task-2512": ["pytest", "tests/regression/test_post_merge_smoke_runner_2512.py", "-q"],
}

# ---------------------------------------------------------------------------
# Status enum (envelope-only; not in frozen contract)
# ---------------------------------------------------------------------------

class SmokeStatus(str, Enum):
    PASS = "PASS"
    FAIL = "FAIL"
    SKIPPED = "SKIPPED"
    TIMEOUT = "TIMEOUT"
    BLOCKED = "BLOCKED"  # smoke missing + dry_run=False


# ---------------------------------------------------------------------------
# Envelope dataclass
# ---------------------------------------------------------------------------

@dataclass
class PostMergeSmokeRun:
    """Smoke 실행 envelope. 회장 §6 정보량을 freeze contract 위반 없이 전달.

    Fields:
      merge_commit:     입력 SHA (origin/main HEAD와 일치 검증 후 보존)
      task_id:          task spec에서 추출한 task_id
      status:           SmokeStatus enum
      smoke_result:     freeze contract `SmokeResult` (transport)
      duration_ms:      subprocess 실행 시간(ms)
      smoke_command:    실제 실행한 command (None=정의 안됨)
      allow_continuation: PASS/SKIPPED 시 True. merge_queue 다음 단계 신호.
      escalation:       FAIL/TIMEOUT/BLOCKED 시 EscalationPacket. 그 외 None.
      stale:            merge_commit ≠ origin/main HEAD 시 True
      dry_run:          입력 dry_run 플래그 보존
    """
    merge_commit: str
    task_id: str
    status: SmokeStatus
    smoke_result: SmokeResult
    duration_ms: int
    smoke_command: Optional[list[str]]
    allow_continuation: bool
    escalation: Optional[EscalationPacket]
    stale: bool
    dry_run: bool

    def to_dict(self) -> dict:
        d = asdict(self)
        # Enum → str
        d["status"] = self.status.value if isinstance(self.status, Enum) else self.status
        if self.escalation is not None:
            esc = d["escalation"]
            esc["escalation_type"] = (
                self.escalation.escalation_type.value
                if isinstance(self.escalation.escalation_type, Enum)
                else self.escalation.escalation_type
            )
        return d

    def to_json(self) -> str:
        return json.dumps(self.to_dict())


# ---------------------------------------------------------------------------
# Subprocess wrapper (테스트 inject 용)
# ---------------------------------------------------------------------------

RunnerType = Callable[..., Any]


def _default_runner(args: list[str], cwd: Optional[str] = None, timeout: int = DEFAULT_TIMEOUT_SEC):
    return subprocess.run(
        args,
        cwd=cwd or str(WORKSPACE),
        capture_output=True,
        text=True,
        timeout=timeout,
    )


# ---------------------------------------------------------------------------
# Forbidden flag check (회장 명시 — admin/force/rebase/cherry-pick 금지)
# ---------------------------------------------------------------------------

FORBIDDEN_GIT_FLAGS = {"--force", "-f", "--force-with-lease", "--no-verify"}


def assert_no_forbidden_git_flags(args: list[str]) -> None:
    """smoke command에 force/admin/rebase/cherry-pick 등 금지 인자 차단."""
    flat = list(args)
    bad = [a for a in flat if a in FORBIDDEN_GIT_FLAGS or a == "--admin"]
    if bad:
        raise RuntimeError(f"FORBIDDEN_GIT_FLAGS detected: {bad}")
    if "rebase" in flat:
        raise RuntimeError("REBASE_FORBIDDEN")
    if "cherry-pick" in flat:
        raise RuntimeError("CHERRY_PICK_FORBIDDEN")


# ---------------------------------------------------------------------------
# stdout/stderr head/tail capture (cap default 64KB)
# ---------------------------------------------------------------------------

def capture_head_tail(text: Optional[str], cap: int = DEFAULT_OUTPUT_CAP_BYTES) -> str:
    """text를 cap byte로 자른 결과 반환. 초과 시 head + truncate marker + tail."""
    if text is None:
        return ""
    encoded = text.encode("utf-8", errors="replace")
    n = len(encoded)
    if n <= cap:
        return text
    half = cap // HEAD_TAIL_RATIO  # 작은 cap에서도 head+tail 합이 cap을 초과하지 않도록
    head_bytes = encoded[:half]
    tail_bytes = encoded[-half:]
    truncated_n = n - 2 * half
    marker = TRUNCATE_MARKER_FMT.format(n=truncated_n)
    return (
        head_bytes.decode("utf-8", errors="replace")
        + marker
        + tail_bytes.decode("utf-8", errors="replace")
    )


# ---------------------------------------------------------------------------
# task spec 파싱 — task_id + smoke_command 추출
# ---------------------------------------------------------------------------

_YAML_BLOCK_RE = re.compile(r"```yaml\s*\n(.*?)```", re.DOTALL)
_TASK_ID_HEADER_RE = re.compile(r"^# (task-\d+(?:\+\d+)?)", re.MULTILINE)


def extract_task_id(task_text: str, task_file: Path) -> str:
    m = _TASK_ID_HEADER_RE.search(task_text)
    return m.group(1) if m else task_file.stem


def extract_smoke_command(task_text: str, task_id: str) -> Optional[list[str]]:
    """task md 본문에서 smoke_command 추출.

    우선순위:
      1) yaml block 안 `smoke_command:` (list 또는 quoted scalar)
      2) `smoke_command:` 단독 라인 (markdown 스칼라)
      3) SMOKE_COMMAND_REGISTRY[task_id]
      4) None (미정의)
    """
    # (1) yaml block
    yaml_block = ""
    m = _YAML_BLOCK_RE.search(task_text)
    if m:
        yaml_block = m.group(1)
        cmd = _parse_smoke_command_in_yaml(yaml_block)
        if cmd:
            return cmd
    # (2) markdown 스칼라
    line_m = re.search(r"^smoke_command:\s*(.+)$", task_text, re.MULTILINE)
    if line_m:
        raw = line_m.group(1).strip().strip("`")
        if raw and raw.lower() not in {"none", "null"}:
            return shlex.split(raw)
    # (3) registry
    if task_id in SMOKE_COMMAND_REGISTRY:
        return list(SMOKE_COMMAND_REGISTRY[task_id])
    # (4) None
    return None


def _parse_smoke_command_in_yaml(yaml_block: str) -> Optional[list[str]]:
    # list 형태
    list_m = re.search(
        r"^smoke_command:\s*\n((?:\s*-\s*.+\n?)+)",
        yaml_block,
        re.MULTILINE,
    )
    if list_m:
        items: list[str] = []
        for line in list_m.group(1).splitlines():
            s = line.strip()
            if not s.startswith("-"):
                continue
            raw = s[1:].strip()
            if raw.startswith('"') and raw.endswith('"'):
                raw = raw[1:-1]
            elif raw.startswith("'") and raw.endswith("'"):
                raw = raw[1:-1]
            items.append(raw)
        return items if items else None
    # scalar 형태
    scalar_m = re.search(r"^smoke_command:\s*(.+)$", yaml_block, re.MULTILINE)
    if scalar_m:
        raw = scalar_m.group(1).strip().strip('"').strip("'")
        if raw and raw.lower() not in {"none", "null"}:
            return shlex.split(raw)
    return None


# ---------------------------------------------------------------------------
# main HEAD ↔ merge_commit stale 판정
# ---------------------------------------------------------------------------

def check_main_head_stale(merge_commit: str, runner: RunnerType) -> bool:
    """git fetch origin main → rev-parse origin/main과 merge_commit 비교.

    return True if origin/main HEAD가 merge_commit과 다름 (stale).
    runner 실패 시 stale=True 보수적 판정.
    """
    try:
        runner(["git", "fetch", "origin", "main"], timeout=60)
        rp = runner(["git", "rev-parse", "origin/main"], timeout=30)
        if getattr(rp, "returncode", 1) != 0:
            return True
        head = (getattr(rp, "stdout", "") or "").strip()
        if not head:
            return True
        return not (head == merge_commit or head.startswith(merge_commit) or merge_commit.startswith(head))
    except Exception as exc:  # noqa: BLE001
        logger.warning("check_main_head_stale failed: %s", exc)
        return True


# ---------------------------------------------------------------------------
# Critical #7 packet builder
# ---------------------------------------------------------------------------

def build_smoke_failed_packet(
    *,
    task_id: str,
    pr_number: int,
    merge_commit: str,
    smoke_result: SmokeResult,
    status: SmokeStatus,
    smoke_command: Optional[list[str]],
    duration_ms: int,
    extra_evidence: Optional[dict] = None,
    timeout_sec: int = DEFAULT_TIMEOUT_SEC,
) -> EscalationPacket:
    """POST_MERGE_SMOKE_FAILED escalation packet.

    회장 명시 — Critical 7종 외 보고 금지. enum 정확 매칭.
    """
    reason_map = {
        SmokeStatus.FAIL: "smoke command exited non-zero on origin/main",
        SmokeStatus.TIMEOUT: f"smoke command exceeded timeout ({timeout_sec}s)",
        SmokeStatus.BLOCKED: "smoke_command undefined and dry_run=False (auto merge requires smoke definition)",
    }
    safe_options = [
        "Re-run smoke locally on freshly fetched origin/main",
        "Disable auto-merge for this task and route to chair",
        "Define smoke_command in task spec and retry",
    ]
    recommended = "Re-run smoke locally on freshly fetched origin/main"
    evidence: dict[str, Any] = {
        "merge_commit": merge_commit,
        "task_id": task_id,
        "smoke_command": smoke_command,
        "status": status.value,
        "exit_code": smoke_result.exit_code,
        "stdout_tail": smoke_result.stdout_tail,
        "stderr_tail": smoke_result.stderr_tail,
        "failure_reason": smoke_result.failure_reason,
        "duration_ms": duration_ms,
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }
    if extra_evidence:
        evidence.update(extra_evidence)
    return EscalationPacket(
        task_id=task_id,
        pr_number=pr_number,
        escalation_type=CriticalEscalationType.POST_MERGE_SMOKE_FAILED,
        reason=reason_map.get(status, "smoke failure"),
        why_auto_cannot_continue=(
            "Post-merge smoke against origin/main failed. Continuing auto-merge "
            "of subsequent PRs without human review violates 자동 머지 10조건 #10."
        ),
        safe_options=safe_options,
        recommended_option=recommended,
        evidence=evidence,
    )


# ---------------------------------------------------------------------------
# Core API: run_post_merge_smoke
# ---------------------------------------------------------------------------

def run_post_merge_smoke(
    *,
    task_file: Path,
    merge_commit: str,
    dry_run: bool = True,
    runner: Optional[RunnerType] = None,
    pr_number: int = 0,
    timeout_sec: int = DEFAULT_TIMEOUT_SEC,
    output_cap_bytes: int = DEFAULT_OUTPUT_CAP_BYTES,
    skip_stale_check: bool = False,
    expected_task_id: Optional[str] = None,
) -> PostMergeSmokeRun:
    """task spec 기준 main smoke 실행.

    Args:
        task_file: task md 절대경로
        merge_commit: 자동 머지된 commit SHA
        dry_run: smoke_command가 None일 때만 의미 있음. None+dry_run=True → SKIPPED,
                 None+dry_run=False → BLOCKED. smoke_command가 정의된 경우 dry_run과
                 무관하게 항상 실제 실행 (회장 §10 표 정합).
        runner: subprocess wrapper (test injection)
        pr_number: PR 번호 (없으면 0)
        timeout_sec: subprocess timeout (default 600). escalation evidence에 그대로 기록.
        output_cap_bytes: stdout/stderr head/tail cap (default 64KB)
        skip_stale_check: True면 main HEAD 비교 skip (test 격리용)
        expected_task_id: 외부에서 task_id를 직접 주입 (testing)

    Returns:
        PostMergeSmokeRun envelope. stale=True가 발견되면 PASS여도
        allow_continuation=False로 강등된다 (자동 머지 큐 진행 차단).
    """
    runner = runner or _default_runner
    task_text = task_file.read_text(encoding="utf-8")
    task_id = expected_task_id or extract_task_id(task_text, task_file)
    smoke_command = extract_smoke_command(task_text, task_id)

    # forbidden flag check (smoke_command 사전 검증)
    if smoke_command:
        assert_no_forbidden_git_flags(smoke_command)

    # stale check
    stale = False if skip_stale_check else check_main_head_stale(merge_commit, runner)

    # 정책 #10 — smoke 미정의 시
    if smoke_command is None:
        if dry_run:
            return _build_skipped_run(
                merge_commit=merge_commit,
                task_id=task_id,
                stale=stale,
                dry_run=dry_run,
            )
        # non-dry-run + 미정의 → BLOCKED + Critical #7 packet
        sr = SmokeResult(
            command="",
            passed=False,
            exit_code=-1,
            stdout_tail="",
            stderr_tail="smoke_command undefined; auto-merge requires definition",
            failure_reason="SMOKE_COMMAND_UNDEFINED",
        )
        packet = build_smoke_failed_packet(
            task_id=task_id, pr_number=pr_number,
            merge_commit=merge_commit, smoke_result=sr,
            status=SmokeStatus.BLOCKED, smoke_command=None,
            duration_ms=0, extra_evidence={"stale": stale},
            timeout_sec=timeout_sec,
        )
        return PostMergeSmokeRun(
            merge_commit=merge_commit, task_id=task_id,
            status=SmokeStatus.BLOCKED, smoke_result=sr,
            duration_ms=0, smoke_command=None,
            allow_continuation=False, escalation=packet,
            stale=stale, dry_run=dry_run,
        )

    # smoke_command 정의됨 — 회장 §10 표 (dry_run | defined → smoke 실행).
    # dry_run은 미정의 케이스에서만 SKIPPED/BLOCKED 분기에 사용. 정의된 경우 항상 실제 실행.
    cmd_str = " ".join(smoke_command)
    started = datetime.now(timezone.utc)
    try:
        result = runner(smoke_command, timeout=timeout_sec)
    except subprocess.TimeoutExpired as exc:
        duration_ms = int((datetime.now(timezone.utc) - started).total_seconds() * 1000)
        stdout_tail = capture_head_tail(getattr(exc, "stdout", "") or "", output_cap_bytes)
        stderr_tail = capture_head_tail(getattr(exc, "stderr", "") or "", output_cap_bytes)
        sr = SmokeResult(
            command=cmd_str, passed=False, exit_code=-1,
            stdout_tail=stdout_tail, stderr_tail=stderr_tail,
            failure_reason="TIMEOUT",
        )
        packet = build_smoke_failed_packet(
            task_id=task_id, pr_number=pr_number,
            merge_commit=merge_commit, smoke_result=sr,
            status=SmokeStatus.TIMEOUT, smoke_command=smoke_command,
            duration_ms=duration_ms, extra_evidence={"stale": stale},
            timeout_sec=timeout_sec,
        )
        return PostMergeSmokeRun(
            merge_commit=merge_commit, task_id=task_id,
            status=SmokeStatus.TIMEOUT, smoke_result=sr,
            duration_ms=duration_ms, smoke_command=smoke_command,
            allow_continuation=False, escalation=packet,
            stale=stale, dry_run=dry_run,
        )
    except Exception as exc:  # noqa: BLE001
        duration_ms = int((datetime.now(timezone.utc) - started).total_seconds() * 1000)
        sr = SmokeResult(
            command=cmd_str, passed=False, exit_code=-1,
            stdout_tail="",
            stderr_tail=capture_head_tail(str(exc), output_cap_bytes),
            failure_reason=f"RUNNER_ERROR: {type(exc).__name__}",
        )
        packet = build_smoke_failed_packet(
            task_id=task_id, pr_number=pr_number,
            merge_commit=merge_commit, smoke_result=sr,
            status=SmokeStatus.FAIL, smoke_command=smoke_command,
            duration_ms=duration_ms, extra_evidence={"stale": stale},
            timeout_sec=timeout_sec,
        )
        return PostMergeSmokeRun(
            merge_commit=merge_commit, task_id=task_id,
            status=SmokeStatus.FAIL, smoke_result=sr,
            duration_ms=duration_ms, smoke_command=smoke_command,
            allow_continuation=False, escalation=packet,
            stale=stale, dry_run=dry_run,
        )

    duration_ms = int((datetime.now(timezone.utc) - started).total_seconds() * 1000)
    rc = getattr(result, "returncode", -1)
    stdout_tail = capture_head_tail(getattr(result, "stdout", "") or "", output_cap_bytes)
    stderr_tail = capture_head_tail(getattr(result, "stderr", "") or "", output_cap_bytes)
    if rc == 0:
        sr = SmokeResult(
            command=cmd_str, passed=True, exit_code=0,
            stdout_tail=stdout_tail, stderr_tail=stderr_tail,
            failure_reason=None,
        )
        # stale 게이트: PASS여도 origin/main HEAD ≠ merge_commit이면 큐 진행 차단.
        # 이유: smoke를 다른 SHA에서 돌렸을 가능성 → 자동 머지 보장 깨짐.
        return PostMergeSmokeRun(
            merge_commit=merge_commit, task_id=task_id,
            status=SmokeStatus.PASS, smoke_result=sr,
            duration_ms=duration_ms, smoke_command=smoke_command,
            allow_continuation=not stale, escalation=None,
            stale=stale, dry_run=dry_run,
        )
    # FAIL
    sr = SmokeResult(
        command=cmd_str, passed=False, exit_code=rc,
        stdout_tail=stdout_tail, stderr_tail=stderr_tail,
        failure_reason=f"EXIT_{rc}",
    )
    packet = build_smoke_failed_packet(
        task_id=task_id, pr_number=pr_number,
        merge_commit=merge_commit, smoke_result=sr,
        status=SmokeStatus.FAIL, smoke_command=smoke_command,
        duration_ms=duration_ms, extra_evidence={"stale": stale},
        timeout_sec=timeout_sec,
    )
    return PostMergeSmokeRun(
        merge_commit=merge_commit, task_id=task_id,
        status=SmokeStatus.FAIL, smoke_result=sr,
        duration_ms=duration_ms, smoke_command=smoke_command,
        allow_continuation=False, escalation=packet,
        stale=stale, dry_run=dry_run,
    )


def _build_skipped_run(
    *, merge_commit: str, task_id: str, stale: bool, dry_run: bool,
    smoke_command: Optional[list[str]] = None,
) -> PostMergeSmokeRun:
    cmd_str = " ".join(smoke_command) if smoke_command else ""
    sr = SmokeResult(
        command=cmd_str, passed=True, exit_code=0,
        stdout_tail="", stderr_tail="",
        failure_reason="SKIPPED" if not smoke_command else None,
    )
    return PostMergeSmokeRun(
        merge_commit=merge_commit, task_id=task_id,
        status=SmokeStatus.SKIPPED, smoke_result=sr,
        duration_ms=0, smoke_command=smoke_command,
        allow_continuation=True, escalation=None,
        stale=stale, dry_run=dry_run,
    )


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def _parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(prog="post_merge_smoke_runner")
    p.add_argument("--task-file", required=True, help="task md 절대 경로")
    p.add_argument("--merge-commit", required=True, help="자동 머지된 commit SHA")
    mode = p.add_mutually_exclusive_group()
    mode.add_argument(
        "--dry-run", action="store_true",
        help="smoke_command 미정의 시 SKIPPED 처리 (정의된 경우는 항상 실행).",
    )
    mode.add_argument(
        "--apply", action="store_true",
        help="smoke_command 미정의 시 BLOCKED 처리 (정의된 경우 항상 실행).",
    )
    p.add_argument("--no-audit", action="store_true", help="audit 로그 기록 skip (테스트용)")
    p.add_argument("--pr-number", type=int, default=0, help="PR 번호 (없으면 0)")
    p.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT_SEC)
    p.add_argument("--skip-stale-check", action="store_true", help="origin/main HEAD 비교 skip")
    return p.parse_args(argv)


def main(argv: Optional[list[str]] = None) -> int:
    args = _parse_args(argv)
    dry_run = True if args.dry_run else (not args.apply)
    run = run_post_merge_smoke(
        task_file=Path(args.task_file),
        merge_commit=args.merge_commit,
        dry_run=dry_run,
        pr_number=args.pr_number,
        timeout_sec=args.timeout,
        skip_stale_check=args.skip_stale_check,
    )
    print(run.to_json())
    if run.status == SmokeStatus.PASS or run.status == SmokeStatus.SKIPPED:
        return 0
    if run.status == SmokeStatus.FAIL:
        return 1
    if run.status == SmokeStatus.TIMEOUT:
        return 2
    return 3  # BLOCKED


if __name__ == "__main__":  # pragma: no cover
    sys.exit(main())
