#!/usr/bin/env python3
"""Unit tests for impact_scanner.py.

Covers:
  - Scenario 1: Python symbol reverse tracing
  - Scenario 2: TypeScript symbol reverse tracing
  - Scenario 3: COMMON_FILTER behavior
  - Scenario 4: Timeout behavior
  - Scenario 5: Zero references (PASS)
  - Additional edge-case tests
"""

import sys
from pathlib import Path
from unittest.mock import MagicMock, patch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from impact_scanner import (  # noqa: E402
    COMMON_FILTER,
    _parse_diff_lines,
    extract_symbols_python,
    extract_symbols_typescript,
    grep_references,
    scan,
)

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_proc(stdout: str = "", returncode: int = 0, stderr: str = "") -> MagicMock:
    mock = MagicMock()
    mock.returncode = returncode
    mock.stdout = stdout
    mock.stderr = stderr
    return mock


# ---------------------------------------------------------------------------
# Scenario 1: Python symbol reverse tracing
# ---------------------------------------------------------------------------


class TestPythonSymbolReverseTracing:

    def test_extract_symbols_returns_calculate_premium(self, tmp_path):
        src = tmp_path / "premium.py"
        src.write_text("def calculate_premium(age, risk_factor):\n    return age * risk_factor\n")
        symbols = extract_symbols_python(str(src), [])
        assert "calculate_premium" in symbols

    def test_grep_references_finds_referencing_file(self, tmp_path):
        define_file = tmp_path / "premium.py"
        define_file.write_text("def calculate_premium(age, risk_factor):\n    return age * risk_factor\n")
        ref_file = tmp_path / "billing.py"
        ref_file.write_text("from premium import calculate_premium\n\ndef process_billing(age):\n    return calculate_premium(age, 1.5)\n")
        refs = grep_references("calculate_premium", str(tmp_path), [str(define_file)])
        ref_files = [r["file"] for r in refs]
        assert any("billing.py" in f for f in ref_files)

    def test_grep_references_excludes_source_file(self, tmp_path):
        define_file = tmp_path / "premium.py"
        define_file.write_text("def calculate_premium(): pass\n")
        ref_file = tmp_path / "billing.py"
        ref_file.write_text("from premium import calculate_premium\n")
        refs = grep_references("calculate_premium", str(tmp_path), [str(define_file)])
        ref_files = [r["file"] for r in refs]
        assert not any("premium.py" in f for f in ref_files)


# ---------------------------------------------------------------------------
# Scenario 2: TypeScript symbol reverse tracing
# ---------------------------------------------------------------------------


class TestTypescriptSymbolReverseTracing:

    def test_extract_symbols_typescript_returns_feature_gate(self, tmp_path):
        tsx_file = tmp_path / "FeatureGate.tsx"
        tsx_file.write_text("import React from 'react';\n\nexport function FeatureGate({ flag, children }: Props) {\n    if (!flag) return null;\n    return <>{children}</>;\n}\n")
        symbols = extract_symbols_typescript(str(tsx_file), [])
        assert "FeatureGate" in symbols

    def test_grep_references_finds_use_feature_access(self, tmp_path):
        tsx_file = tmp_path / "FeatureGate.tsx"
        tsx_file.write_text("export function FeatureGate({ flag, children }: Props) {\n    return null;\n}\n")
        ref_file = tmp_path / "use-feature-access.ts"
        ref_file.write_text("import { FeatureGate } from './FeatureGate';\n\nexport function useFeatureAccess() {\n    return FeatureGate;\n}\n")
        refs = grep_references("FeatureGate", str(tmp_path), [str(tsx_file)])
        ref_files = [r["file"] for r in refs]
        assert any("use-feature-access.ts" in f for f in ref_files)


# ---------------------------------------------------------------------------
# Scenario 3: COMMON_FILTER behavior
# ---------------------------------------------------------------------------


