"""Regression tests for ``scripts/auto_merge_controller.py`` (task-2444).

Coverage:

* All 8 required-check enforcement
* mergeable_state BLOCKED / behind / dirty / unstable rejection
* gemini-review-gate failure / SKIPPED rejection (with label)
* cancelled marker → close (no merge attempt)
* Forbidden flag tripwire (--admin)
* Forbidden direct push to main
* End-to-end 6 scenarios (A1..A6) using mocked GitHub API responses.

Tests are pure-Python — no network, no real ``gh`` CLI invocations.
"""

from __future__ import annotations

import json
import subprocess
import sys
from pathlib import Path
from typing import Any
from unittest import mock

import pytest

WORKSPACE = Path(__file__).resolve().parents[2]
SCRIPTS = WORKSPACE / "scripts"
sys.path.insert(0, str(SCRIPTS))
sys.path.insert(0, str(WORKSPACE))

import auto_merge_controller as amc  # noqa: E402  # pyright: ignore[reportMissingImports]
from auto_merge_lock import FileLock, LockTimeout  # noqa: E402  # pyright: ignore[reportMissingImports]


REPO = "TestOwner/test-repo"


# ---------------------------------------------------------------------------
# Helpers — synthetic GitHub fixtures
# ---------------------------------------------------------------------------


def make_pr(
    number: int,
    *,
    branch: str = "task/task-9999-dev2",
    base: str = "main",
    sha: str = "deadbeef" * 5,
    mergeable_state: str = "clean",
) -> dict[str, Any]:
    return {
        "number": number,
        "head": {"ref": branch, "sha": sha},
        "base": {"ref": base},
        "mergeable_state": mergeable_state,
        "merged": False,
        "merged_at": None,
        "merge_commit_sha": None,
    }


def all_success_check_runs(*extra_status: tuple[str, str]) -> list[dict[str, Any]]:
    """Return 8 required check-runs all ``success`` (overridable)."""
    runs = [{"name": name, "conclusion": "success"} for name in amc.REQUIRED_CHECKS]
    overrides = dict(extra_status)
    for r in runs:
        if r["name"] in overrides:
            r["conclusion"] = overrides[r["name"]]
    return runs


class FakeGitHub:
    """In-memory stub for the gh REST + GraphQL surface used by controller."""

    def __init__(
        self,
        *,
        prs: list[dict[str, Any]],
        check_runs: dict[str, list[dict[str, Any]]] | None = None,
        pr_full_overrides: dict[int, dict[str, Any]] | None = None,
        review_threads: dict[int, list[bool]] | None = None,
        main_head: str = "0" * 40,
    ):
        self.prs = list(prs)
        self.check_runs = check_runs or {}
        self.pr_full_overrides = pr_full_overrides or {}
        self.review_threads = review_threads or {}
        self.main_head_history = [main_head]
        self.calls: list[tuple[str, str]] = []

    # API surface ----------------------------------------------------------
    def api(self, path: str) -> Any:
        self.calls.append(("api", path))
        if path.endswith("/branches/main"):
            return {"commit": {"sha": self.main_head_history[-1]}}
        if "/check-runs" in path:
            sha = path.split("/commits/")[1].split("/")[0]
            return {"check_runs": self.check_runs.get(sha, [])}
        if "/pulls/" in path and not path.endswith("?state=open&base=main"):
            num = int(path.rsplit("/", 1)[-1])
            base = next((p for p in self.prs if p["number"] == num), None)
            if base is None:
                raise RuntimeError(f"PR not found: {num}")
            override = self.pr_full_overrides.get(num, {})
            return {**base, **override}
        if path.endswith("?state=open&base=main"):
            return [p for p in self.prs if p["base"]["ref"] == "main"]
        raise AssertionError(f"unexpected path: {path}")

    def graphql(self, _query: str, variables: dict[str, Any]) -> Any:
        num = variables["number"]
        nodes = [{"isResolved": resolved} for resolved in self.review_threads.get(num, [])]
        return {
            "data": {
                "repository": {"pullRequest": {"reviewThreads": {"nodes": nodes}}}
            }
        }

    # State mutators (used by safe_merge_fn stub) --------------------------
    def advance_main(self, new_sha: str) -> None:
        self.main_head_history.append(new_sha)


