#!/usr/bin/env python3
"""auto_merge_controller.py — Auto-merge controller (task-2444).

Goal (회장 명시):

    "회장 승인 없이도 안전한 PR은 자동 merge되고, 위험한 PR은 GitHub
     ruleset과 controller 양쪽에서 차단된다."
    "봇은 ruleset을 우회하지 않는다. GitHub이 허용한 PR만 자동 처리한다."

The controller polls open PRs targeting ``main`` and merges them only when
**all** of the following GitHub-side preconditions hold:

  1. PR ``base.ref == "main"``
  2. No ``memory/events/<task-id>.cancelled`` marker
  3. All 8 required check-runs are present and ``conclusion == "success"``
     ({ci/guard, guard, cancel-kill-switch, qc-check, hidden-path-audit,
       lock-in-check, merge-safety-check, gemini-review-gate})
  4. ``mergeable_state`` is **not** in {blocked, behind, dirty, unstable}
  5. ``gemini-review-gate`` conclusion is ``success`` (re-asserted)
  6. Zero unresolved review threads (GraphQL ``reviewThreads.isResolved``)
  7. PR is up-to-date with main (``mergeable_state != "behind"``)

Only when **every** condition is satisfied do we invoke
``gh pr merge <num> --auto --merge --delete-branch``. The controller never
uses ``--admin``, never calls ``git push origin main``, and never tries to
merge a PR whose ``mergeable_state`` is ``BLOCKED``.

The controller is **idempotent and safe to run on a cron**: every cycle is
protected by a process-wide ``FileLock``; main-HEAD is recorded in
``memory/logs/auto-merge-audit.jsonl`` before/after each merge so we can
prove nothing was force-pushed.
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import re
import subprocess
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable

# ---------------------------------------------------------------------------
# Bootstrap (workspace root + sibling helper)
# ---------------------------------------------------------------------------

_HERE = Path(__file__).resolve().parent
_WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", str(_HERE.parent))).resolve()
if str(_WORKSPACE) not in sys.path:
    sys.path.insert(0, str(_WORKSPACE))
if str(_HERE) not in sys.path:
    sys.path.insert(0, str(_HERE))

from auto_merge_lock import FileLock, LockTimeout  # noqa: E402  # pyright: ignore[reportMissingImports]

# ---------------------------------------------------------------------------
# Constants — Hard-coded safety boundaries
# ---------------------------------------------------------------------------

REPO_DEFAULT = os.environ.get("REPO", "JonghyukJeon/dev_workspace")
LOCK_PATH = _WORKSPACE / "memory" / "cache" / "auto_merge_controller.lock"
AUDIT_LOG = _WORKSPACE / "memory" / "logs" / "auto-merge-audit.jsonl"
EVENTS_DIR = _WORKSPACE / "memory" / "events"

REQUIRED_CHECKS: frozenset[str] = frozenset({
    "ci/guard",
    "guard",
    "cancel-kill-switch",
    "qc-check",
    "hidden-path-audit",
    "lock-in-check",
    "merge-safety-check",
    "gemini-review-gate",
})

#: ``mergeable_state`` values that MUST block auto-merge. GitHub returns
#: lowercase strings via the REST API. ``"clean"`` and ``"has_hooks"`` are
#: the only two that may proceed.
BLOCKED_MERGE_STATES: frozenset[str] = frozenset({
    "blocked",
    "behind",
    "dirty",
    "unstable",
    "draft",
    "unknown",
})

#: Forbidden CLI flags / phrases. Any merge invocation that tries to use
#: these is intercepted with a ``RuntimeError`` BEFORE the subprocess runs.
FORBIDDEN_FLAGS: frozenset[str] = frozenset({"--admin"})
FORBIDDEN_GIT_PUSH = ("git", "push")  # combined with main detection below

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------

LOGGER = logging.getLogger("auto_merge_controller")
if not LOGGER.handlers:
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(
        logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
    )
    LOGGER.addHandler(handler)
    LOGGER.setLevel(logging.INFO)


# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------


@dataclass
class SkipDecision:
    """Why a PR was skipped this cycle (still open, no merge attempt)."""

    pr_number: int
    reason: str
    label: str | None = None  # e.g. "auto-merge-blocked", "gemini-blocked"


@dataclass
class MergeDecision:
    """A PR cleared all gates and the controller invoked ``gh pr merge``."""

    pr_number: int
    head_branch: str
    head_sha: str
    main_head_before: str
    main_head_after: str | None = None
    merged: bool = False
    merged_at: str | None = None


@dataclass
class CycleResult:
    """Result of one controller cycle. Useful for tests + audit."""

    repo: str
    started_at: float
    finished_at: float | None = None
    main_head_before_cycle: str | None = None
    skipped: list[SkipDecision] = field(default_factory=list)
    merged: list[MergeDecision] = field(default_factory=list)
    cancelled_closed: list[int] = field(default_factory=list)
    errors: list[str] = field(default_factory=list)


# ---------------------------------------------------------------------------
# Subprocess wrappers — every external call funnels through here so the
# safety net (FORBIDDEN_FLAGS + push detection) cannot be bypassed.
# ---------------------------------------------------------------------------


def _enforce_forbidden(cmd: list[str]) -> None:
    """Hard-block dangerous commands BEFORE subprocess execution."""
    bad = [f for f in cmd if f in FORBIDDEN_FLAGS]
    if bad:
        raise RuntimeError(f"[FORBIDDEN] forbidden flag(s) in command: {bad}")
    # block ``git push`` to main (allow other git push variants from no other code path)
    if len(cmd) >= 2 and cmd[0] == FORBIDDEN_GIT_PUSH[0] and cmd[1] == FORBIDDEN_GIT_PUSH[1]:
        joined = " ".join(cmd)
        if re.search(r"\borigin\s+main\b", joined) or re.search(r"\bHEAD:main\b", joined):
            raise RuntimeError("[FORBIDDEN] direct push to main is not allowed")


def run_cmd(cmd: list[str], *, check: bool = True, capture: bool = True) -> subprocess.CompletedProcess:
    """Run a subprocess with the forbidden-flag guard always engaged."""
    _enforce_forbidden(cmd)
    LOGGER.debug("run_cmd: %s", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        capture_output=capture,
        text=True,
        check=False,
    )
    if check and proc.returncode != 0:
        raise RuntimeError(
            f"command failed ({proc.returncode}): {' '.join(cmd)}\n"
            f"stdout={proc.stdout}\nstderr={proc.stderr}"
        )
    return proc


def gh_api(path: str, *, method: str = "GET") -> Any:
    """Call ``gh api <path>`` and return parsed JSON."""
    cmd = ["gh", "api", "-X", method, path]
    proc = run_cmd(cmd, check=True)
    if not proc.stdout.strip():
        return None
    return json.loads(proc.stdout)


def gh_api_graphql(query: str, variables: dict[str, Any] | None = None) -> Any:
    """Call ``gh api graphql`` for richer queries (review threads etc.)."""
    cmd = ["gh", "api", "graphql", "-f", f"query={query}"]
    for k, v in (variables or {}).items():
        cmd.extend(["-F", f"{k}={v}"])
    proc = run_cmd(cmd, check=True)
    return json.loads(proc.stdout) if proc.stdout.strip() else None


# ---------------------------------------------------------------------------
# Main HEAD recording
# ---------------------------------------------------------------------------


def get_main_head(repo: str) -> str:
    """Return the current HEAD sha of ``main`` per the GitHub REST API."""
    data = gh_api(f"/repos/{repo}/branches/main")
    return data["commit"]["sha"]


def record_main_head(stage: str, repo: str, *, extra: dict[str, Any] | None = None) -> str:
    """Append a main-HEAD audit entry. Returns the recorded sha."""
    sha = get_main_head(repo)
    AUDIT_LOG.parent.mkdir(parents=True, exist_ok=True)
    entry = {
        "timestamp": time.time(),
        "iso": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "stage": stage,
        "repo": repo,
        "main_head": sha,
    }
    if extra:
        entry.update(extra)
    with AUDIT_LOG.open("a") as f:
        f.write(json.dumps(entry, sort_keys=True) + "\n")
    return sha


# ---------------------------------------------------------------------------
# Eligibility checks — pure functions over GitHub responses
# ---------------------------------------------------------------------------


def task_id_from_branch(branch: str) -> str | None:
    """Extract ``task-N`` from a head branch (``task/task-2444-dev2`` etc.)."""
    m = re.search(r"task-\d+", branch or "")
    return m.group(0) if m else None


def cancelled_marker_path(task_id: str) -> Path:
    return EVENTS_DIR / f"{task_id}.cancelled"


def required_check_state(check_runs: list[dict[str, Any]]) -> tuple[set[str], set[str]]:
    """Return (missing, non_success) for the 8 required checks.

    A check is ``non_success`` if it appears with conclusion != ``success``
    (or if its conclusion is None, i.e. still in progress).
    """
    present_latest: dict[str, str | None] = {}
    for cr in check_runs or []:
        name = cr.get("name")
        if name not in REQUIRED_CHECKS:
            continue
        # Pick the most recent run per name (later entries win — GitHub
        # returns runs in execution order; we don't depend on this beyond
        # "any non-success means non-success").
        present_latest[name] = cr.get("conclusion")
    missing = set(REQUIRED_CHECKS) - present_latest.keys()
    non_success = {
        name for name, conclusion in present_latest.items() if conclusion != "success"
    }
    return missing, non_success


def has_unresolved_threads(graphql_response: dict[str, Any] | None) -> bool:
    """True if any review thread is unresolved."""
    if not graphql_response:
        return False
    nodes = (
        graphql_response.get("data", {})
        .get("repository", {})
        .get("pullRequest", {})
        .get("reviewThreads", {})
        .get("nodes", [])
    )
    return any(not n.get("isResolved", True) for n in nodes)


# ---------------------------------------------------------------------------
# Action helpers — close cancelled PRs, add labels, perform safe merge
# ---------------------------------------------------------------------------


def label_blocked(pr_num: int, label: str, repo: str) -> None:
    """Attach a label to the PR (e.g. ``auto-merge-blocked``)."""
    try:
        run_cmd(
            [
                "gh",
                "pr",
                "edit",
                str(pr_num),
                "--add-label",
                label,
                "--repo",
                repo,
            ],
            check=False,
        )
    except RuntimeError as exc:
        LOGGER.warning("label_blocked failed for PR #%d: %s", pr_num, exc)


def handle_cancelled_pr(pr: dict[str, Any], repo: str) -> None:
    """Close a PR whose task has a cancelled marker. Branch is deleted."""
    pr_num = pr["number"]
    branch = pr["head"]["ref"]
    LOGGER.info("[cancel-close] closing PR #%d (branch=%s)", pr_num, branch)
    run_cmd(
        [
            "gh",
            "pr",
            "close",
            str(pr_num),
            "--delete-branch",
            "--repo",
            repo,
            "--comment",
            "[auto-merge-controller] cancelled marker detected, PR closed.",
        ],
        check=False,
    )


def safe_merge(pr_num: int, repo: str) -> subprocess.CompletedProcess:
    """Invoke ``gh pr merge --auto --merge --delete-branch``.

    The forbidden-flag guard inside ``run_cmd`` will refuse any caller that
    tries to inject ``--admin`` or push directly to main.
    """
    cmd = [
        "gh",
        "pr",
        "merge",
        str(pr_num),
        "--auto",
        "--merge",
        "--delete-branch",
        "--repo",
        repo,
    ]
    return run_cmd(cmd, check=False)


def post_check(pr_num: int, expected_branch: str, repo: str) -> dict[str, Any]:
    """Verify GitHub's view of the merge after we invoked it.

    Returns a dict with ``merged``, ``merged_at``, ``branch_deleted`` so the
    caller (and audit log) can store the truth.
    """
    pr = gh_api(f"/repos/{repo}/pulls/{pr_num}")
    branch_proc = subprocess.run(
        ["gh", "api", f"/repos/{repo}/branches/{expected_branch}"],
        capture_output=True,
        text=True,
    )
    return {
        "merged": bool(pr.get("merged")),
        "merged_at": pr.get("merged_at"),
        "branch_deleted": branch_proc.returncode != 0,
        "merge_commit_sha": pr.get("merge_commit_sha"),
    }


# ---------------------------------------------------------------------------
# Core decision pipeline (pure-ish — only side effect is logging)
# ---------------------------------------------------------------------------


def evaluate_pr(
    pr: dict[str, Any],
    *,
    repo: str,
    api: Callable[[str], Any] = gh_api,
    graphql: Callable[..., Any] = gh_api_graphql,
    cancelled_marker_exists: Callable[[str], bool] | None = None,
) -> SkipDecision | None:
    """Return a SkipDecision (with reason/label) or None if PR is mergeable.

    Pure with respect to side effects — callers handle close/label/merge.
    Dependency-injected ``api`` / ``graphql`` / ``cancelled_marker_exists``
    keep this fully testable without touching the network or filesystem.
    """
    if cancelled_marker_exists is None:
        cancelled_marker_exists = lambda tid: cancelled_marker_path(tid).exists()  # noqa: E731

    pr_num = pr["number"]
    head = pr.get("head", {}) or {}
    branch = head.get("ref", "")
    sha = head.get("sha", "")
    base_ref = (pr.get("base") or {}).get("ref", "")

    # 2. base must be main
    if base_ref != "main":
        return SkipDecision(pr_num, f"base.ref={base_ref!r} (not main)")

    # 3. cancelled marker → caller must close, not merge
    task_id = task_id_from_branch(branch)
    if task_id and cancelled_marker_exists(task_id):
        return SkipDecision(pr_num, f"cancelled marker present ({task_id})", label=None)

    # 4. required checks all success
    cr_payload = api(f"/repos/{repo}/commits/{sha}/check-runs")
    missing, non_success = required_check_state(cr_payload.get("check_runs", []))
    if missing or non_success:
        # Determine which label fits
        gemini_state = next(
            (
                cr.get("conclusion")
                for cr in cr_payload.get("check_runs", [])
                if cr.get("name") == "gemini-review-gate"
            ),
            None,
        )
        if gemini_state is not None and gemini_state != "success":
            label = "gemini-blocked"
            reason = f"gemini-review-gate not success (={gemini_state!r})"
        elif missing and non_success:
            label = "auto-merge-blocked"
            reason = f"missing={sorted(missing)} non_success={sorted(non_success)}"
        elif missing:
            # Pending / not-yet-reported: don't label, just wait
            label = None
            reason = f"missing={sorted(missing)} (in_progress or not started)"
        else:
            label = "auto-merge-blocked"
            reason = f"non_success={sorted(non_success)}"
        return SkipDecision(pr_num, reason, label=label)

    # 5. mergeable_state gate
    pr_full = api(f"/repos/{repo}/pulls/{pr_num}")
    state = (pr_full.get("mergeable_state") or "").lower()
    if state in BLOCKED_MERGE_STATES:
        return SkipDecision(pr_num, f"mergeable_state={state!r}", label="auto-merge-blocked")

    # 6. gemini-review-gate explicit re-assertion (defensive)
    gemini_state = next(
        (
            cr.get("conclusion")
            for cr in cr_payload.get("check_runs", [])
            if cr.get("name") == "gemini-review-gate"
        ),
        None,
    )
    if gemini_state != "success":
        return SkipDecision(
            pr_num,
            f"gemini-review-gate not success (={gemini_state!r})",
            label="gemini-blocked",
        )

    # 7. unresolved review threads
    query = """
      query($owner:String!,$name:String!,$number:Int!){
        repository(owner:$owner,name:$name){
          pullRequest(number:$number){
            reviewThreads(first:100){ nodes { isResolved } }
          }
        }
      }
    """
    owner, name = repo.split("/", 1)
    threads = graphql(query, {"owner": owner, "name": name, "number": pr_num})
    if has_unresolved_threads(threads):
        return SkipDecision(pr_num, "unresolved conversations", label="auto-merge-blocked")

    # 8. up-to-date with main (already covered by mergeable_state != behind)
    return None


# ---------------------------------------------------------------------------
# Process loop
# ---------------------------------------------------------------------------


def list_open_prs(repo: str) -> list[dict[str, Any]]:
    return gh_api(f"/repos/{repo}/pulls?state=open&base=main") or []


def process_open_prs(
    repo: str = REPO_DEFAULT,
    *,
    api: Callable[[str], Any] = gh_api,
    graphql: Callable[..., Any] = gh_api_graphql,
    list_prs: Callable[[str], list[dict[str, Any]]] | None = None,
    safe_merge_fn: Callable[[int, str], subprocess.CompletedProcess] | None = None,
    head_recorder: Callable[..., str] | None = None,
    cancelled_marker_exists: Callable[[str], bool] | None = None,
    now: Callable[[], float] = time.time,
) -> CycleResult:
    """Run one full cycle. Dependency-injected for tests."""
    list_prs = list_prs or list_open_prs
    safe_merge_fn = safe_merge_fn or safe_merge
    head_recorder = head_recorder or (lambda stage, **kw: record_main_head(stage, repo, **kw))

    result = CycleResult(repo=repo, started_at=now())
    try:
        result.main_head_before_cycle = head_recorder("before-cycle")
    except Exception as exc:  # noqa: BLE001
        result.errors.append(f"head-record before-cycle: {exc}")

    try:
        prs = list_prs(repo)
    except Exception as exc:  # noqa: BLE001
        result.errors.append(f"list_open_prs: {exc}")
        result.finished_at = now()
        return result

    for pr in prs:
        pr_num = pr["number"]
        branch = (pr.get("head") or {}).get("ref", "")
        task_id = task_id_from_branch(branch)
        if task_id and (cancelled_marker_exists or (lambda tid: cancelled_marker_path(tid).exists()))(task_id):
            try:
                handle_cancelled_pr(pr, repo)
                result.cancelled_closed.append(pr_num)
            except Exception as exc:  # noqa: BLE001
                result.errors.append(f"close cancelled #{pr_num}: {exc}")
            continue

        try:
            decision = evaluate_pr(
                pr,
                repo=repo,
                api=api,
                graphql=graphql,
                cancelled_marker_exists=cancelled_marker_exists,
            )
        except Exception as exc:  # noqa: BLE001
            result.errors.append(f"evaluate #{pr_num}: {exc}")
            continue

        if decision is not None:
            LOGGER.info("[skip] PR #%d — %s", decision.pr_number, decision.reason)
            if decision.label:
                label_blocked(decision.pr_number, decision.label, repo)
            result.skipped.append(decision)
            continue

        # Cleared all gates — record main HEAD, merge, record again, post_check
        head = pr["head"]
        try:
            before = head_recorder(
                f"before-merge-pr-{pr_num}", extra={"pr": pr_num, "branch": head["ref"]}
            )
        except Exception as exc:  # noqa: BLE001
            result.errors.append(f"head-record before-merge #{pr_num}: {exc}")
            before = ""

        merge_decision = MergeDecision(
            pr_number=pr_num,
            head_branch=head["ref"],
            head_sha=head["sha"],
            main_head_before=before,
        )

        try:
            proc = safe_merge_fn(pr_num, repo)
            if proc.returncode != 0:
                result.errors.append(
                    f"safe_merge #{pr_num} rc={proc.returncode} stderr={proc.stderr.strip()}"
                )
        except RuntimeError as exc:
            # forbidden flag tripwire — bubble up to errors but continue cycle
            result.errors.append(f"safe_merge guarded #{pr_num}: {exc}")
            continue

        try:
            after = head_recorder(
                f"after-merge-pr-{pr_num}", extra={"pr": pr_num}
            )
            merge_decision.main_head_after = after
        except Exception as exc:  # noqa: BLE001
            result.errors.append(f"head-record after-merge #{pr_num}: {exc}")

        try:
            pc = post_check(pr_num, head["ref"], repo)
            merge_decision.merged = bool(pc["merged"])
            merge_decision.merged_at = pc["merged_at"]
        except Exception as exc:  # noqa: BLE001
            result.errors.append(f"post_check #{pr_num}: {exc}")

        result.merged.append(merge_decision)

    result.finished_at = now()
    return result


# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Auto-merge controller (task-2444)")
    parser.add_argument("--repo", default=REPO_DEFAULT, help="GitHub repo (owner/name)")
    parser.add_argument(
        "--lock-timeout",
        type=float,
        default=10.0,
        help="Seconds to wait for the controller lock",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Evaluate PRs but do not perform merge/close/label actions",
    )
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args(argv)

    if args.verbose:
        LOGGER.setLevel(logging.DEBUG)

    if args.dry_run:
        global safe_merge  # noqa: PLW0603

        def _no_merge(pr_num: int, repo: str) -> subprocess.CompletedProcess:
            LOGGER.info("[dry-run] would merge PR #%d (repo=%s)", pr_num, repo)
            return subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr="")

        safe_merge = _no_merge

    try:
        with FileLock(LOCK_PATH, timeout=args.lock_timeout):
            result = process_open_prs(repo=args.repo)
    except LockTimeout as exc:
        LOGGER.warning("controller already running: %s", exc)
        return 0

    LOGGER.info(
        "cycle done — merged=%d skipped=%d cancelled=%d errors=%d",
        len(result.merged),
        len(result.skipped),
        len(result.cancelled_closed),
        len(result.errors),
    )
    if result.errors:
        for err in result.errors:
            LOGGER.error("  - %s", err)
        return 1
    return 0


if __name__ == "__main__":
    sys.exit(main())
