#!/usr/bin/env python3
"""failure_callback_dispatcher.py — task-2712 FAILURE_CALLBACK_BEFORE_EXIT_GUARD.

ANU callback 발사 enforcement (SELF_COLLECTOR_FORBIDDEN) + disk handoff /
supervisor crash marker + D1~D5 bypass defense + count-mismatch 감지.

회장 verbatim doctrine (task-2712 §6.3): 모든 cron 발사는 단일 entry point
`_fire_callback_with_enforcement()` wrapper 만 경유한다. collector_role=ANU +
collector_key/owner_key=ANU_KEY + self_key_used=false 가 아니면 fail-closed.

본 모듈은 §6.1 4-step fallback · §6.3.1 5 checkpoint · §6.3.2 D1~D5 · §6.3.3.B
detect_bypass_via_count_mismatch pseudocode 를 1:1 mirror 한다.
"""

from __future__ import annotations

import glob
import hashlib
import json
import os
import sys
import traceback
from datetime import datetime, timezone
from typing import Callable, List, Optional

try:
    from .terminal_state_classifier import ANU_KEY, _verify_bot_spawn  # noqa: F401
    from .failure_envelope_writer import (
        build_envelope,
        write_envelope,
        emit_stderr_line,
    )
except Exception:  # pragma: no cover - 직접 실행 fallback
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    from terminal_state_classifier import ANU_KEY, _verify_bot_spawn  # type: ignore # noqa: F401
    from failure_envelope_writer import (  # type: ignore
        build_envelope,
        write_envelope,
        emit_stderr_line,
    )


class CollectorViolation(Exception):
    """SELF_COLLECTOR_FORBIDDEN 위반 시 발생."""


# ── §6.3.1 In-flight enforcement (5 checkpoint) ───────────────────────────
def _validate_collector_strict(envelope: dict) -> None:
    """collector_role / collector_key / owner_key / self_key_used strict 검증.

    회장 verbatim §6.3.1 1:1 mirror. ANU key(c119085addb0f8b7) 가 아닌 임의/
    self key 면 CollectorViolation 으로 fail-closed.
    """
    if envelope.get("collector_role") != "ANU":
        raise CollectorViolation("SELF_COLLECTOR_FORBIDDEN: collector_role != ANU")
    if envelope.get("collector_key") != ANU_KEY:
        raise CollectorViolation(f"SELF_COLLECTOR_FORBIDDEN: collector_key != {ANU_KEY}")
    if envelope.get("owner_key") != ANU_KEY:
        raise CollectorViolation(f"SELF_COLLECTOR_FORBIDDEN: owner_key != {ANU_KEY}")
    if envelope.get("self_key_used") is not False:
        raise CollectorViolation(
            "SELF_COLLECTOR_FORBIDDEN: self_key_used must be explicit false"
        )


