"""canonical_workspace_resolver — deterministic workspace snapshot for 6 hooks."""
from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional

# ---------------------------------------------------------------------------
# Type alias for injectable runner (enables test stubbing)
# ---------------------------------------------------------------------------
RunnerType = Callable[..., subprocess.CompletedProcess]

_DEFAULT_WORKSPACE = Path("/home/jay/workspace")

_ALLOWED_HOOKS = frozenset(
    {"scope-guard", "finish-task", "guard", "smoke", "qc", "merge_execution"}
)


# ---------------------------------------------------------------------------
# 1. CanonicalWorkspace (frozen dataclass)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class CanonicalWorkspace:
    """Immutable snapshot of the resolved workspace state for a single task."""

    task_id: str
    workspace_root: Path
    worktree_path: Path
    branch_name: str
    main_head_sha: str
    base_sha: str
    cwd: Path
    is_main: bool
    is_clean: bool


# ---------------------------------------------------------------------------
# JSON helpers (round-trip: Path ↔ str)
# ---------------------------------------------------------------------------

def to_json(ws: CanonicalWorkspace) -> str:
    """Serialize CanonicalWorkspace to JSON string (Path → str)."""
    return json.dumps(
        {
            "task_id": ws.task_id,
            "workspace_root": str(ws.workspace_root),
            "worktree_path": str(ws.worktree_path),
            "branch_name": ws.branch_name,
            "main_head_sha": ws.main_head_sha,
            "base_sha": ws.base_sha,
            "cwd": str(ws.cwd),
            "is_main": ws.is_main,
            "is_clean": ws.is_clean,
        },
        indent=2,
    )


def from_dict(d: dict) -> CanonicalWorkspace:
    """Deserialize dict → CanonicalWorkspace (str → Path)."""
    return CanonicalWorkspace(
        task_id=d["task_id"],
        workspace_root=Path(d["workspace_root"]),
        worktree_path=Path(d["worktree_path"]),
        branch_name=d["branch_name"],
        main_head_sha=d["main_head_sha"],
        base_sha=d["base_sha"],
        cwd=Path(d["cwd"]),
        is_main=bool(d["is_main"]),
        is_clean=bool(d["is_clean"]),
    )


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _default_runner(
    args: list[str],
    *,
    cwd: Optional[str] = None,
) -> subprocess.CompletedProcess:
    """Default subprocess runner with timeout=30."""
    return subprocess.run(
        args,
        cwd=cwd,
        capture_output=True,
        text=True,
        timeout=30,
    )


def _run(
    args: list[str],
    *,
    cwd: Optional[str | Path] = None,
    runner: Optional[RunnerType] = None,
) -> subprocess.CompletedProcess:
    """Invoke runner with normalised cwd."""
    fn = runner if runner is not None else _default_runner
    return fn(args, cwd=str(cwd) if cwd is not None else None)


def _parse_worktree_list(output: str) -> list[dict]:
    """Parse `git worktree list --porcelain` output into list of dicts."""
    entries: list[dict] = []
    current: dict = {}
    for line in output.splitlines():
        line = line.strip()
        if line.startswith("worktree "):
            if current:
                entries.append(current)
            current = {"path": line[len("worktree "):]}
        elif line.startswith("HEAD "):
            current["head"] = line[len("HEAD "):]
        elif line.startswith("branch "):
            current["branch"] = line[len("branch "):]
        elif line == "bare":
            current["bare"] = True
    if current:
        entries.append(current)
    return entries


# ---------------------------------------------------------------------------
# 2. resolve_canonical_workspace
# ---------------------------------------------------------------------------

