"""Axis 3 restricted canary - tool call classifier (canonical, PYTHONPATH-agnostic).

chair_authorization_id = CHAIR-AUTH-AXIS-3-CANARY-20260524-JJONGS-RESTRICTED-001
Pure function classifier. No I/O, no side effects, no globals mutated.
"""

from __future__ import annotations

import os
import re
import shlex
from typing import Iterable, Mapping, Optional, Tuple

# Self-contained import of policy map. PYTHONPATH-agnostic: rely only on
# sibling-module import via the package root resolved by the hook.
try:
    from utils.runtime_guard_policy_map import (
        CREDENTIAL_PATTERNS,
        DECISION_AUDIT_ONLY,
        DECISION_BLOCK,
        DECISION_WARN,
        FORBIDDEN_PATH_GLOB_SUFFIXES,
        FORBIDDEN_PATH_PREFIXES,
        initial_decision_for_category,
    )
except Exception:  # pragma: no cover - fallback when utils not on sys.path
    # Mirror constants so unit tests can still call classifier directly.
    DECISION_AUDIT_ONLY = "AUDIT_ONLY"
    DECISION_WARN = "WARN"
    DECISION_BLOCK = "BLOCK"
    CREDENTIAL_PATTERNS = ()
    FORBIDDEN_PATH_GLOB_SUFFIXES = ()
    FORBIDDEN_PATH_PREFIXES = ()

    def initial_decision_for_category(category: str) -> str:
        if category == "destructive":
            return DECISION_BLOCK
        if category in ("forbidden_path", "credential_pattern"):
            return DECISION_WARN
        return DECISION_AUDIT_ONLY


_CREDENTIAL_REGEXES = tuple(re.compile(p) for p in CREDENTIAL_PATTERNS)

# Bash command surface tokens that indicate destructive intent.
_RM_RF_ROOT_RE = re.compile(
    r"(?:^|\s|;|&&|\|\|)rm\s+(?:-[a-zA-Z]*r[a-zA-Z]*f[a-zA-Z]*|-[a-zA-Z]*f[a-zA-Z]*r[a-zA-Z]*|-rf|-fr)\s+/(?:\s|$|;|&&|\|\|)"
)
_GIT_PUSH_FORCE_MAIN_RE = re.compile(
    r"git\s+push\s+(?:.*\s)?--force(?:[^-]|-with-lease)?[^\n;|&]*?\s(?:origin\s+main|origin/main|main)\b"
)
_GIT_PUSH_FORCE_F_MAIN_RE = re.compile(
    r"git\s+push\s+(?:.*\s)?-f\s+(?:.*\s)?(?:origin\s+main|origin/main|main)\b"
)
_GIT_BRANCH_DELETE_MAIN_RE = re.compile(r"git\s+branch\s+-D\s+(?:.*\s)?main\b")
_COKACDIR_CRON_REMOVE_RE = re.compile(r"cokacdir(?:\s+\S+)*\s+--cron-remove\b")
# git reset --hard <ref> where ref != "HEAD" / current-branch sentinel.
_GIT_RESET_HARD_RE = re.compile(r"git\s+reset\s+--hard\s+(\S+)")


def _is_destructive_git_reset(command: str) -> bool:
    """Block git reset --hard with explicit non-current ref.

    Conservative heuristic: allow HEAD / HEAD~N / @{u} / origin/<current> via
    classifier WARN list, only BLOCK when an explicit other branch/sha is given.
    """
    for match in _GIT_RESET_HARD_RE.finditer(command):
        ref = match.group(1)
        if ref in ("HEAD", "@{u}", "ORIG_HEAD", "FETCH_HEAD"):
            continue
        if re.fullmatch(r"HEAD~\d+", ref):
            continue
        # any other explicit ref (branch, sha, tag) → destructive
        return True
    return False


