#!/usr/bin/env python3
"""lock_in_verify.py — Lock-in 1 First-line 가드 자동 검증 (CI 강제용).

task-2439에서 도입한 "First-line guard" 패턴이 머지 경로 함수의 *첫 statement*로
유지되는지 AST 기반으로 검사한다. 위반 시 exit_code=1로 CI를 차단한다.

검사 대상:
  1) scripts/anu_confirm_bot/main.py::_execute_approve
       - 첫 문장: cancelled 마커 Path() 할당
       - 그 다음: cancelled .exists() 가드 + early return
       - guard.sh subprocess.run(...)이 *모든* gh pr merge 호출보다 line-number 기준 먼저
  2) scripts/auto_merge.py::AutoMerger.execute_merge
       - 동일 패턴, merge 시그너처는 worktree_manager finish 호출
"""

from __future__ import annotations

import argparse
import ast
import sys
from pathlib import Path

WORKSPACE = Path(__file__).resolve().parent.parent

CHECKS = [
    {
        # task-2449 Fix 5: gh pr merge 직접 호출 → taskctl merge 라우팅으로 변경.
        # First-line 가드(cancelled / guard.sh)는 보존되며, 실제 머지 subprocess는
        # taskctl.py를 호출한다. token은 변수명/리터럴에서 동시에 잡히도록 ("taskctl", "merge").
        "file": "scripts/anu_confirm_bot/main.py",
        "func": "_execute_approve",
        "merge_signature_tokens": ("taskctl", "merge"),
    },
    {
        "file": "scripts/auto_merge.py",
        "func": "execute_merge",
        "merge_signature_tokens": ("worktree_manager", "finish"),
    },
]


def _find_function(tree: ast.AST, name: str) -> ast.FunctionDef | None:
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name:
            return node  # type: ignore[return-value]
    return None


def _strip_docstring(body: list[ast.stmt]) -> list[ast.stmt]:
    if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant) and isinstance(body[0].value.value, str):
        return body[1:]
    return body


def _is_cancelled_path_assign(stmt: ast.stmt) -> bool:
    if not isinstance(stmt, ast.Assign):
        return False
    if not stmt.targets or not isinstance(stmt.targets[0], ast.Name):
        return False
    if "cancelled" not in stmt.targets[0].id.lower():
        return False
    src = ast.unparse(stmt.value)
    return ".cancelled" in src and "memory" in src and "events" in src


def _is_cancelled_exists_guard(stmt: ast.stmt) -> bool:
    if not isinstance(stmt, ast.If):
        return False
    test_src = ast.unparse(stmt.test)
    if "cancelled" not in test_src.lower() or ".exists()" not in test_src:
        return False
    if not stmt.body:
        return False
    first = stmt.body[0]
    return isinstance(first, (ast.Return, ast.Raise))


def _is_guard_sh_subprocess(call: ast.Call, guard_var_names: set[str]) -> bool:
    func_src = ast.unparse(call.func)
    if func_src != "subprocess.run":
        return False
    src = ast.unparse(call)
    if "guard.sh" in src:
        return True
    return any(name in src for name in guard_var_names)


def _collect_guard_var_names(func: ast.FunctionDef) -> set[str]:
    names: set[str] = set()
    for node in ast.walk(func):
        if isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            value_src = ast.unparse(node.value)
            if "guard.sh" in value_src:
                names.add(node.targets[0].id)
    return names


def _list_contains_tokens(node: ast.AST, tokens: tuple[str, ...]) -> bool:
    """Find any list literal or call whose string-literal contents include all tokens."""
    if isinstance(node, (ast.List, ast.Tuple)):
        joined = " ".join(
            elt.value for elt in node.elts if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
        )
        if all(t in joined for t in tokens):
            return True
    if isinstance(node, ast.Call):
        for arg in node.args:
            if _list_contains_tokens(arg, tokens):
                return True
    return False


def _find_first_guard_sh_lineno(func: ast.FunctionDef) -> int | None:
    guard_vars = _collect_guard_var_names(func)
    earliest: int | None = None
    for node in ast.walk(func):
        if isinstance(node, ast.Call) and _is_guard_sh_subprocess(node, guard_vars):
            if earliest is None or node.lineno < earliest:
                earliest = node.lineno
    return earliest