def resolve_canonical_workspace(
    task_id: str,
    *,
    cwd: Optional[Path] = None,
    fetch: bool = True,
    runner: Optional[RunnerType] = None,
) -> CanonicalWorkspace:
    """Resolve and lock a deterministic CanonicalWorkspace for task_id."""

    # Step 1: normalise cwd
    effective_cwd: Path = Path(cwd).resolve() if cwd is not None else Path.cwd().resolve()

    # Step 2: git rev-parse --show-toplevel (works inside main or worktree)
    r = _run(
        ["git", "rev-parse", "--show-toplevel"],
        cwd=effective_cwd,
        runner=runner,
    )
    git_toplevel: Optional[Path] = None
    if r.returncode == 0 and r.stdout.strip():
        git_toplevel = Path(r.stdout.strip()).resolve()

    # Step 3: determine workspace_root
    # Priority: git worktree list first entry (main worktree) → git show-toplevel → env → default
    # NOTE: git rev-parse --show-toplevel returns the worktree path (not main) when run
    # from a linked worktree. To reliably get the main repo root, we use
    # `git worktree list --porcelain` and take the FIRST entry (always the main worktree).

    # Use show-toplevel's cwd as seed for initial worktree list call
    _seed_cwd = git_toplevel if git_toplevel is not None else effective_cwd

    wt_r = _run(
        ["git", "worktree", "list", "--porcelain"],
        cwd=_seed_cwd,
        runner=runner,
    )

    worktree_path: Optional[Path] = None
    branch_name: str = ""
    git_main_root: Optional[Path] = None

    if wt_r.returncode == 0:
        entries = _parse_worktree_list(wt_r.stdout)
        # First entry is always the main worktree (git spec)
        if entries:
            git_main_root = Path(entries[0].get("path", "")).resolve()
        # Find matching worktree for task_id
        for entry in entries:
            entry_path = Path(entry.get("path", "")).resolve()
            if task_id in str(entry_path):
                worktree_path = entry_path
                raw_branch = entry.get("branch", "")
                if raw_branch.startswith("refs/heads/"):
                    branch_name = raw_branch[len("refs/heads/"):]
                else:
                    branch_name = raw_branch
                break

    # workspace_root: git main worktree root is authoritative (spec §1)
    env_root_str = os.environ.get("WORKSPACE_ROOT", "").strip()
    env_root: Optional[Path] = Path(env_root_str).resolve() if env_root_str else None

    if git_main_root is not None:
        workspace_root = git_main_root
        # env_root is ignored when git result is available (spec §2 "git 우선")
    elif git_toplevel is not None:
        workspace_root = git_toplevel
    elif env_root is not None:
        workspace_root = env_root
    else:
        workspace_root = _DEFAULT_WORKSPACE

    if worktree_path is None:
        # Derive expected path without creating it (caller's responsibility)
        import glob as _glob
        pattern = str(workspace_root / ".worktrees" / f"{task_id}-*")
        candidates = sorted(_glob.glob(pattern))
        if candidates:
            worktree_path = Path(candidates[0]).resolve()
        else:
            worktree_path = (workspace_root / ".worktrees" / f"{task_id}-dev1").resolve()

    # Step 5: branch_name fallback (if not found via worktree list)
    if not branch_name:
        br_r = _run(
            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
            cwd=effective_cwd,
            runner=runner,
        )
        if br_r.returncode == 0:
            branch_name = br_r.stdout.strip()

    # Step 6: fetch + lock main_head_sha / base_sha
    if fetch:
        _run(
            ["git", "fetch", "origin", "main", "--quiet"],
            cwd=workspace_root,
            runner=runner,
        )

    sha_r = _run(
        ["git", "rev-parse", "origin/main"],
        cwd=workspace_root,
        runner=runner,
    )
    if sha_r.returncode != 0 or not sha_r.stdout.strip():
        raise RuntimeError(f"FAILED_TO_RESOLVE_ORIGIN_MAIN: {sha_r.stderr.strip()!r}")
    main_head_sha = sha_r.stdout.strip()
    base_sha = main_head_sha  # always identical per spec

    # Step 7: is_main
    is_main: bool = effective_cwd == workspace_root

    # Step 8: is_clean placeholder (scope-aware dirty via evaluate_scope_dirty)
    is_clean: bool = True

    return CanonicalWorkspace(
        task_id=task_id,
        workspace_root=workspace_root,
        worktree_path=worktree_path,
        branch_name=branch_name,
        main_head_sha=main_head_sha,
        base_sha=base_sha,
        cwd=effective_cwd,
        is_main=is_main,
        is_clean=is_clean,
    )


# ---------------------------------------------------------------------------
# 3. assert_cwd_matches_workspace
# ---------------------------------------------------------------------------

def assert_cwd_matches_workspace(ws: CanonicalWorkspace) -> None:
    """Raise RuntimeError if ws.cwd is neither worktree_path nor workspace_root."""
    if ws.cwd != ws.worktree_path and ws.cwd != ws.workspace_root:
        raise RuntimeError(
            f"WRONG_CWD: cwd={ws.cwd}, expected={ws.worktree_path} or {ws.workspace_root}"
        )


# ---------------------------------------------------------------------------
# 4. assert_main_fresh
# ---------------------------------------------------------------------------

def assert_main_fresh(
    ws: CanonicalWorkspace,
    *,
    runner: Optional[RunnerType] = None,
) -> None:
    """Re-verify origin/main HEAD; raise RuntimeError if SHA drifted since resolve."""
    sha_r = _run(
        ["git", "rev-parse", "origin/main"],
        cwd=ws.workspace_root,
        runner=runner,
    )
    if sha_r.returncode != 0:
        raise RuntimeError(
            f"STALE_MAIN: failed to re-check origin/main: {sha_r.stderr!r}"
        )
    current_sha = sha_r.stdout.strip()
    if current_sha != ws.main_head_sha:
        raise RuntimeError(
            f"STALE_MAIN: locked={ws.main_head_sha}, current={current_sha}"
        )


# ---------------------------------------------------------------------------
# 5. evaluate_scope_dirty
# ---------------------------------------------------------------------------

def evaluate_scope_dirty(
    ws: CanonicalWorkspace,
    expected_files: list[str],
    *,
    runner: Optional[RunnerType] = None,
) -> bool:
    """Return True (dirty) if any expected_files have uncommitted changes; False (clean) otherwise."""
    if not expected_files:
        return False
    args = ["git", "status", "--porcelain", "--"] + list(expected_files)
    r = _run(args, cwd=ws.worktree_path, runner=runner)
    if r.returncode != 0:
        # Fallback: try from workspace_root
        r = _run(args, cwd=ws.workspace_root, runner=runner)
    if r.returncode != 0:
        raise RuntimeError(f"GIT_STATUS_FAILED: {r.stderr.strip()!r}")
    output = r.stdout.strip()
    return bool(output)


