#!/usr/bin/env python3
"""
absorption-health-check.py

Loads absorption-registry.yaml and runs health checks for each item,
then outputs a JSON status report.

Usage:
    python3 absorption-health-check.py [--source SOURCE] [--status STATUS]
                                        [--json] [--summary]
"""

import argparse
import json
import subprocess
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Optional

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

WORKSPACE_ROOT = Path("/home/jay/workspace")
HOME_DIR = Path("/home/jay")
REGISTRY_PATH = WORKSPACE_ROOT / "config" / "absorption-registry.yaml"

# ---------------------------------------------------------------------------
# YAML loading (PyYAML preferred; fallback to a minimal parser)
# ---------------------------------------------------------------------------


def load_yaml(path: Path) -> dict[str, Any]:
    """Load a YAML file using PyYAML, with a graceful fallback error."""
    try:
        import yaml  # type: ignore

        with open(path, "r", encoding="utf-8") as fh:
            return yaml.safe_load(fh)
    except ImportError:
        # Minimal fallback: only works for simple structures; warn user.
        print(
            "WARNING: PyYAML not installed. Install it with: pip install pyyaml",
            file=sys.stderr,
        )
        raise SystemExit(1)
    except FileNotFoundError:
        print(f"ERROR: Registry file not found: {path}", file=sys.stderr)
        raise SystemExit(1)


# ---------------------------------------------------------------------------
# Path helpers
# ---------------------------------------------------------------------------


def resolve_path(target: str) -> Path:
    """
    Resolve a target path string to an absolute Path.
    - Paths starting with ~/ are expanded to HOME_DIR
    - Absolute paths (starting with /) are used as-is
    - Relative paths are resolved against WORKSPACE_ROOT
    """
    if target.startswith("~/"):
        return HOME_DIR / target[2:]
    p = Path(target)
    if p.is_absolute():
        return p
    return WORKSPACE_ROOT / target


# ---------------------------------------------------------------------------
# Individual health check runners
# ---------------------------------------------------------------------------


def check_file_exists(target: str) -> tuple[str, str]:
    """Return ('pass'|'fail', detail_message)."""
    path = resolve_path(target)
    if path.exists():
        return "pass", f"exists: {path}"
    return "fail", f"not found: {path}"


def check_file_recent_activity(target: str, max_age_hours: int) -> tuple[str, str]:
    """Check that the file's mtime is within max_age_hours."""
    path = resolve_path(target)
    if not path.exists():
        return "fail", f"not found: {path}"
    try:
        mtime = path.stat().st_mtime
        age_hours = (datetime.now().timestamp() - mtime) / 3600
        if age_hours <= max_age_hours:
            return "pass", f"last modified {age_hours:.1f}h ago (limit {max_age_hours}h)"
        return "fail", f"last modified {age_hours:.1f}h ago (limit {max_age_hours}h)"
    except OSError as exc:
        return "fail", f"stat error: {exc}"


def check_grep_pattern(target: str, pattern: str) -> tuple[str, str]:
    """
    Search for pattern in target (file or directory).
    Uses subprocess grep for reliability.
    """
    path = resolve_path(target)
    if not path.exists():
        return "fail", f"target not found: {path}"

    grep_args = ["grep", "-r", "--include=*", "-l", "-E", pattern, str(path)]
    # If target is a file, don't use -r
    if path.is_file():
        grep_args = ["grep", "-E", "-l", pattern, str(path)]

    try:
        result = subprocess.run(
            grep_args,
            capture_output=True,
            text=True,
            timeout=15,
        )
        if result.returncode == 0 and result.stdout.strip():
            matches = result.stdout.strip().splitlines()
            return "pass", f"pattern found in: {', '.join(matches[:3])}"
        return "fail", f"pattern '{pattern}' not found in {path}"
    except subprocess.TimeoutExpired:
        return "fail", f"grep timed out on {path}"
    except Exception as exc:
        return "fail", f"grep error: {exc}"


def check_audit_trail_recent(target: str, max_age_hours: int) -> tuple[str, str]:
    """Check that audit-trail.jsonl has entries within max_age_hours."""
    path = resolve_path(target)
    if not path.exists():
        return "fail", f"audit trail not found: {path}"

    try:
        cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
        recent_count = 0
        with open(path, "r", encoding="utf-8") as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                try:
                    entry = json.loads(line)
                    # Look for any timestamp-like field
                    ts_str: Optional[str] = entry.get("timestamp") or entry.get("ts") or entry.get("time")
                    if ts_str:
                        # Normalize trailing Z
                        ts_str = ts_str.replace("Z", "+00:00")
                        ts = datetime.fromisoformat(ts_str)
                        if ts.tzinfo is None:
                            ts = ts.replace(tzinfo=timezone.utc)
                        if ts >= cutoff:
                            recent_count += 1
                except (json.JSONDecodeError, ValueError):
                    continue

        if recent_count > 0:
            return "pass", f"{recent_count} entries within last {max_age_hours}h"
        return "fail", f"no entries within last {max_age_hours}h in {path}"
    except OSError as exc:
        return "fail", f"read error: {exc}"


def check_process_running(target: str) -> tuple[str, str]:
    """Check if a process matching target name is running."""
    try:
        result = subprocess.run(
            ["ps", "aux"],
            capture_output=True,
            text=True,
            timeout=10,
        )
        lines = [ln for ln in result.stdout.splitlines() if target in ln and "grep" not in ln]
        if lines:
            return "pass", f"process '{target}' found ({len(lines)} match(es))"
        return "fail", f"process '{target}' not running"
    except Exception as exc:
        return "fail", f"ps error: {exc}"