# ---------------------------------------------------------------------------
# Lock tests
# ---------------------------------------------------------------------------


def test_filelock_releases_on_exit(tmp_path: Path) -> None:
    p = tmp_path / "x.lock"
    with FileLock(p, timeout=1):
        pass
    # second acquisition should succeed immediately
    with FileLock(p, timeout=1):
        pass


def test_filelock_blocks_concurrent_holder(tmp_path: Path) -> None:
    p = tmp_path / "x.lock"
    holder = FileLock(p, timeout=10)
    holder.acquire()
    try:
        with pytest.raises(LockTimeout):
            with FileLock(p, timeout=0.3):
                pass
    finally:
        holder.release()


# ---------------------------------------------------------------------------
# Forbidden flag / push tripwire
# ---------------------------------------------------------------------------


def test_run_cmd_blocks_admin_flag() -> None:
    with pytest.raises(RuntimeError, match=r"FORBIDDEN.*--admin"):
        amc.run_cmd(["gh", "pr", "merge", "1", "--admin", "--merge"])


def test_run_cmd_blocks_direct_main_push() -> None:
    with pytest.raises(RuntimeError, match=r"FORBIDDEN.*push to main"):
        amc.run_cmd(["git", "push", "origin", "main"])


def test_run_cmd_blocks_head_to_main_push() -> None:
    with pytest.raises(RuntimeError, match=r"FORBIDDEN.*push to main"):
        amc.run_cmd(["git", "push", "origin", "HEAD:main"])


def test_run_cmd_allows_normal_gh_pr_merge(monkeypatch: pytest.MonkeyPatch) -> None:
    captured: dict[str, Any] = {}

    def fake_run(cmd, *, capture_output, text, check):
        captured["cmd"] = cmd
        return subprocess.CompletedProcess(cmd, 0, "ok", "")

    monkeypatch.setattr(amc.subprocess, "run", fake_run)
    proc = amc.run_cmd(["gh", "pr", "merge", "5", "--auto", "--merge", "--delete-branch"])
    assert proc.returncode == 0
    assert "--admin" not in captured["cmd"]


# ---------------------------------------------------------------------------
# evaluate_pr — pure decision tests
# ---------------------------------------------------------------------------


def test_evaluate_skips_non_main_base() -> None:
    pr = make_pr(1, base="develop")
    fake = FakeGitHub(prs=[pr])
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "not main" in decision.reason


def test_evaluate_skips_cancelled_marker() -> None:
    pr = make_pr(1, branch="task/task-1234-dev1")
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
    )
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda tid: tid == "task-1234",
    )
    assert decision is not None
    assert "cancelled" in decision.reason


def test_evaluate_skips_missing_required_check() -> None:
    pr = make_pr(1)
    runs = [r for r in all_success_check_runs() if r["name"] != "qc-check"]
    fake = FakeGitHub(prs=[pr], check_runs={pr["head"]["sha"]: runs})
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "qc-check" in decision.reason
    assert decision.label is None  # missing => pending, no label


def test_evaluate_skips_failed_check_with_label() -> None:
    pr = make_pr(1)
    runs = all_success_check_runs(("qc-check", "failure"))
    fake = FakeGitHub(prs=[pr], check_runs={pr["head"]["sha"]: runs})
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "qc-check" in decision.reason
    assert decision.label == "auto-merge-blocked"


def test_evaluate_skips_blocked_mergeable_state() -> None:
    pr = make_pr(1, mergeable_state="blocked")
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
    )
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "blocked" in decision.reason
    assert decision.label == "auto-merge-blocked"


def test_evaluate_skips_behind_main() -> None:
    pr = make_pr(1, mergeable_state="behind")
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
    )
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "behind" in decision.reason


def test_evaluate_skips_unresolved_threads() -> None:
    pr = make_pr(1)
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
        review_threads={1: [True, False]},
    )
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert "unresolved" in decision.reason


def test_evaluate_passes_when_all_clear() -> None:
    pr = make_pr(1)
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
        review_threads={1: [True, True]},
    )
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is None