# ---------------------------------------------------------------------------
# 6. assert_finish_task_context
# ---------------------------------------------------------------------------

def assert_finish_task_context(
    ws: CanonicalWorkspace,
    finish_target: dict,
    *,
    runner: Optional[RunnerType] = None,
) -> None:
    """Verify HEAD SHA, branch_name, worktree_path match finish_target; raise on mismatch."""
    # Resolve current HEAD from ws.cwd (worktree)
    head_r = _run(
        ["git", "rev-parse", "HEAD"],
        cwd=ws.cwd,
        runner=runner,
    )
    if head_r.returncode != 0:
        raise RuntimeError(f"GIT_REV_PARSE_HEAD_FAILED: {head_r.stderr.strip()!r}")
    current_head_sha = head_r.stdout.strip()

    expected_head = finish_target.get("head_sha", "")
    expected_branch = finish_target.get("branch_name", "")
    expected_wt = Path(str(finish_target.get("worktree_path", ""))).resolve()

    mismatches: list[str] = []
    if current_head_sha != expected_head:
        mismatches.append(
            f"head_sha: current={current_head_sha!r} != expected={expected_head!r}"
        )
    if ws.branch_name != expected_branch:
        mismatches.append(
            f"branch_name: ws={ws.branch_name!r} != expected={expected_branch!r}"
        )
    if ws.worktree_path.resolve() != expected_wt:
        mismatches.append(
            f"worktree_path: ws={ws.worktree_path} != expected={expected_wt}"
        )
    if mismatches:
        raise RuntimeError(
            "FINISH_TASK_CONTEXT_MISMATCH: " + "; ".join(mismatches)
        )


# ---------------------------------------------------------------------------
# 7. resolve_for_hooks
# ---------------------------------------------------------------------------

def resolve_for_hooks(
    task_id: str,
    hook_name: str,
    *,
    cwd: Optional[Path] = None,
    fetch: bool = True,
    runner: Optional[RunnerType] = None,
) -> CanonicalWorkspace:
    """Unified entrypoint for all 6 hook types; validates hook_name then resolves."""
    if hook_name not in _ALLOWED_HOOKS:
        raise ValueError(
            f"UNKNOWN_HOOK: {hook_name}, allowed={sorted(_ALLOWED_HOOKS)}"
        )
    return resolve_canonical_workspace(
        task_id,
        cwd=cwd,
        fetch=fetch,
        runner=runner,
    )


# ---------------------------------------------------------------------------
# 8. CLI entrypoint
# ---------------------------------------------------------------------------

def _build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="Resolve and inspect CanonicalWorkspace for a task.",
    )
    p.add_argument("--task-id", required=True, help="task-NNNN[+M] identifier")
    p.add_argument("--json", action="store_true", dest="output_json", help="Print JSON output")
    p.add_argument("--assert-cwd", action="store_true", help="Assert cwd matches workspace")
    p.add_argument("--assert-fresh", action="store_true", help="Assert origin/main is not stale")
    p.add_argument("--scope", metavar="FILE1,FILE2,...", help="Evaluate scope dirty (comma-separated)")
    p.add_argument("--no-fetch", action="store_true", help="Skip git fetch (test mode)")
    return p


if __name__ == "__main__":
    parser = _build_parser()
    args = parser.parse_args()

    fetch_flag = not args.no_fetch

    try:
        ws = resolve_canonical_workspace(args.task_id, fetch=fetch_flag)
    except Exception as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        sys.exit(1)

    if args.assert_cwd:
        try:
            assert_cwd_matches_workspace(ws)
        except RuntimeError as exc:
            print(f"ERROR: {exc}", file=sys.stderr)
            sys.exit(2)

    if args.assert_fresh:
        try:
            assert_main_fresh(ws)
        except RuntimeError as exc:
            print(f"ERROR: {exc}", file=sys.stderr)
            sys.exit(3)

    if args.scope:
        files = [f.strip() for f in args.scope.split(",") if f.strip()]
        dirty = evaluate_scope_dirty(ws, files)
        if dirty:
            print(f"SCOPE_DIRTY: {files}", file=sys.stderr)
            sys.exit(4)

    if args.output_json:
        print(to_json(ws))
    else:
        print(f"task_id      : {ws.task_id}")
        print(f"workspace_root: {ws.workspace_root}")
        print(f"worktree_path : {ws.worktree_path}")
        print(f"branch_name  : {ws.branch_name}")
        print(f"main_head_sha: {ws.main_head_sha}")
        print(f"base_sha     : {ws.base_sha}")
        print(f"cwd          : {ws.cwd}")
        print(f"is_main      : {ws.is_main}")
        print(f"is_clean     : {ws.is_clean}")