# ---------------------------------------------------------------------------
# Dispatch health check
# ---------------------------------------------------------------------------


def run_health_check(hc: dict[str, Any]) -> tuple[str, str]:
    """
    Dispatch to the appropriate checker based on hc['type'].
    Returns (result, detail) where result is 'pass', 'fail', or 'skip'.
    """
    check_type: str = hc.get("type", "")
    target: str = hc.get("target", "")

    try:
        if check_type == "file_exists":
            return check_file_exists(target)

        elif check_type == "file_recent_activity":
            max_age: int = int(hc.get("max_age_hours", 24))
            return check_file_recent_activity(target, max_age)

        elif check_type == "grep_pattern":
            pattern: str = hc.get("pattern", "")
            if not pattern:
                return "skip", "no pattern specified"
            return check_grep_pattern(target, pattern)

        elif check_type == "audit_trail_recent":
            max_age = int(hc.get("max_age_hours", 24))
            return check_audit_trail_recent(target, max_age)

        elif check_type == "process_running":
            return check_process_running(target)

        else:
            return "skip", f"unknown check type: {check_type}"

    except Exception as exc:
        return "fail", f"unexpected error: {exc}"


# ---------------------------------------------------------------------------
# Status update logic
# ---------------------------------------------------------------------------


def compute_effective_status(declared_status: str, hc_result: str) -> str:
    """
    Adjust status based on health check result:
    - active + fail  → degraded
    - everything else stays as declared
    """
    if declared_status == "active" and hc_result == "fail":
        return "degraded"
    return declared_status


# ---------------------------------------------------------------------------
# Main processing
# ---------------------------------------------------------------------------


def process_registry(
    registry: dict[str, Any],
    source_filter: Optional[str] = None,
    status_filter: Optional[str] = None,
) -> dict[str, Any]:
    """
    Iterate sources and items, run health checks, build result structure.
    """
    sources_data: dict[str, Any] = registry.get("sources", {})
    duplicates_raw: list[dict[str, Any]] = registry.get("duplicates", [])

    all_items: list[dict[str, Any]] = []
    by_source: dict[str, dict[str, int]] = {}

    STATUS_KEYS = ("active", "implemented", "implementing", "recommended", "deferred", "degraded", "duplicate", "absorbed", "archived", "skipped")

    for source_name, source_info in sources_data.items():
        # Apply source filter
        if source_filter and source_name != source_filter:
            continue

        items: list[dict[str, Any]] = source_info.get("items", [])
        source_counts: dict[str, int] = {k: 0 for k in STATUS_KEYS}
        source_counts["total"] = 0

        for item in items:
            declared_status: str = item.get("status", "recommended")
            hc_config: dict[str, Any] = item.get("health_check", {})

            hc_result, hc_detail = run_health_check(hc_config)
            effective_status = compute_effective_status(declared_status, hc_result)

            # Apply status filter after computing effective status
            if status_filter and effective_status != status_filter:
                continue

            source_counts["total"] += 1
            if effective_status in source_counts:
                source_counts[effective_status] += 1

            item_out: dict[str, Any] = {
                "id": item.get("id", ""),
                "name": item.get("name", ""),
                "source": source_name,
                "priority": item.get("priority", ""),
                "declared_status": declared_status,
                "status": effective_status,
                "health_check_result": hc_result,
                "health_check_detail": hc_detail,
            }
            if item.get("notes"):
                item_out["notes"] = item["notes"]
            if item.get("implemented_in"):
                item_out["implemented_in"] = item["implemented_in"]
            all_items.append(item_out)

        if source_counts["total"] > 0 or not status_filter:
            by_source[source_name] = source_counts

    # Summary
    summary: dict[str, int] = {k: 0 for k in STATUS_KEYS}
    summary["total"] = len(all_items)
    for item in all_items:
        st = item["status"]
        if st in summary:
            summary[st] += 1

    # Format duplicates
    duplicates_out: list[dict[str, Any]] = []
    for d in duplicates_raw:
        entry: dict[str, Any] = {"items": d.get("items", []), "description": d.get("description", "")}
        if d.get("resolution"):
            entry["resolution"] = d["resolution"]
        duplicates_out.append(entry)

    return {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "summary": summary,
        "by_source": by_source,
        "duplicates": duplicates_out,
        "items": all_items,
    }


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Run health checks on absorption-registry.yaml items.")
    parser.add_argument(
        "--source",
        metavar="SOURCE",
        help="Filter results to a specific source (e.g. fireauto, gstack)",
    )
    parser.add_argument(
        "--status",
        metavar="STATUS",
        choices=["active", "implemented", "recommended", "deferred", "degraded", "duplicate", "absorbed", "archived"],
        help="Filter results to a specific status",
    )
    parser.add_argument(
        "--json",
        dest="output_json",
        action="store_true",
        default=True,
        help="Output as JSON (default)",
    )
    parser.add_argument(
        "--summary",
        action="store_true",
        default=False,
        help="Output summary only (no per-item details)",
    )
    return parser


def main() -> None:
    parser = build_parser()
    args = parser.parse_args()

    registry = load_yaml(REGISTRY_PATH)
    result = process_registry(
        registry,
        source_filter=args.source,
        status_filter=args.status,
    )

    if args.summary:
        output: dict[str, Any] = {
            "timestamp": result["timestamp"],
            "summary": result["summary"],
            "by_source": result["by_source"],
            "duplicates": result["duplicates"],
        }
    else:
        output = result

    print(json.dumps(output, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