class TestCommonFilter:

    def test_data_not_in_symbols_checked(self, tmp_path):
        assert "data" in COMMON_FILTER
        src = tmp_path / "utils.py"
        src.write_text("def data():\n    return {}\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)], task_id="test-filter")
        assert "data" not in result["symbols_checked"]

    def test_all_common_filter_names_excluded(self, tmp_path):
        body = "\n".join(f"def {n}(): pass" for n in COMMON_FILTER)
        src = tmp_path / "all_common.py"
        src.write_text(body + "\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)])
        for n in COMMON_FILTER:
            assert n not in result["symbols_checked"]


# ---------------------------------------------------------------------------
# Scenario 4: Timeout behavior
# ---------------------------------------------------------------------------


class TestTimeoutBehavior:

    def test_parse_diff_lines_basic_hunk(self):
        lines = _parse_diff_lines("@@ -1,3 +10,5 @@\n context\n+added line\n")
        assert lines == [10, 11, 12, 13, 14]

    def test_parse_diff_lines_multiple_hunks(self):
        lines = _parse_diff_lines("@@ -1,2 +1,3 @@\n a\n@@ -10,1 +20,2 @@\n b\n")
        assert 1 in lines and 2 in lines and 3 in lines
        assert 20 in lines and 21 in lines

    def test_parse_diff_lines_no_count(self):
        assert _parse_diff_lines("@@ -5 +7 @@\n") == [7]

    def test_scan_returns_warn_on_timeout(self, tmp_path):
        files = []
        for i in range(20):
            src = tmp_path / f"module_{i}.py"
            src.write_text(f"def unique_function_{i}(): pass\n")
            files.append(str(src))
        call_count = [0]

        def fake_monotonic():
            call_count[0] += 1
            if call_count[0] == 1:
                return 1_000_000.0
            return 1_000_100.0

        with patch("impact_scanner.time.monotonic", side_effect=fake_monotonic):
            with patch("subprocess.run") as mock_run:
                mock_run.return_value = _make_proc(stdout="", returncode=0)
                result = scan(project_root=str(tmp_path), modified_files=files, timeout=1)
        assert result["gate_result"] == "WARN"


# ---------------------------------------------------------------------------
# Scenario 5: Zero references -> PASS
# ---------------------------------------------------------------------------


class TestZeroReferencesPass:

    def test_unique_function_no_references_returns_pass(self, tmp_path):
        unique_name = "xyzzy_totally_unique_func_a1b2c3"
        src = tmp_path / "unique_module.py"
        src.write_text(f"def {unique_name}():\n    return 42\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)], task_id="zero-refs")
        assert result["gate_result"] == "PASS"
        assert unique_name in result["symbols_checked"]


# ---------------------------------------------------------------------------
# Additional edge-case tests
# ---------------------------------------------------------------------------


class TestExtractSymbolsPythonEdgeCases:

    def test_empty_diff_lines_returns_all_symbols(self, tmp_path):
        src = tmp_path / "module.py"
        src.write_text("def alpha(): pass\nclass Beta: pass\ndef gamma(): pass\n")
        symbols = extract_symbols_python(str(src), [])
        assert "alpha" in symbols and "Beta" in symbols and "gamma" in symbols

    def test_specific_diff_lines_returns_only_overlapping(self, tmp_path):
        src = tmp_path / "module.py"
        src.write_text("def alpha(): pass\ndef beta(): pass\n")
        symbols = extract_symbols_python(str(src), [1])
        assert "alpha" in symbols
        assert "beta" not in symbols

    def test_syntax_error_file_returns_empty_list(self, tmp_path):
        src = tmp_path / "broken.py"
        src.write_text("def (:\n    pass\n")
        assert extract_symbols_python(str(src), []) == []

    def test_nonexistent_file_returns_empty_list(self, tmp_path):
        assert extract_symbols_python(str(tmp_path / "ghost.py"), []) == []

    def test_nested_functions_are_included(self, tmp_path):
        src = tmp_path / "nested.py"
        src.write_text("def outer():\n    def inner(): pass\n")
        symbols = extract_symbols_python(str(src), [])
        assert "outer" in symbols and "inner" in symbols

    def test_async_function_is_extracted(self, tmp_path):
        src = tmp_path / "async_mod.py"
        src.write_text("async def fetch_data(): pass\n")
        assert "fetch_data" in extract_symbols_python(str(src), [])


class TestExtractSymbolsTypescriptEdgeCases:

    def test_nonexistent_file_returns_empty_list(self, tmp_path):
        assert extract_symbols_typescript(str(tmp_path / "ghost.ts"), []) == []

    def test_multiple_export_types(self, tmp_path):
        ts_file = tmp_path / "exports.ts"
        ts_file.write_text("export function MyFunction() {}\nexport class MyClass {}\nexport const MY_CONST = 42;\nexport enum MyEnum { A, B }\nexport interface MyInterface {}\nexport type MyType = string;\n")
        symbols = extract_symbols_typescript(str(ts_file), [])
        for s in ["MyFunction", "MyClass", "MY_CONST", "MyEnum", "MyInterface", "MyType"]:
            assert s in symbols

    def test_specific_diff_lines_filters_correctly(self, tmp_path):
        ts_file = tmp_path / "module.ts"
        ts_file.write_text("export function Alpha() {}\nexport function Beta() {}\n")
        symbols = extract_symbols_typescript(str(ts_file), [2])
        assert "Beta" in symbols and "Alpha" not in symbols

    def test_non_exported_symbols_not_included(self, tmp_path):
        ts_file = tmp_path / "internal.ts"
        ts_file.write_text("function privateHelper() {}\nexport function PublicApi() {}\n")
        symbols = extract_symbols_typescript(str(ts_file), [])
        assert "PublicApi" in symbols and "privateHelper" not in symbols


class TestParseDiffLines:

    def test_empty_string(self):
        assert _parse_diff_lines("") == []

    def test_count_zero(self):
        assert _parse_diff_lines("@@ -1,0 +5,0 @@\n") == []

    def test_single_line(self):
        assert _parse_diff_lines("@@ -3,1 +7,1 @@\n") == [7]

    def test_sorted_output(self):
        result = _parse_diff_lines("@@ -1,3 +20,3 @@\n@@ -10,2 +1,2 @@\n")
        assert result == sorted(result)

    def test_no_hunk_headers(self):
        assert _parse_diff_lines("diff --git a/f.py b/f.py\nindex abc..def 100644\n") == []


class TestGateThresholds:

    def _run_scan_with_n_refs(self, tmp_path, n_refs: int) -> dict:
        src = tmp_path / "source.py"
        src.write_text("def threshold_func(): pass\n")
        grep_lines = "\n".join(f"{tmp_path}/other_{i}.py:1:threshold_func()" for i in range(n_refs))

        def fake_run(cmd, **_kw):
            if cmd[0] == "grep":
                return _make_proc(stdout=grep_lines if grep_lines else "")
            return _make_proc(stdout="", returncode=0)

        with patch("subprocess.run", side_effect=fake_run):
            return scan(project_root=str(tmp_path), modified_files=[str(src)], task_id="threshold-test")

    def test_zero_refs_is_pass(self, tmp_path):
        assert self._run_scan_with_n_refs(tmp_path, 0)["gate_result"] == "PASS"

    def test_one_ref_is_warn(self, tmp_path):
        assert self._run_scan_with_n_refs(tmp_path, 1)["gate_result"] == "WARN"

    def test_five_refs_is_warn(self, tmp_path):
        assert self._run_scan_with_n_refs(tmp_path, 5)["gate_result"] == "WARN"

    def test_six_refs_is_block(self, tmp_path):
        assert self._run_scan_with_n_refs(tmp_path, 6)["gate_result"] == "BLOCK"

    def test_many_refs_is_block(self, tmp_path):
        assert self._run_scan_with_n_refs(tmp_path, 20)["gate_result"] == "BLOCK"


class TestGrepReferencesExcludeFiles:

    def test_exclude_files_are_omitted(self, tmp_path):
        define_file = tmp_path / "define.py"
        define_file.write_text("def my_func(): pass\n")
        other_file = tmp_path / "other.py"
        other_file.write_text("my_func()\n")
        refs = grep_references("my_func", str(tmp_path), [str(define_file), str(other_file)])
        ref_files = [r["file"] for r in refs]
        for excluded in [str(define_file), str(other_file)]:
            assert not any(Path(f).resolve() == Path(excluded).resolve() for f in ref_files)

    def test_non_excluded_file_is_included(self, tmp_path):
        define_file = tmp_path / "define.py"
        define_file.write_text("def visible_func(): pass\n")
        ref_file = tmp_path / "ref.py"
        ref_file.write_text("visible_func()\n")
        refs = grep_references("visible_func", str(tmp_path), [str(define_file)])
        assert any("ref.py" in r["file"] for r in refs)


class TestScanOutputStructure:

    def test_result_has_required_keys(self, tmp_path):
        src = tmp_path / "check.py"
        src.write_text("def check_something(): pass\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)], task_id="structure-test")
        for k in ["task_id", "gate_result", "unmodified_references", "symbols_checked"]:
            assert k in result
        assert result["task_id"] == "structure-test"

    def test_unsupported_extension_skipped(self, tmp_path):
        src = tmp_path / "data.csv"
        src.write_text("col1,col2\n1,2\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)])
        assert result["gate_result"] == "PASS"
        assert result["symbols_checked"] == []

    def test_task_id_propagated(self, tmp_path):
        src = tmp_path / "mod.py"
        src.write_text("def func(): pass\n")
        with patch("subprocess.run") as mock_run:
            mock_run.return_value = _make_proc(stdout="", returncode=0)
            result = scan(project_root=str(tmp_path), modified_files=[str(src)], task_id="my-task-99")
        assert result["task_id"] == "my-task-99"
