#!/usr/bin/env python3
"""Memory violation detector (task-2419 Fix 2).

memory/specs/memory-violation-rules.yaml 기반으로 메모리 위반 자동 감지.

사용:
  python3 scripts/memory_violation_detector.py --task-id task-1234
  python3 scripts/memory_violation_detector.py --commit abc1234
  python3 scripts/memory_violation_detector.py --diff-base main
  python3 scripts/memory_violation_detector.py --staged

종료 코드:
  0: 위반 0건
  1: 위반 발견
  2: 룰 spec 또는 입력 에러
"""
from __future__ import annotations

import argparse
import json
import os
import re
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any

WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace"))
RULES_PATH = WORKSPACE / "memory" / "specs" / "memory-violation-rules.yaml"
LOG_PATH = WORKSPACE / "memory" / "memory-check-log.json"
TIMERS_PATH = WORKSPACE / "memory" / "task-timers.json"


@dataclass
class Rule:
    id: str
    name: str
    memory_ref: str
    target: str
    severity: str
    description: str
    pattern_regex: str | None = None
    file_pattern: str | None = None
    log_field: str | None = None
    expected_value: Any = None


@dataclass
class Violation:
    rule_id: str
    rule_name: str
    location: str
    matched_content: str
    severity: str


def load_rules(path: Path = RULES_PATH) -> list[Rule]:
    """YAML 룰 spec 로드."""
    try:
        import yaml  # PyYAML
    except ImportError:
        print("ERROR: PyYAML 필요. pip install pyyaml", file=sys.stderr)
        sys.exit(2)

    if not path.exists():
        print(f"ERROR: 룰 spec 없음: {path}", file=sys.stderr)
        sys.exit(2)

    with open(path, encoding="utf-8") as f:
        data = yaml.safe_load(f)

    rules = []
    for r in data.get("rules", []):
        rules.append(Rule(
            id=r["id"],
            name=r["name"],
            memory_ref=r.get("memory_ref", ""),
            target=r["target"],
            severity=r.get("severity", "medium"),
            description=r.get("description", ""),
            pattern_regex=r.get("pattern_regex"),
            file_pattern=r.get("file_pattern"),
            log_field=r.get("log_field"),
            expected_value=r.get("expected_value"),
        ))
    return rules


def get_commit_messages(
    commit_sha: str | None = None,
    diff_base: str | None = None,
) -> list[tuple[str, str]]:
    """commit (sha, message) 목록 반환."""
    if commit_sha:
        cmd = ["git", "log", "-1", "--pretty=format:%H%x00%s%x00%b%x00END", commit_sha]
    elif diff_base:
        cmd = ["git", "log", f"{diff_base}..HEAD", "--pretty=format:%H%x00%s%x00%b%x00END"]
    else:
        return []
    try:
        out = subprocess.run(
            cmd, capture_output=True, text=True, check=True, cwd=str(WORKSPACE)
        ).stdout
    except Exception:
        return []

    results: list[tuple[str, str]] = []
    for entry in out.split("END\n"):
        entry = entry.strip()
        if not entry:
            continue
        parts = entry.split("\0")
        if len(parts) >= 2:
            sha = parts[0].strip()
            subject = parts[1].strip()
            body = parts[2].strip() if len(parts) > 2 else ""
            results.append((sha, f"{subject}\n{body}"))
    return results


def get_changed_files(
    commit_sha: str | None = None,
    diff_base: str | None = None,
    staged: bool = False,
) -> list[str]:
    """변경된 파일 목록 반환."""
    if staged:
        cmd = ["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"]
    elif commit_sha:
        cmd = ["git", "show", "--name-only", "--pretty=format:", commit_sha]
    elif diff_base:
        cmd = ["git", "diff", f"{diff_base}..HEAD", "--name-only", "--diff-filter=ACM"]
    else:
        return []
    try:
        out = subprocess.run(
            cmd, capture_output=True, text=True, check=True, cwd=str(WORKSPACE)
        ).stdout
        return [f.strip() for f in out.splitlines() if f.strip()]
    except Exception:
        return []


def get_log_entries_for_task(task_id: str) -> list[dict[str, Any]]:
    """memory-check-log.json에서 task_id에 해당하는 entries."""
    if not LOG_PATH.exists():
        return []
    try:
        with open(LOG_PATH, encoding="utf-8") as f:
            data = json.load(f)
        return [e for e in data.get("checks", []) if e.get("task_id") == task_id]
    except Exception:
        return []