def test_evaluate_skipped_check_without_success_flag_is_blocked() -> None:
    """gemini-review-gate==SKIPPED should be rejected with gemini-blocked label."""
    pr = make_pr(1)
    runs = all_success_check_runs(("gemini-review-gate", "skipped"))
    fake = FakeGitHub(prs=[pr], check_runs={pr["head"]["sha"]: runs})
    decision = amc.evaluate_pr(
        pr, repo=REPO, api=fake.api, graphql=fake.graphql,
        cancelled_marker_exists=lambda _t: False,
    )
    assert decision is not None
    assert decision.label == "gemini-blocked"


# ---------------------------------------------------------------------------
# 6-scenario end-to-end (A1..A6)
# ---------------------------------------------------------------------------


def _make_cycle_runner(fake: FakeGitHub):
    """Build a process_open_prs harness using FakeGitHub side-effect stubs."""
    audit: list[dict[str, Any]] = []
    merged_calls: list[int] = []
    closed_calls: list[int] = []
    labels: list[tuple[int, str]] = []

    def head_recorder(stage: str, extra: dict[str, Any] | None = None):
        sha = fake.main_head_history[-1]
        entry: dict[str, Any] = {"stage": stage, "main_head": sha}
        if extra:
            entry.update(extra)
        audit.append(entry)
        return sha

    def safe_merge_fn(pr_num: int, _repo: str):
        merged_calls.append(pr_num)
        # Simulate GitHub advancing main + setting merged=True
        new_sha = f"merge-{pr_num:08x}".ljust(40, "0")
        fake.advance_main(new_sha)
        for p in fake.prs:
            if p["number"] == pr_num:
                p["merged"] = True
                p["merged_at"] = f"2026-05-04T00:00:0{pr_num}Z"
                p["merge_commit_sha"] = new_sha
        return subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr="")

    def list_prs(_repo):
        return [p for p in fake.prs if not p["merged"] and p["base"]["ref"] == "main"]

    def cancelled_marker(tid: str) -> bool:
        return tid in fake.__dict__.get("_cancelled_set", set())

    def label_blocked(pr_num: int, label: str, _repo: str) -> None:
        labels.append((pr_num, label))

    def handle_cancelled(pr, _repo):
        closed_calls.append(pr["number"])
        for p in fake.prs:
            if p["number"] == pr["number"]:
                p["state"] = "closed"
                p["merged"] = False

    return {
        "audit": audit,
        "merged_calls": merged_calls,
        "closed_calls": closed_calls,
        "labels": labels,
        "head_recorder": head_recorder,
        "safe_merge_fn": safe_merge_fn,
        "list_prs": list_prs,
        "cancelled_marker": cancelled_marker,
        "label_blocked": label_blocked,
        "handle_cancelled": handle_cancelled,
    }


def _run_with_overrides(fake: FakeGitHub, harness: dict[str, Any]):
    """Invoke process_open_prs with full dependency injection + side effects."""
    with mock.patch.object(amc, "label_blocked", harness["label_blocked"]), \
         mock.patch.object(amc, "handle_cancelled_pr", harness["handle_cancelled"]), \
         mock.patch.object(amc, "post_check", lambda pn, br, repo: {
             "merged": True, "merged_at": f"2026-05-04T00:00:0{pn}Z",
             "branch_deleted": True, "merge_commit_sha": fake.main_head_history[-1],
         }):
        return amc.process_open_prs(
            repo=REPO,
            api=fake.api,
            graphql=fake.graphql,
            list_prs=harness["list_prs"],
            safe_merge_fn=harness["safe_merge_fn"],
            head_recorder=harness["head_recorder"],
            cancelled_marker_exists=harness["cancelled_marker"],
            now=lambda: 1730000000.0,
        )


def test_A1_all_checks_success_auto_merges() -> None:
    pr = make_pr(101)
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
        review_threads={101: [True]},
        main_head="aaa" * 13 + "a",
    )
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == [101]
    assert len(result.merged) == 1
    assert result.merged[0].pr_number == 101
    assert result.merged[0].merged is True
    # main head changed
    assert result.merged[0].main_head_before != result.merged[0].main_head_after
    # audit log captured before/after merge stages
    stages = [a["stage"] for a in harness["audit"]]
    assert "before-cycle" in stages
    assert "before-merge-pr-101" in stages
    assert "after-merge-pr-101" in stages


