#!/usr/bin/env python3
"""
impact_scanner.py — Reverse impact scanner.

Given a set of modified files, extracts symbols (functions/classes) from them,
then greps the entire project to find OTHER files that reference those symbols
but were NOT modified. Detects cases where a bot modifies file A but misses
file B that also references the same symbol.
"""

import argparse
import ast
import json
import logging
import re
import subprocess
import sys
import time
from pathlib import Path

logger = logging.getLogger(__name__)

COMMON_FILTER = {
    "data", "result", "config", "props", "state", "error", "value",
    "item", "items", "list", "name", "path", "type", "id", "key",
    "index", "event", "options", "response", "request"
}

EXCLUDE_DIRS = ["node_modules", ".git", "__pycache__", "dist", "build", ".worktrees", ".next"]


def extract_symbols_python(file_path: str, diff_lines: list) -> list:
    """Extract function/class symbol names from a Python file using the ast module.

    If diff_lines is provided, only returns symbols whose definition lines
    overlap with the changed lines. If diff_lines is empty, returns all
    top-level function and class names.

    Args:
        file_path: Absolute or relative path to the Python source file.
        diff_lines: List of changed line numbers (1-based). Empty means all.

    Returns:
        List of symbol name strings. Empty list on parse error.
    """
    try:
        source = Path(file_path).read_text(encoding="utf-8", errors="replace")
        tree = ast.parse(source, filename=file_path)
    except (OSError, SyntaxError) as exc:
        logger.debug("Failed to parse %s: %s", file_path, exc)
        return []

    diff_set = set(diff_lines)
    symbols = []

    for node in ast.walk(tree):
        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            continue
        if diff_set:
            node_lines = set(range(node.lineno, getattr(node, "end_lineno", node.lineno) + 1))
            if not node_lines.intersection(diff_set):
                continue
        symbols.append(node.name)

    return symbols


def extract_symbols_typescript(file_path: str, diff_lines: list) -> list:
    """Extract exported symbol names from a TypeScript/JavaScript file via regex.

    Matches patterns like:
        export function Foo, export class Bar, export const baz, etc.

    If diff_lines is provided, only returns symbols from matching lines.
    If diff_lines is empty, returns all exported symbols.

    Args:
        file_path: Absolute or relative path to the TS/JS source file.
        diff_lines: List of changed line numbers (1-based). Empty means all.

    Returns:
        List of symbol name strings. Empty list on read error.
    """
    pattern = re.compile(r'export\s+(function|class|type|interface|const|enum)\s+(\w+)')
    try:
        lines = Path(file_path).read_text(encoding="utf-8", errors="replace").splitlines()
    except OSError as exc:
        logger.debug("Failed to read %s: %s", file_path, exc)
        return []

    diff_set = set(diff_lines)
    symbols = []

    for lineno, line in enumerate(lines, start=1):
        if diff_set and lineno not in diff_set:
            continue
        match = pattern.search(line)
        if match:
            symbols.append(match.group(2))

    return symbols


def grep_references(symbol: str, project_root: str, exclude_files: list) -> list:
    """Search the project for references to a symbol in Python/TS/JS files.

    Skips node_modules, .git, __pycache__, dist, build, .worktrees, .next
    directories and also skips the files listed in exclude_files.

    Args:
        symbol: The symbol name to search for.
        project_root: Root directory of the project to search.
        exclude_files: List of file paths (relative or absolute) to skip.

    Returns:
        List of dicts with keys "file", "line", "content".
        Returns empty list on timeout or error.
    """
    cmd = [
        "grep", "-rn",
        "--include=*.py",
        "--include=*.ts",
        "--include=*.tsx",
        "--include=*.js",
        "--include=*.jsx",
    ]
    for d in EXCLUDE_DIRS:
        cmd += ["--exclude-dir", d]

    cmd += [r'\b' + symbol + r'\b', project_root]

    try:
        proc = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=3,
        )
    except subprocess.TimeoutExpired:
        logger.warning("grep timed out for symbol '%s'", symbol)
        return []
    except OSError as exc:
        logger.warning("grep failed for symbol '%s': %s", symbol, exc)
        return []

    # Normalise exclude_files to absolute paths for comparison
    abs_excludes = set()
    for ef in exclude_files:
        p = Path(ef)
        abs_excludes.add(str(p.resolve()))
        abs_excludes.add(str(p))  # keep original form too

    results = []
    for raw_line in proc.stdout.splitlines():
        # grep output format: <file>:<lineno>:<content>
        parts = raw_line.split(":", 2)
        if len(parts) < 3:
            continue
        file_path, lineno_str, content = parts[0], parts[1], parts[2]
        try:
            lineno = int(lineno_str)
        except ValueError:
            continue

        abs_file = str(Path(file_path).resolve())
        if abs_file in abs_excludes or file_path in abs_excludes:
            continue

        results.append({"file": file_path, "line": lineno, "content": content.strip()})

    return results


def _parse_diff_lines(diff_output: str) -> list:
    """Parse `git diff` output and return added/changed line numbers in the new file.

    Reads @@ hunk headers of the form @@ -a,b +c,d @@ and collects the range
    [c, c+d) as changed lines.

    Args:
        diff_output: Raw text output of a git diff command.

    Returns:
        Sorted list of 1-based line numbers.
    """
    hunk_re = re.compile(r'@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@')
    lines = set()
    for match in hunk_re.finditer(diff_output):
        start = int(match.group(1))
        count = int(match.group(2)) if match.group(2) is not None else 1
        for i in range(start, start + count):
            lines.add(i)
    return sorted(lines)