def _scan_destructive(command: str) -> Optional[Tuple[str, str]]:
    if _RM_RF_ROOT_RE.search(command):
        return ("destructive.rm_rf_root", "rm -rf / detected")
    if _GIT_PUSH_FORCE_MAIN_RE.search(command) or _GIT_PUSH_FORCE_F_MAIN_RE.search(command):
        return (
            "destructive.git_push_force_main",
            "git push --force origin/main detected",
        )
    if _is_destructive_git_reset(command):
        return (
            "destructive.git_reset_hard_other_branch",
            "git reset --hard <not_current_branch> detected",
        )
    if _GIT_BRANCH_DELETE_MAIN_RE.search(command):
        return (
            "destructive.git_branch_delete_main",
            "git branch -D main detected",
        )
    if _COKACDIR_CRON_REMOVE_RE.search(command):
        return (
            "destructive.cokacdir_cron_remove",
            "cokacdir --cron-remove detected",
        )
    return None


def _scan_credentials(text: str) -> Optional[Tuple[str, str]]:
    if not text:
        return None
    for rx in _CREDENTIAL_REGEXES:
        m = rx.search(text)
        if m:
            return ("credential_pattern.match", f"credential pattern matched: {rx.pattern}")
    return None


def _norm_path(p: str) -> str:
    try:
        return os.path.abspath(p)
    except Exception:
        return p


def _scan_forbidden_path(path: str) -> Optional[Tuple[str, str]]:
    if not path:
        return None
    norm = _norm_path(path)
    for prefix in FORBIDDEN_PATH_PREFIXES:
        if norm == prefix or norm.startswith(prefix + "/"):
            return ("forbidden_path.prefix", f"forbidden path prefix: {prefix}")
    lowered = norm.lower()
    for suffix in FORBIDDEN_PATH_GLOB_SUFFIXES:
        if lowered.endswith("/" + suffix) or lowered.endswith("." + suffix) or "/" + suffix in lowered:
            return ("forbidden_path.suffix", f"forbidden path suffix: {suffix}")
    return None


def _extract_tool_paths(tool_name: str, tool_input: Mapping[str, object]) -> Iterable[str]:
    for key in ("file_path", "path", "notebook_path"):
        v = tool_input.get(key)
        if isinstance(v, str) and v:
            yield v
    if tool_name == "Bash":
        cmd = tool_input.get("command", "")
        if isinstance(cmd, str) and cmd:
            try:
                tokens = shlex.split(cmd, posix=True)
            except Exception:
                tokens = cmd.split()
            for t in tokens:
                if t.startswith("/") or t.startswith("~/"):
                    yield t


def classify(
    tool_name: str,
    tool_input: Optional[Mapping[str, object]],
) -> dict:
    """Return decision dict {decision, rule_id, reason, category}.

    Pure function. AUDIT_ONLY when no rule fires. tool_input may be None
    when the hook fails to parse the upstream event payload.
    """
    if tool_input is None:
        tool_input = {}

    # 1) destructive (Bash only realistically)
    if tool_name == "Bash":
        cmd = tool_input.get("command", "")
        if isinstance(cmd, str):
            hit = _scan_destructive(cmd)
            if hit:
                rule_id, reason = hit
                return {
                    "decision": initial_decision_for_category("destructive"),
                    "rule_id": rule_id,
                    "reason": reason,
                    "category": "destructive",
                }

    # 2) forbidden path
    for p in _extract_tool_paths(tool_name, tool_input):
        hit = _scan_forbidden_path(p)
        if hit:
            rule_id, reason = hit
            return {
                "decision": initial_decision_for_category("forbidden_path"),
                "rule_id": rule_id,
                "reason": reason,
                "category": "forbidden_path",
            }

    # 3) credential pattern in text-bearing fields
    for key in ("content", "new_string", "command", "prompt"):
        v = tool_input.get(key)
        if isinstance(v, str):
            hit = _scan_credentials(v)
            if hit:
                rule_id, reason = hit
                return {
                    "decision": initial_decision_for_category("credential_pattern"),
                    "rule_id": rule_id,
                    "reason": reason,
                    "category": "credential_pattern",
                }

    return {
        "decision": DECISION_AUDIT_ONLY,
        "rule_id": "audit.noop",
        "reason": "no rule fired",
        "category": "audit_only",
    }