def test_A2_ci_failure_blocks_merge_with_label() -> None:
    pr = make_pr(102)
    runs = all_success_check_runs(("cancel-kill-switch", "failure"))
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: runs},
        pr_full_overrides={102: {"mergeable_state": "blocked"}},
        review_threads={102: []},
    )
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == []
    assert any("cancel-kill-switch" in s.reason for s in result.skipped)
    assert ("auto-merge-blocked" in [lbl for _, lbl in harness["labels"]])


def test_A3_pending_checks_skip_no_label() -> None:
    pr = make_pr(103)
    # only 5 of 8 checks reported, others still pending
    runs = [r for r in all_success_check_runs()
            if r["name"] in {"ci/guard", "guard", "cancel-kill-switch", "qc-check", "hidden-path-audit"}]
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: runs},
        review_threads={103: []},
    )
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == []
    assert harness["labels"] == []  # pending ≠ failure → no label
    assert any("missing" in s.reason for s in result.skipped)


def test_A4_gemini_blocked_with_label() -> None:
    pr = make_pr(104)
    runs = all_success_check_runs(("gemini-review-gate", "skipped"))
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: runs},
        review_threads={104: []},
    )
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == []
    assert ("gemini-blocked" in [lbl for _, lbl in harness["labels"]])
    assert any("gemini-review-gate" in s.reason for s in result.skipped)


def test_A5_cancelled_pr_is_closed_not_merged() -> None:
    pr = make_pr(105, branch="task/task-7777-dev1")
    fake = FakeGitHub(
        prs=[pr],
        check_runs={pr["head"]["sha"]: all_success_check_runs()},
        review_threads={105: []},
    )
    fake._cancelled_set = {"task-7777"}  # type: ignore[attr-defined]
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == []
    assert harness["closed_calls"] == [105]
    assert result.cancelled_closed == [105]


def test_A6_three_prs_serialized() -> None:
    prs = [make_pr(n, sha=f"{n}" * 40) for n in (201, 202, 203)]
    fake = FakeGitHub(
        prs=prs,
        check_runs={p["head"]["sha"]: all_success_check_runs() for p in prs},
        review_threads={p["number"]: [] for p in prs},
        main_head="0" * 40,
    )
    harness = _make_cycle_runner(fake)
    result = _run_with_overrides(fake, harness)

    assert harness["merged_calls"] == [201, 202, 203]
    # main head progresses sequentially
    audit_heads = [a for a in harness["audit"] if a["stage"].startswith(("before-merge", "after-merge"))]
    seen_sha: set[str] = set()
    for entry in audit_heads:
        seen_sha.add(entry["main_head"])
    # before/after captures of 3 merges → at least 4 distinct shas (initial + 3 merges)
    assert len(seen_sha) >= 4
    # merged_at strictly ordered
    merged_at = [m.merged_at for m in result.merged]
    assert merged_at == sorted(merged_at)


# ---------------------------------------------------------------------------
# Audit log / main HEAD recording
# ---------------------------------------------------------------------------


def test_record_main_head_appends_jsonl(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
    log_path = tmp_path / "audit.jsonl"
    monkeypatch.setattr(amc, "AUDIT_LOG", log_path)
    monkeypatch.setattr(amc, "get_main_head", lambda _r: "abc123")
    sha = amc.record_main_head("test-stage", REPO, extra={"foo": "bar"})
    assert sha == "abc123"
    line = json.loads(log_path.read_text().strip())
    assert line["stage"] == "test-stage"
    assert line["main_head"] == "abc123"
    assert line["repo"] == REPO
    assert line["foo"] == "bar"


def test_required_check_state_detects_partial_success() -> None:
    runs = [
        {"name": "ci/guard", "conclusion": "success"},
        {"name": "guard", "conclusion": "success"},
        {"name": "cancel-kill-switch", "conclusion": "success"},
        {"name": "qc-check", "conclusion": "failure"},
        # 4 checks present (1 failure), 4 missing
    ]
    missing, non_success = amc.required_check_state(runs)
    assert "qc-check" in non_success
    assert "gemini-review-gate" in missing
    assert len(missing) == 4
    assert len(non_success) == 1