def get_modified_files_from_git(project_root: str) -> list:
    """Return a list of file paths modified relative to HEAD in the given repo.

    Tries `git diff --name-only HEAD` first; falls back to
    `git diff --name-only --cached` if that fails.

    Args:
        project_root: Root directory of the git repository.

    Returns:
        List of file path strings relative to project_root. Empty on error.
    """
    for cmd in (
        ["git", "diff", "--name-only", "HEAD"],
        ["git", "diff", "--name-only", "--cached"],
    ):
        try:
            proc = subprocess.run(
                cmd,
                cwd=project_root,
                capture_output=True,
                text=True,
                timeout=10,
            )
            if proc.returncode == 0:
                files = [f.strip() for f in proc.stdout.splitlines() if f.strip()]
                if files:
                    return files
        except (subprocess.TimeoutExpired, OSError):
            pass
    return []


def scan(project_root: str, modified_files: list, task_id: str = "", max_symbols: int = 5, timeout: int = 15) -> dict:
    """Orchestrate symbol extraction and reference scanning for modified files.

    For each modified file:
      1. Determines the language from the extension.
      2. Obtains changed line numbers via `git diff HEAD -- <file>`.
      3. Extracts symbols, filters common names, limits to max_symbols.
      4. Greps the project for references in unmodified files.

    Gate results:
      - "PASS"  — 0 unmodified references found
      - "WARN"  — 1-5 unmodified references found
      - "BLOCK" — 6+ unmodified references found

    Args:
        project_root: Root directory of the project.
        modified_files: List of modified file paths (relative to project_root).
        task_id: Optional task identifier included in the output dict.
        max_symbols: Maximum number of symbols to check per file.
        timeout: Overall time budget in seconds.

    Returns:
        Dict with keys: task_id, gate_result, unmodified_references, symbols_checked.
    """
    deadline = time.monotonic() + timeout
    unmodified_references: list = []
    symbols_checked: list = []

    abs_modified = []
    for mf in modified_files:
        p = Path(mf)
        if not p.is_absolute():
            p = Path(project_root) / p
        abs_modified.append(str(p.resolve()))

    for rel_file in modified_files:
        if time.monotonic() >= deadline:
            logger.warning("Overall timeout reached; stopping early.")
            return {
                "task_id": task_id,
                "gate_result": "WARN",
                "unmodified_references": unmodified_references,
                "symbols_checked": symbols_checked,
            }

        file_path = str(Path(project_root) / rel_file) if not Path(rel_file).is_absolute() else rel_file
        suffix = Path(rel_file).suffix.lower()

        # Determine diff_lines via git
        diff_lines: list = []
        try:
            proc = subprocess.run(
                ["git", "diff", "HEAD", "--", rel_file],
                cwd=project_root,
                capture_output=True,
                text=True,
                timeout=5,
            )
            if proc.returncode == 0:
                diff_lines = _parse_diff_lines(proc.stdout)
        except (subprocess.TimeoutExpired, OSError):
            pass  # treat as full-file scan

        # Extract symbols
        if suffix == ".py":
            raw_symbols = extract_symbols_python(file_path, diff_lines)
        elif suffix in (".ts", ".tsx", ".js", ".jsx"):
            raw_symbols = extract_symbols_typescript(file_path, diff_lines)
        else:
            logger.debug("Unsupported extension '%s' for %s; skipping.", suffix, rel_file)
            continue

        # Filter common names and limit count
        filtered = [s for s in raw_symbols if s not in COMMON_FILTER]
        selected = filtered[:max_symbols]
        symbols_checked.extend(selected)

        for symbol in selected:
            if time.monotonic() >= deadline:
                logger.warning("Overall timeout reached during grep phase.")
                return {
                    "task_id": task_id,
                    "gate_result": "WARN",
                    "unmodified_references": unmodified_references,
                    "symbols_checked": symbols_checked,
                }
            refs = grep_references(symbol, project_root, abs_modified)
            unmodified_references.extend(refs)

    ref_count = len(unmodified_references)
    if ref_count == 0:
        gate = "PASS"
    elif ref_count <= 5:
        gate = "WARN"
    else:
        gate = "BLOCK"

    return {
        "task_id": task_id,
        "gate_result": gate,
        "unmodified_references": unmodified_references,
        "symbols_checked": symbols_checked,
    }


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")

    parser = argparse.ArgumentParser(description="Impact Scanner - reverse dependency check")
    parser.add_argument("--project-root", required=True)
    parser.add_argument("--task-id", default="")
    parser.add_argument("--max-symbols", type=int, default=5)
    parser.add_argument("--timeout", type=int, default=15)
    args = parser.parse_args()

    modified_files = get_modified_files_from_git(args.project_root)
    result = scan(
        args.project_root,
        modified_files,
        args.task_id,
        max_symbols=args.max_symbols,
        timeout=args.timeout,
    )

    print(json.dumps(result, ensure_ascii=False))

    if result["gate_result"] == "PASS":
        sys.exit(0)
    elif result["gate_result"] == "WARN":
        sys.exit(1)
    else:  # BLOCK
        sys.exit(2)