def check_commit_message_rule(
    rule: Rule, commits: list[tuple[str, str]]
) -> list[Violation]:
    if not rule.pattern_regex:
        return []
    pat = re.compile(rule.pattern_regex, re.MULTILINE)
    violations: list[Violation] = []
    for sha, msg in commits:
        if pat.search(msg):
            violations.append(Violation(
                rule_id=rule.id,
                rule_name=rule.name,
                location=f"commit:{sha[:8]}",
                matched_content=msg.split("\n")[0][:80],
                severity=rule.severity,
            ))
    return violations


def check_changed_file_rule(rule: Rule, files: list[str]) -> list[Violation]:
    if not rule.pattern_regex:
        return []
    pat = re.compile(rule.pattern_regex)
    violations: list[Violation] = []
    for fpath in files:
        if rule.file_pattern:
            import fnmatch
            if not fnmatch.fnmatch(os.path.basename(fpath), rule.file_pattern):
                continue
        full = WORKSPACE / fpath
        if not full.exists():
            continue
        try:
            content = full.read_text(encoding="utf-8", errors="ignore")
        except Exception:
            continue
        for i, line in enumerate(content.splitlines(), 1):
            if pat.search(line):
                violations.append(Violation(
                    rule_id=rule.id,
                    rule_name=rule.name,
                    location=f"{fpath}:{i}",
                    matched_content=line.strip()[:80],
                    severity=rule.severity,
                ))
                break
    return violations


def check_log_entry_rule(
    rule: Rule, entries: list[dict[str, Any]]
) -> list[Violation]:
    if not rule.log_field:
        return []
    violations: list[Violation] = []
    for e in entries:
        actual = e.get(rule.log_field)
        if actual != rule.expected_value:
            violations.append(Violation(
                rule_id=rule.id,
                rule_name=rule.name,
                location=f"log:{e.get('mc_id', '?')}/{e.get('task_id', '?')}",
                matched_content=f"{rule.log_field}={actual} (expected={rule.expected_value})",
                severity=rule.severity,
            ))
    return violations


def main() -> int:
    parser = argparse.ArgumentParser(description="Memory violation detector")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--task-id", help="task ID 기반 검사 (memory-check-log.json)")
    group.add_argument("--commit", help="특정 commit SHA 검사")
    group.add_argument("--diff-base", help="ref..HEAD 범위 검사")
    group.add_argument("--staged", action="store_true", help="staged 변경 검사")
    args = parser.parse_args()

    rules = load_rules()
    if not rules:
        print("WARN: 룰 0건. 룰 spec 작성 필요.", file=sys.stderr)
        return 0

    commits: list[tuple[str, str]] = []
    files: list[str] = []
    log_entries: list[dict[str, Any]] = []

    if args.task_id:
        log_entries = get_log_entries_for_task(args.task_id)
    elif args.commit:
        commits = get_commit_messages(commit_sha=args.commit)
        files = get_changed_files(commit_sha=args.commit)
    elif args.diff_base:
        commits = get_commit_messages(diff_base=args.diff_base)
        files = get_changed_files(diff_base=args.diff_base)
    elif args.staged:
        files = get_changed_files(staged=True)

    all_violations: list[Violation] = []
    for rule in rules:
        if rule.target == "commit_messages":
            all_violations.extend(check_commit_message_rule(rule, commits))
        elif rule.target == "changed_files":
            all_violations.extend(check_changed_file_rule(rule, files))
        elif rule.target == "log_entries":
            all_violations.extend(check_log_entry_rule(rule, log_entries))

    if not all_violations:
        print("[memory-violation-detector] 위반 0건. PASS.")
        return 0

    print(
        f"\n[memory-violation-detector] 위반 {len(all_violations)}건 발견:\n",
        file=sys.stderr,
    )
    for v in all_violations:
        print(f"  [{v.severity.upper()}] {v.rule_id} ({v.rule_name})", file=sys.stderr)
        print(f"    위치: {v.location}", file=sys.stderr)
        print(f"    내용: {v.matched_content}", file=sys.stderr)
        print(file=sys.stderr)
    return 1


if __name__ == "__main__":
    sys.exit(main())