def _find_merge_call_linenos(func: ast.FunctionDef, tokens: tuple[str, ...]) -> list[int]:
    """Locate each subprocess.run(cmd, ...) where cmd resolves to a list containing tokens.

    Tokens may be present either as direct string literals in the list, or via
    Path()-assigned variables whose value source-text contains the token. We
    collect both kinds and treat any list element matching either as a hit.
    """
    token_var_names: set[str] = set()
    for node in ast.walk(func):
        if isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            value_src = ast.unparse(node.value)
            if all(t in value_src for t in tokens):
                token_var_names.add(node.targets[0].id)

    def list_matches(node: ast.AST) -> bool:
        if isinstance(node, (ast.List, ast.Tuple)):
            joined_strs: list[str] = []
            joined_names: list[str] = []
            for elt in node.elts:
                if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
                    joined_strs.append(elt.value)
                else:
                    joined_names.append(ast.unparse(elt))
            joined = " ".join(joined_strs)
            joined_var = " ".join(joined_names)
            combined = joined + " " + joined_var
            if all(t in combined for t in tokens):
                return True
            if any(name in joined_var for name in token_var_names):
                return True
        return False

    list_assignments: dict[str, int] = {}
    merge_lines: list[int] = []
    for node in ast.walk(func):
        if isinstance(node, ast.Assign):
            value = node.value
            if isinstance(value, (ast.List, ast.Tuple)) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
                if list_matches(value):
                    list_assignments[node.targets[0].id] = node.lineno
        if isinstance(node, ast.Call):
            if ast.unparse(node.func) != "subprocess.run" or not node.args:
                continue
            first_arg = node.args[0]
            if isinstance(first_arg, (ast.List, ast.Tuple)) and list_matches(first_arg):
                merge_lines.append(node.lineno)
            elif isinstance(first_arg, ast.Name) and first_arg.id in list_assignments:
                merge_lines.append(node.lineno)
    return sorted(merge_lines)


def verify_function(file_path: Path, func_name: str, merge_tokens: tuple[str, ...]) -> list[str]:
    errors: list[str] = []
    if not file_path.exists():
        return [f"{file_path}: file not found"]
    src = file_path.read_text(encoding="utf-8")
    tree = ast.parse(src)
    func = _find_function(tree, func_name)
    if func is None:
        return [f"{file_path}::{func_name}: function not found"]
    body = _strip_docstring(list(func.body))
    if len(body) < 3:
        return [f"{file_path}::{func_name}: body too short ({len(body)} stmts)"]
    if not _is_cancelled_path_assign(body[0]):
        errors.append(f"{file_path}::{func_name}: 첫 statement가 cancelled 마커 Path 할당이 아님 — got {ast.unparse(body[0])[:120]}")
    if not _is_cancelled_exists_guard(body[1]):
        errors.append(f"{file_path}::{func_name}: 두 번째 statement가 cancelled.exists() 가드+return/raise 가 아님 — got {ast.unparse(body[1])[:120]}")
    guard_lineno = _find_first_guard_sh_lineno(func)
    if guard_lineno is None:
        errors.append(f"{file_path}::{func_name}: guard.sh subprocess.run 호출 없음")
    merge_linenos = _find_merge_call_linenos(func, merge_tokens)
    if not merge_linenos:
        errors.append(f"{file_path}::{func_name}: merge subprocess({merge_tokens}) 호출 없음")
    if guard_lineno is not None and merge_linenos and guard_lineno >= merge_linenos[0]:
        errors.append(
            f"{file_path}::{func_name}: guard.sh line({guard_lineno})이 merge subprocess line({merge_linenos[0]})보다 늦음 — Lock-in 위반"
        )
    return errors


def main() -> int:
    parser = argparse.ArgumentParser(description="Lock-in First-line 가드 검증")
    parser.add_argument("--workspace", default=str(WORKSPACE))
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args()
    ws = Path(args.workspace).resolve()
    all_errors: list[str] = []
    for spec in CHECKS:
        target = ws / spec["file"]
        errs = verify_function(target, spec["func"], spec["merge_signature_tokens"])
        if errs:
            all_errors.extend(errs)
        elif not args.quiet:
            print(f"PASS  {spec['file']}::{spec['func']}")
    if all_errors:
        print("FAIL  Lock-in First-line 가드 위반:", file=sys.stderr)
        for e in all_errors:
            print(f"  - {e}", file=sys.stderr)
        return 1
    if not args.quiet:
        print("PASS  Lock-in First-line 가드 모든 함수 통과")
    return 0


if __name__ == "__main__":
    sys.exit(main())
