"""v3.6 Runtime Harness — Layer 2: Spawn Detector.

chair_authorization_id=CHAIR-AUTH-TASK-2704-V36-CONTROL-PLANE-P0-MVP-260528

Contract:
- detect_spawn_state(task_id, schedule_id) -> dict
  Returns {"state": <9-state enum>, "signals": {...}, "reason": str}

- 9 states: NOT_REGISTERED, REGISTERED, FIRED, SESSION_SEEN, WORK_STARTED,
            ARTIFACT_SEEN, CALLBACK_REGISTERED, DONE, UNKNOWN

- Cross-validation across ≥2 signals required for most states.
- CRITICAL: When signals insufficient/conflicting → UNKNOWN.
  "spawn 0" single-source assertion is FORBIDDEN.
- No backward state transitions (only forward or → UNKNOWN).
- system_prompt path direct comparison — NO grep -v cokacdir filter (known bug).
"""
from __future__ import annotations

import glob
import json
import os
import subprocess
import time
from typing import Optional

CHAIR_AUTHORIZATION_ID = "CHAIR-AUTH-TASK-2704-V36-CONTROL-PLANE-P0-MVP-260528"

# ─── State enum ───────────────────────────────────────────────────────────────
NOT_REGISTERED = "NOT_REGISTERED"
REGISTERED = "REGISTERED"
FIRED = "FIRED"
SESSION_SEEN = "SESSION_SEEN"
WORK_STARTED = "WORK_STARTED"
ARTIFACT_SEEN = "ARTIFACT_SEEN"
CALLBACK_REGISTERED = "CALLBACK_REGISTERED"
DONE = "DONE"
UNKNOWN = "UNKNOWN"

# Ordered progression (UNKNOWN is a special escape valve — not in chain)
_STATE_ORDER = [
    NOT_REGISTERED,
    REGISTERED,
    FIRED,
    SESSION_SEEN,
    WORK_STARTED,
    ARTIFACT_SEEN,
    CALLBACK_REGISTERED,
    DONE,
]

# State ordinal for ≥ comparisons
_STATE_ORDINAL = {s: i for i, s in enumerate(_STATE_ORDER)}

_WORKSPACE = "/home/jay/workspace"
_EVENTS_DIR = os.path.join(_WORKSPACE, "memory/events")
_TIMERS_FILE = os.path.join(_WORKSPACE, "memory/task-timers.json")
_COKACDIR = "/home/jay/.cokacdir"


# ─── Public API ───────────────────────────────────────────────────────────────

def detect_spawn_state(
    task_id: str,
    schedule_id: Optional[str] = None,
    events_dir: Optional[str] = None,
    timers_file: Optional[str] = None,
    cokacdir: Optional[str] = None,
) -> dict:
    """Detect the spawn state for task_id using multi-signal cross-validation.

    Returns:
        {
            "state": str,         # 9-state enum value
            "signals": dict,      # raw signal values
            "reason": str,        # human-readable decision reason
            "signal_count": int,  # number of positive signals observed
        }

    Never raises. Returns UNKNOWN on any internal error.
    """
    try:
        return _detect_spawn_state_impl(
            task_id=task_id,
            schedule_id=schedule_id,
            events_dir=events_dir or _EVENTS_DIR,
            timers_file=timers_file or _TIMERS_FILE,
            cokacdir=cokacdir or _COKACDIR,
        )
    except Exception as exc:
        return {
            "state": UNKNOWN,
            "signals": {},
            "reason": f"detect_spawn_state internal error (safe-fail): {exc}",
            "signal_count": 0,
        }


def state_ge(state_a: str, state_b: str) -> bool:
    """Return True if state_a is >= state_b in the progression order.

    UNKNOWN is treated as incomparable (returns False for any comparison).
    """
    if state_a == UNKNOWN or state_b == UNKNOWN:
        return False
    return _STATE_ORDINAL.get(state_a, -1) >= _STATE_ORDINAL.get(state_b, -1)


# ─── Internal implementation ──────────────────────────────────────────────────