def _now_iso() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _write_audit(envelope: dict, events_dir: str, caller_stack: str) -> str:
    """§6.3.2 D4 runtime audit log: cron-fire-audit-<ts>.json (stack + sha256)."""
    os.makedirs(events_dir, exist_ok=True)
    ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f")
    payload = json.dumps(envelope, ensure_ascii=False, sort_keys=True)
    sha = hashlib.sha256(payload.encode("utf-8")).hexdigest()
    path = os.path.join(events_dir, f"cron-fire-audit-{ts}.json")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(
            {
                "ts": _now_iso(),
                "task_id": envelope.get("task_id"),
                "terminal_state": envelope.get("terminal_state"),
                "collector_role": envelope.get("collector_role"),
                "sha256": sha,
                "caller_stack": caller_stack,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )
    return path


def _fire_callback_with_enforcement(
    envelope: dict, cron_fn: Callable[[dict], object], events_dir: Optional[str] = None
) -> object:
    """§6.3.1 single entry point. validate → audit 박제 → cron_fn.

    D1: 모든 cron 발사는 본 wrapper 만 경유. D4: audit log 박제 (count mismatch
    감지의 분모). collector strict 위반 시 cron_fn 호출 전에 fail-closed.
    """
    _validate_collector_strict(envelope)
    caller_stack = "".join(traceback.format_stack(limit=6))
    if events_dir:
        _write_audit(envelope, events_dir, caller_stack)
    return cron_fn(envelope)


# ── §6.1 4-step fallback: disk handoff / supervisor crash marker ──────────
def write_handoff_marker(
    task_id: str,
    terminal_state: str,
    *,
    failure_kind: str = "",
    events_dir: str = "memory/events",
    phase: str = "",
    exit_code: int = 1,
    artifact_paths: Optional[list] = None,
    residual_pid: Optional[int] = None,
    critical7_match: bool = False,
) -> dict:
    """BLOCKED / INFRA_DEFECT / API_FAIL 의 disk handoff marker 박제.

    cron 등록 실패 시 (§6.1 step 3) 또는 dispatch.py inner instrumentation
    (§5.2) 에서 호출된다. exactly-one rule 은 write_envelope 가 보장한다.
    """
    envelope = build_envelope(
        task_id,
        terminal_state,
        phase=phase,
        exit_code=exit_code,
        failure_kind=failure_kind,
        artifact_paths=artifact_paths,
        residual_pid=residual_pid,
        critical7_match=critical7_match,
        registration_mode="failure_callback_before_exit_guard",
    )
    return write_envelope(envelope, events_dir, marker_type="failure_handoff")


def write_supervisor_crash_marker(
    task_id: str,
    *,
    exit_code: int = -9,
    failure_kind: str = "sigkill_or_oom",
    events_dir: str = "memory/events",
    phase: str = "",
    signal_source: str = "",
) -> dict:
    """CRASH_NO_EXIT_CODE (SIGKILL/OOM/kernel panic/trap) supervisor crash marker."""
    envelope = build_envelope(
        task_id,
        "CRASH_NO_EXIT_CODE",
        phase=phase,
        exit_code=exit_code,
        failure_kind=failure_kind or "sigkill_or_oom",
        registration_mode="failure_callback_before_exit_guard",
        summary=signal_source,
    )
    return write_envelope(envelope, events_dir, marker_type="supervisor_crash")


def fallback_chain(
    envelope: dict,
    cron_fn: Callable[[dict], object],
    events_dir: str,
) -> dict:
    """§6.1 4-step fallback chain.

    1) failure envelope JSON disk 박제 (terminal marker)
    2) ANU cron 발사 시도 (enforcement wrapper 경유)
    3) cron 실패 시 disk handoff marker 박제
    4) exit (caller 가 exit_code 유지)
    """
    task_id = envelope["task_id"]
    # step 1: terminal marker 박제
    write_result = write_envelope(envelope, events_dir, marker_type="failure_envelope")
    cron_status = "not_attempted"
    handoff_result = None
    # step 2: cron 발사
    try:
        _fire_callback_with_enforcement(envelope, cron_fn, events_dir=events_dir)
        cron_status = "fired"
    except CollectorViolation:
        raise  # SELF_COLLECTOR_FORBIDDEN 은 fail-closed (삼키지 않음)
    except Exception:
        cron_status = "cron_fail"
        # step 3: handoff marker fallback
        handoff_result = write_handoff_marker(
            task_id,
            envelope.get("terminal_state", "INFRA_DEFECT"),
            failure_kind="cron_fallback",
            events_dir=events_dir,
            phase=envelope.get("phase", ""),
            exit_code=envelope.get("exit_code", 1),
        )
        if handoff_result.get("status", "").startswith("FALLBACK"):
            emit_stderr_line(
                task_id,
                envelope.get("terminal_state", ""),
                envelope.get("exit_code", 1),
                "cron_fallback",
                envelope.get("phase", ""),
            )
    return {
        "envelope_write": write_result,
        "cron_status": cron_status,
        "handoff": handoff_result,
    }


# ── §6.3.3.B detect_bypass_via_count_mismatch (1:1 mirror) ─────────────────
def _glob(events_dir: str, pattern: str):
    return glob.glob(os.path.join(events_dir, pattern))


def detect_bypass_via_count_mismatch(
    window_start: float, window_end: float, events_dir: str = "memory/events"
) -> List[dict]:
    """§6.3.3.B pseudocode 1:1 mirror. terminal marker class count ↔ audit log
    count window 매칭 mismatch / multi-class concurrent / failure+done concurrent
    violation 을 감지한다."""
    success_marker_count = 0
    failure_marker_count = 0
    fire_markers = []
    class_counts = {
        "failure_envelope": 0,
        "failure_handoff_marker": 0,
        "supervisor_crash_marker": 0,
        "done": 0,
    }

    for path in _glob(events_dir, "*.failure-envelope.json"):
        if window_start <= os.path.getmtime(path) <= window_end:
            try:
                with open(path, encoding="utf-8") as f:
                    envelope = json.load(f)
            except Exception:
                envelope = {}
            if envelope.get("registration_mode") in (
                "normal_callback",
                "failure_callback_before_exit_guard",
            ):
                failure_marker_count += 1
                class_counts["failure_envelope"] += 1
                fire_markers.append((path, "failure_envelope", envelope.get("terminal_state")))
    for path in _glob(events_dir, "*.failure-handoff-marker.json"):
        if window_start <= os.path.getmtime(path) <= window_end:
            failure_marker_count += 1
            class_counts["failure_handoff_marker"] += 1
            fire_markers.append((path, "failure_handoff_marker", None))
    for path in _glob(events_dir, "*.supervisor-crash-marker.json"):
        if window_start <= os.path.getmtime(path) <= window_end:
            failure_marker_count += 1
            class_counts["supervisor_crash_marker"] += 1
            fire_markers.append((path, "supervisor_crash_marker", None))
    for path in _glob(events_dir, "*.done"):
        if window_start <= os.path.getmtime(path) <= window_end:
            success_marker_count += 1
            class_counts["done"] += 1
            fire_markers.append((path, "done", "SUCCESS"))

    audit_entries = [
        p
        for p in _glob(events_dir, "cron-fire-audit-*.json")
        if window_start <= os.path.getmtime(p) <= window_end
    ]

    violations: List[dict] = []
    if success_marker_count >= 1 and failure_marker_count >= 1:
        violations.append(
            {
                "violation": "EXACTLY_ONE_TERMINAL_VIOLATION_FAILURE_AND_DONE_CONCURRENT",
                "success_marker_count": success_marker_count,
                "failure_marker_count": failure_marker_count,
            }
        )
    classes_with_fires = [c for c, n in class_counts.items() if n >= 1]
    if len(classes_with_fires) >= 2:
        violations.append(
            {
                "violation": "EXACTLY_ONE_TERMINAL_VIOLATION_MULTI_CLASS_CONCURRENT",
                "classes": classes_with_fires,
            }
        )
    failure_envelope_non_success = any(
        cls == "failure_envelope" and ts and ts != "SUCCESS"
        for (_, cls, ts) in fire_markers
    )
    if failure_envelope_non_success and class_counts["done"] >= 1:
        violations.append(
            {"violation": "EXACTLY_ONE_TERMINAL_VIOLATION_FAILURE_STATE_WITH_DONE"}
        )
    total_fires = success_marker_count + failure_marker_count
    if total_fires > len(audit_entries):
        violations.append(
            {
                "violation": "AUDIT_LOG_MISSING_FOR_CRON_FIRE",
                "total_fires": total_fires,
                "audit_entries": len(audit_entries),
            }
        )
    return violations


if __name__ == "__main__":  # pragma: no cover - CLI smoke
    import argparse

    ap = argparse.ArgumentParser(description="task-2712 failure callback dispatcher")
    ap.add_argument("command", choices=["handoff", "crash", "verify-bypass"])
    ap.add_argument("--task-id", default="task-2712")
    ap.add_argument("--terminal-state", default="INFRA_DEFECT")
    ap.add_argument("--failure-kind", default="")
    ap.add_argument("--exit-code", type=int, default=1)
    ap.add_argument("--phase", default="")
    ap.add_argument(
        "--events-dir",
        default=os.environ.get(
            "FAILURE_CALLBACK_2712_EVENTS_DIR", "/home/jay/workspace/memory/events"
        ),
    )
    a = ap.parse_args()
    if a.command == "handoff":
        print(
            json.dumps(
                write_handoff_marker(
                    a.task_id,
                    a.terminal_state,
                    failure_kind=a.failure_kind,
                    events_dir=a.events_dir,
                    phase=a.phase,
                    exit_code=a.exit_code,
                ),
                ensure_ascii=False,
            )
        )
    elif a.command == "crash":
        print(
            json.dumps(
                write_supervisor_crash_marker(
                    a.task_id,
                    exit_code=a.exit_code,
                    failure_kind=a.failure_kind or "sigkill_or_oom",
                    events_dir=a.events_dir,
                    phase=a.phase,
                ),
                ensure_ascii=False,
            )
        )
    else:
        import time

        print(
            json.dumps(
                detect_bypass_via_count_mismatch(0, time.time() + 1, a.events_dir),
                ensure_ascii=False,
            )
        )