def _detect_spawn_state_impl(
    task_id: str,
    schedule_id: Optional[str],
    events_dir: str,
    timers_file: str,
    cokacdir: str,
) -> dict:
    now = time.time()
    signals: dict = {}

    # ── Signal 1: task_id in task-timers.json ────────────────────────────────
    timer_status = _get_timer_status(task_id, timers_file)
    signals["timer_status"] = timer_status  # absent|running|completed|cancelled|unknown
    signals["timer_present"] = timer_status not in ("absent",)

    # ── Signal 2: schedule_id (passed in or from timer metadata) ─────────────
    if not schedule_id:
        schedule_id = _get_schedule_id_from_timer(task_id, timers_file)
    signals["schedule_id"] = schedule_id
    signals["schedule_id_present"] = bool(schedule_id)

    # ── Signal 3: dispatch marker in events_dir ───────────────────────────────
    dispatch_markers = glob.glob(os.path.join(events_dir, f"{task_id}.dispatched*"))
    signals["dispatch_marker_present"] = len(dispatch_markers) > 0
    signals["dispatch_marker_paths"] = dispatch_markers

    # ── Signal 4: cron schedule_history log ──────────────────────────────────
    schedule_history_present = False
    if schedule_id:
        history_path = os.path.join(cokacdir, "schedule_history", f"{schedule_id}.log")
        schedule_history_present = os.path.isfile(history_path)
    signals["schedule_history_present"] = schedule_history_present

    # ── Signal 5: bot session process (system_prompt path direct compare) ─────
    # ★ NO grep -v cokacdir filter — direct path comparison only
    anu_session_hex = _get_anu_session_hex()
    bot_session_active = False
    bot_session_hex = None
    bot_processes = _list_claude_processes()
    for proc_hex in bot_processes:
        if proc_hex and proc_hex != anu_session_hex:
            # A claude process with a DIFFERENT system_prompt hex = a bot session
            bot_session_active = True
            bot_session_hex = proc_hex
            break
    signals["bot_session_active"] = bot_session_active
    signals["bot_session_hex"] = bot_session_hex
    signals["anu_session_hex"] = anu_session_hex

    # ── Signal 6: worktree mtime ──────────────────────────────────────────────
    worktree_mtime_seconds = _get_worktree_mtime_seconds(task_id, now)
    signals["worktree_mtime_seconds"] = worktree_mtime_seconds
    signals["worktree_fresh"] = worktree_mtime_seconds is not None and worktree_mtime_seconds < 900

    # ── Signal 7: events/artifact mtime ──────────────────────────────────────
    artifact_mtime_seconds = _get_artifact_mtime_seconds(task_id, events_dir, now)
    signals["artifact_mtime_seconds"] = artifact_mtime_seconds
    signals["artifact_fresh"] = artifact_mtime_seconds is not None and artifact_mtime_seconds < 900

    # ── Signal 8: done markers ────────────────────────────────────────────────
    done_markers = glob.glob(os.path.join(events_dir, f"{task_id}.done*"))
    signals["done_markers"] = done_markers
    signals["done_marker_present"] = len(done_markers) > 0

    # ── Signal 9: callback markers ────────────────────────────────────────────
    callback_markers = glob.glob(os.path.join(events_dir, f"{task_id}.callback-*"))
    signals["callback_marker_present"] = len(callback_markers) > 0

    # ── Signal 10: bot output file ────────────────────────────────────────────
    bot_output_present = False
    if schedule_id:
        bot_output_path = os.path.join(cokacdir, f"{schedule_id}.output")
        bot_output_present = os.path.isfile(bot_output_path)
    signals["bot_output_present"] = bot_output_present

    # ─── State resolution ──────────────────────────────────────────────────────
    # Per spec §1.3: minimum 2-signal OR weighted entries for most states.
    # CRITICAL: insufficient/conflicting signals → UNKNOWN.

    # DONE: .done + task-timers completed (2 signals)
    if signals["done_marker_present"] and timer_status in ("completed", "completed_owner_decision_accepted"):
        return _result(DONE, signals, "done_marker_present + timer completed", _pos_count(signals))

    # CALLBACK_REGISTERED: callback marker + (done or done-like) (2 signals)
    if signals["callback_marker_present"] and (
        signals["done_marker_present"] or timer_status in ("completed", "completed_owner_decision_accepted")
    ):
        return _result(CALLBACK_REGISTERED, signals, "callback_marker_present + done context", _pos_count(signals))

    # ARTIFACT_SEEN: fresh artifact mtime < 900s (1 signal sufficient per spec)
    if signals["artifact_fresh"]:
        return _result(ARTIFACT_SEEN, signals, "recent artifact mtime < 900s", _pos_count(signals))

    # WORK_STARTED: dispatch marker + (fresh worktree OR fresh artifact OR bot session) (2 signals)
    work_started_signals = sum([
        signals["dispatch_marker_present"],
        signals["worktree_fresh"],
        signals["bot_session_active"],
        bot_output_present,
    ])
    if signals["dispatch_marker_present"] and work_started_signals >= 2:
        return _result(WORK_STARTED, signals, "dispatch_marker + corroborating work signal", _pos_count(signals))

    # SESSION_SEEN: bot session with different system_prompt hex + process active (2 signals)
    if bot_session_active and bot_session_hex:
        return _result(SESSION_SEEN, signals, "bot_session_active with distinct system_prompt hex", _pos_count(signals))

    # FIRED: fire_time passed + schedule_history OR dispatch_marker (2 signals)
    if schedule_id and (schedule_history_present or signals["dispatch_marker_present"]):
        return _result(FIRED, signals, "schedule_id + schedule_history or dispatch_marker", _pos_count(signals))

    # REGISTERED: timer running + schedule_id present (2 signals)
    if signals["timer_present"] and signals["schedule_id_present"]:
        return _result(REGISTERED, signals, "timer_running + schedule_id_present", _pos_count(signals))

    # NOT_REGISTERED: ALL 4 key signals absent
    # (timer absent AND schedule_id absent AND session absent AND artifact absent)
    key_absent = (
        not signals["timer_present"]
        and not signals["schedule_id_present"]
        and not signals["bot_session_active"]
        and not signals["dispatch_marker_present"]
        and not signals["artifact_fresh"]
        and not signals["done_marker_present"]
    )
    if key_absent:
        return _result(NOT_REGISTERED, signals, "all key signals absent (timer, schedule_id, session, artifact)", _pos_count(signals))

    # UNKNOWN: signals present but insufficient/conflicting for a firm determination
    return _result(
        UNKNOWN,
        signals,
        "signals present but insufficient or conflicting for firm state determination — poll next sample",
        _pos_count(signals),
    )


def _result(state: str, signals: dict, reason: str, signal_count: int) -> dict:
    return {
        "state": state,
        "signals": signals,
        "reason": reason,
        "signal_count": signal_count,
    }


def _pos_count(signals: dict) -> int:
    """Count positive boolean signals."""
    count = 0
    for v in signals.values():
        if v is True:
            count += 1
    return count


def _get_timer_status(task_id: str, timers_file: str) -> str:
    try:
        if not os.path.isfile(timers_file):
            return "absent"
        with open(timers_file, encoding="utf-8") as fh:
            data = json.load(fh)
        tasks = data.get("tasks", {})
        if task_id not in tasks:
            return "absent"
        return tasks[task_id].get("status", "unknown")
    except Exception:
        return "unknown"


def _get_schedule_id_from_timer(task_id: str, timers_file: str) -> Optional[str]:
    try:
        if not os.path.isfile(timers_file):
            return None
        with open(timers_file, encoding="utf-8") as fh:
            data = json.load(fh)
        tasks = data.get("tasks", {})
        if task_id not in tasks:
            return None
        return tasks[task_id].get("schedule_id")
    except Exception:
        return None


def _get_anu_session_hex() -> Optional[str]:
    """Get the hex identifier of the current ANU (orchestrator) session.

    Reads from the ANU_SESSION_HEX env var (set by the session itself),
    or attempts to identify from the CLAUDE_SESSION_ID / system_prompt path.
    Returns None if unavailable.
    """
    # Env var takes priority (set by ANU session startup)
    env_hex = os.environ.get("ANU_SESSION_HEX", "")
    if env_hex:
        return env_hex

    # Fall back: inspect /proc/self/cmdline for system_prompt path
    try:
        with open("/proc/self/cmdline", "rb") as fh:
            cmdline = fh.read().decode("utf-8", errors="replace").replace("\x00", " ")
        # Look for system_prompt_<hex>_ pattern
        import re
        m = re.search(r"system_prompt_([0-9a-f]{16})", cmdline)
        if m:
            return m.group(1)
    except Exception:
        pass
    return None


def _list_claude_processes() -> list[Optional[str]]:
    """Return list of system_prompt hex values for running claude processes.

    Uses direct /proc inspection. NO grep -v cokacdir filter (spec §1.1).
    Returns list of hex strings (may include None for processes without hex).
    """
    import re
    hexes = []
    try:
        # Walk /proc to find claude processes
        for pid_entry in os.listdir("/proc"):
            if not pid_entry.isdigit():
                continue
            try:
                cmdline_path = f"/proc/{pid_entry}/cmdline"
                with open(cmdline_path, "rb") as fh:
                    cmdline = fh.read().decode("utf-8", errors="replace").replace("\x00", " ")
                # Only consider claude CLI processes
                if "claude" not in cmdline:
                    continue
                m = re.search(r"system_prompt_([0-9a-f]{16})", cmdline)
                if m:
                    hexes.append(m.group(1))
                elif "claude" in cmdline:
                    hexes.append(None)
            except (PermissionError, FileNotFoundError):
                continue
    except Exception:
        pass
    return hexes


def _get_worktree_mtime_seconds(task_id: str, now: float) -> Optional[int]:
    """Return seconds since last modification of any task worktree, or None."""
    try:
        worktrees_base = os.path.join(_WORKSPACE, ".worktrees")
        if not os.path.isdir(worktrees_base):
            return None
        pattern = os.path.join(worktrees_base, f"{task_id}-*")
        worktrees = glob.glob(pattern)
        if not worktrees:
            return None
        latest_mtime = max(os.path.getmtime(p) for p in worktrees if os.path.exists(p))
        return int(now - latest_mtime)
    except Exception:
        return None


def _get_artifact_mtime_seconds(task_id: str, events_dir: str, now: float) -> Optional[int]:
    """Return seconds since last modification of any task-related event file, or None."""
    try:
        pattern = os.path.join(events_dir, f"{task_id}.*")
        files = glob.glob(pattern)
        if not files:
            return None
        latest_mtime = max(os.path.getmtime(p) for p in files if os.path.exists(p))
        return int(now - latest_mtime)
    except Exception:
        return None
