#!/usr/bin/env python3
"""
health_score.py - Project health score calculator
A2 + A13 + A17 bundle
"""

import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Any

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

WEIGHTS: dict[str, float] = {
    "test_pass_rate": 0.20,
    "pyright_errors": 0.15,
    "code_coverage": 0.10,
    "tech_debt": 0.15,
    "security": 0.15,
    "documentation": 0.10,
    "deploy_stability": 0.15,
}

FIX_KEYWORDS = ("fix", "수정", "버그")
FIX_WARNING_THRESHOLD = 30.0

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", str(Path(__file__).resolve().parent.parent))
DEFAULT_BASELINE_PATH = str(Path(_WORKSPACE_ROOT) / "memory/whisper/qa-baseline.json")
DEFAULT_TASK_TIMERS_PATH = str(Path(_WORKSPACE_ROOT) / "memory/task-timers.json")

# ---------------------------------------------------------------------------
# Grade calculation
# ---------------------------------------------------------------------------


def score_to_grade(score: float) -> str:
    """Convert numeric score to letter grade."""
    if score >= 90:
        return "A"
    elif score >= 80:
        return "B"
    elif score >= 70:
        return "C"
    elif score >= 60:
        return "D"
    else:
        return "F"


# ---------------------------------------------------------------------------
# Weighted score
# ---------------------------------------------------------------------------


def calculate_weighted_score(categories: dict[str, dict[str, Any]]) -> float:
    """Compute weighted average score from category scores."""
    total = 0.0
    for key, weight in WEIGHTS.items():
        cat = categories.get(key, {})
        total += cat.get("score", 0) * weight
    return round(total, 2)


# ---------------------------------------------------------------------------
# Baseline I/O
# ---------------------------------------------------------------------------


def load_baseline(path: str) -> dict[str, Any] | None:
    """Load baseline JSON. Returns None if file is missing or invalid."""
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)  # type: ignore[no-any-return]
    except (FileNotFoundError, json.JSONDecodeError):
        return None


def save_baseline(result: dict[str, Any], path: str) -> None:
    """Save current result as baseline JSON."""
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=2)


# ---------------------------------------------------------------------------
# Task timers I/O
# ---------------------------------------------------------------------------


def load_task_timers(path: str) -> dict[str, Any]:
    """Load tasks from task-timers.json. Returns empty dict on error."""
    try:
        with open(path, "r", encoding="utf-8") as f:
            data: dict[str, Any] = json.load(f)
        return data.get("tasks", {})  # type: ignore[return-value]
    except (FileNotFoundError, json.JSONDecodeError):
        return {}


# ---------------------------------------------------------------------------
# Fix percentage (A17)
# ---------------------------------------------------------------------------


def _is_fix_task(description: str) -> bool:
    """Return True if the task description contains any fix keyword."""
    lower = description.lower()
    return any(kw in lower for kw in FIX_KEYWORDS)


def calculate_fix_pct(tasks: dict[str, Any]) -> dict[str, Any]:
    """
    Calculate fix task ratio from tasks dict.

    Returns:
        {"total": N, "fix_count": N, "pct": X.X, "warning": bool}
    """
    if not tasks:
        return {"total": 0, "fix_count": 0, "pct": 0.0, "warning": False}

    total = len(tasks)
    fix_count = sum(1 for t in tasks.values() if _is_fix_task(t.get("description", "")))
    pct = round(fix_count / total * 100, 1)
    warning = pct > FIX_WARNING_THRESHOLD
    return {"total": total, "fix_count": fix_count, "pct": pct, "warning": warning}


# ---------------------------------------------------------------------------
# Category collectors
# ---------------------------------------------------------------------------


def _run(cmd: list[str], cwd: str | None = None) -> tuple[int, str, str]:
    """Run subprocess, return (returncode, stdout, stderr)."""
    try:
        proc = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            cwd=cwd,
            timeout=60,
        )
        return proc.returncode, proc.stdout, proc.stderr
    except (FileNotFoundError, subprocess.TimeoutExpired) as e:
        return -1, "", str(e)


def collect_test_pass_rate(project_dir: str) -> dict[str, Any]:
    """Run pytest and compute pass rate score (0-100)."""
    rc, stdout, stderr = _run(
        ["python3", "-m", "pytest", "--tb=no", "-q", "--no-header"],
        cwd=project_dir,
    )
    if rc == -1:
        return {"score": 0, "details": f"pytest not available: {stderr}"}

    # Parse summary line: "X passed, Y failed" etc.
    passed = 0
    failed = 0
    for line in (stdout + stderr).splitlines():
        line_lower = line.lower()
        # e.g. "5 passed, 2 failed in 0.12s"
        parts = line_lower.split()
        for i, part in enumerate(parts):
            if part in ("passed", "passed,") and i > 0:
                try:
                    passed = int(parts[i - 1])
                except ValueError:
                    pass
            if part in ("failed", "failed,") and i > 0:
                try:
                    failed = int(parts[i - 1])
                except ValueError:
                    pass

    total = passed + failed
    if total == 0:
        return {"score": 0, "details": "No tests found or pytest failed to run"}

    pct = passed / total * 100
    score = int(round(pct))
    return {"score": score, "details": f"{passed}/{total} tests passed ({pct:.1f}%)"}


def collect_pyright_errors(project_dir: str) -> dict[str, Any]:
    """Run pyright and score based on error count."""
    rc, stdout, stderr = _run(["pyright", "--outputjson", project_dir])
    if rc == -1:
        return {"score": 0, "details": f"pyright not available: {stderr}"}

    # Try JSON output
    try:
        # pyright --outputjson prints JSON to stdout
        data = json.loads(stdout)
        error_count = data.get("summary", {}).get("errorCount", 0)
    except (json.JSONDecodeError, KeyError):
        # Fall back to counting "error" lines
        error_count = sum(1 for line in stdout.splitlines() if " error " in line.lower())

    if error_count == 0:
        score = 100
    elif error_count <= 5:
        score = 80
    elif error_count <= 10:
        score = 60
    elif error_count <= 20:
        score = 40
    elif error_count <= 50:
        score = 20
    else:
        score = 0

    return {"score": score, "details": f"{error_count} pyright error(s)"}


def collect_code_coverage(project_dir: str) -> dict[str, Any]:
    """Run pytest with coverage and extract coverage %."""
    rc, stdout, stderr = _run(
        ["python3", "-m", "pytest", "--cov=.", "--cov-report=term-missing", "--tb=no", "-q", "--no-header"],
        cwd=project_dir,
    )
    if rc == -1:
        return {"score": 0, "details": f"pytest-cov not available: {stderr}"}

    combined = stdout + stderr
    # Find "TOTAL ... XX%" line
    for line in combined.splitlines():
        if line.strip().startswith("TOTAL"):
            parts = line.split()
            for part in reversed(parts):
                if part.endswith("%"):
                    try:
                        cov = float(part.rstrip("%"))
                        score = int(round(cov))
                        return {"score": score, "details": f"Coverage: {cov:.1f}%"}
                    except ValueError:
                        pass

    return {"score": 0, "details": "Coverage data not found"}


def collect_tech_debt(project_dir: str) -> dict[str, Any]:
    """Check tech-debt.json if available, otherwise use radon/flake8 complexity."""
    # Check for tech-debt.json in the workspace memory
    debt_path = Path(project_dir) / "memory" / "tech-debt.json"
    workspace_debt = Path(_WORKSPACE_ROOT) / "memory/tech-debt.json"

    debt_data: dict[str, Any] | None = None
    used_path = ""

    for p in [debt_path, workspace_debt]:
        if p.exists():
            try:
                with open(p, "r", encoding="utf-8") as f:
                    debt_data = json.load(f)
                used_path = str(p)
                break
            except (json.JSONDecodeError, OSError):
                pass

    if debt_data is not None:
        # Count open/active debt items
        items = debt_data if isinstance(debt_data, list) else debt_data.get("items", [])
        if isinstance(items, list):
            open_count = sum(1 for item in items if isinstance(item, dict) and item.get("status", "open") == "open")
            total_count = len(items)
        else:
            open_count = 0
            total_count = 0

        if total_count == 0:
            score = 100
        else:
            closed_pct = (total_count - open_count) / total_count * 100
            score = int(round(closed_pct))
        return {"score": score, "details": f"Tech debt: {open_count} open items from {used_path}"}

    # Fallback: try flake8 for complexity
    rc, stdout, _ = _run(["python3", "-m", "flake8", "--max-complexity=10", "--select=C", project_dir])
    if rc == -1:
        return {"score": 50, "details": "Tech debt: no data available (flake8 not found)"}

    complexity_issues = len([line for line in stdout.splitlines() if line.strip()])
    if complexity_issues == 0:
        return {"score": 90, "details": "No high-complexity code detected"}
    elif complexity_issues <= 5:
        return {"score": 70, "details": f"{complexity_issues} high-complexity function(s)"}
    elif complexity_issues <= 15:
        return {"score": 50, "details": f"{complexity_issues} high-complexity function(s)"}
    else:
        return {"score": 30, "details": f"{complexity_issues} high-complexity function(s)"}


def collect_security(project_dir: str) -> dict[str, Any]:
    """Run bandit security scan."""
    rc, stdout, stderr = _run(
        ["python3", "-m", "bandit", "-r", "-f", "json", project_dir],
    )
    if rc == -1:
        return {"score": 0, "details": f"bandit not available: {stderr}"}

    try:
        data = json.loads(stdout)
        results = data.get("results", [])
        high = sum(1 for r in results if r.get("issue_severity", "").upper() == "HIGH")
        medium = sum(1 for r in results if r.get("issue_severity", "").upper() == "MEDIUM")
        low = sum(1 for r in results if r.get("issue_severity", "").upper() == "LOW")
    except (json.JSONDecodeError, KeyError):
        return {"score": 0, "details": "bandit output parse error"}

    if high == 0 and medium == 0 and low == 0:
        score = 100
    elif high == 0 and medium == 0:
        score = 85
    elif high == 0:
        score = 65
    elif high <= 2:
        score = 40
    else:
        score = 10

    return {"score": score, "details": f"Security: {high} high, {medium} medium, {low} low issues"}


def collect_documentation(project_dir: str) -> dict[str, Any]:
    """Estimate documentation coverage by checking docstrings in Python files."""
    py_files = list(Path(project_dir).rglob("*.py"))
    if not py_files:
        return {"score": 0, "details": "No Python files found"}

    import ast

    total_funcs = 0
    documented = 0
    for py_file in py_files:
        try:
            source = py_file.read_text(encoding="utf-8")
            tree = ast.parse(source)
        except (SyntaxError, OSError):
            continue
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                total_funcs += 1
                if ast.get_docstring(node):
                    documented += 1

    if total_funcs == 0:
        return {"score": 0, "details": "No functions/classes found"}

    pct = documented / total_funcs * 100
    score = int(round(pct))
    return {"score": score, "details": f"Documentation: {documented}/{total_funcs} functions/classes ({pct:.1f}%)"}


def collect_deploy_stability(project_dir: str) -> dict[str, Any]:
    """
    Estimate deploy stability from CI logs or git history.
    Falls back to checking for common CI config files.
    """
    # Check for CI config files as a proxy
    ci_files = [
        ".github/workflows",
        ".gitlab-ci.yml",
        "Jenkinsfile",
        ".circleci/config.yml",
        "azure-pipelines.yml",
    ]
    ci_present = any((Path(project_dir) / cf).exists() for cf in ci_files)

    # Check for recent git log (look for "revert" or "hotfix" keywords)
    rc, stdout, _ = _run(
        ["git", "log", "--oneline", "-50", "--format=%s"],
        cwd=project_dir,
    )

    if rc != 0 or not stdout.strip():
        # No git history available
        if ci_present:
            return {"score": 70, "details": "CI config present; no git history available"}
        return {"score": 50, "details": "No CI config or git history found"}

    commits = stdout.strip().splitlines()
    total = len(commits)
    bad_commits = sum(
        1
        for c in commits
        if any(kw in c.lower() for kw in ("revert", "hotfix", "rollback", "emergency", "critical fix"))
    )
    bad_pct = bad_commits / total * 100 if total > 0 else 0

    if bad_pct == 0:
        score = 100
    elif bad_pct <= 5:
        score = 85
    elif bad_pct <= 10:
        score = 70
    elif bad_pct <= 20:
        score = 50
    else:
        score = 30

    return {
        "score": score,
        "details": f"Deploy stability: {bad_commits}/{total} problematic commits ({bad_pct:.1f}%)",
    }


def collect_all_categories(project_dir: str) -> dict[str, dict[str, Any]]:
    """Collect all 7 category scores."""
    return {
        "test_pass_rate": collect_test_pass_rate(project_dir),
        "pyright_errors": collect_pyright_errors(project_dir),
        "code_coverage": collect_code_coverage(project_dir),
        "tech_debt": collect_tech_debt(project_dir),
        "security": collect_security(project_dir),
        "documentation": collect_documentation(project_dir),
        "deploy_stability": collect_deploy_stability(project_dir),
    }


# ---------------------------------------------------------------------------
# Build result (A13 baseline comparison)
# ---------------------------------------------------------------------------


def build_result(
    categories: dict[str, dict[str, Any]],
    baseline: dict[str, Any] | None,
    fix_pct_result: dict[str, Any] | None,
) -> dict[str, Any]:
    """
    Assemble the final result dict.
    Adds baseline delta/direction for each category if baseline is provided.
    """
    score = int(round(calculate_weighted_score(categories)))
    grade = score_to_grade(score)

    # Build category output with optional baseline comparison
    output_categories: dict[str, dict[str, Any]] = {}
    baseline_cats: dict[str, Any] = {}
    if baseline is not None:
        baseline_cats = baseline.get("categories", {})

    for key, data in categories.items():
        cat_out: dict[str, Any] = dict(data)
        if baseline_cats and key in baseline_cats:
            prev_score = baseline_cats[key].get("score", data.get("score", 0))
            curr_score = data.get("score", 0)
            delta = curr_score - prev_score
            if delta > 0:
                direction = "improved"
            elif delta < 0:
                direction = "degraded"
            else:
                direction = "stable"
            cat_out["delta"] = delta
            cat_out["direction"] = direction
        output_categories[key] = cat_out

    result: dict[str, Any] = {
        "score": score,
        "grade": grade,
        "categories": output_categories,
    }

    if fix_pct_result is not None:
        result["fix_pct"] = fix_pct_result

    return result


# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Project health score calculator (A2 + A13 + A17)",
    )
    parser.add_argument(
        "--project-dir",
        default=".",
        help="Project directory to analyze (default: current directory)",
    )
    parser.add_argument(
        "--baseline",
        default=DEFAULT_BASELINE_PATH,
        help=f"Baseline JSON file path (default: {DEFAULT_BASELINE_PATH})",
    )
    parser.add_argument(
        "--save-baseline",
        action="store_true",
        help="Save current result as baseline",
    )
    parser.add_argument(
        "--task-timers",
        default=DEFAULT_TASK_TIMERS_PATH,
        help=f"task-timers.json path (default: {DEFAULT_TASK_TIMERS_PATH})",
    )
    parser.add_argument(
        "--output",
        default=None,
        help="Output JSON file path (default: stdout)",
    )
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)

    project_dir = str(Path(args.project_dir).resolve())

    # Collect categories
    categories = collect_all_categories(project_dir)

    # Load baseline
    baseline = load_baseline(args.baseline)

    # Fix pct (A17)
    tasks = load_task_timers(args.task_timers)
    fix_pct_result = calculate_fix_pct(tasks)

    # Build result
    result = build_result(categories, baseline=baseline, fix_pct_result=fix_pct_result)

    # Save baseline if requested
    if args.save_baseline:
        save_baseline(result, args.baseline)

    # Output
    output_json = json.dumps(result, ensure_ascii=False, indent=2)
    if args.output:
        out_path = Path(args.output)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        out_path.write_text(output_json, encoding="utf-8")
    else:
        print(output_json)

    # Print fix_pct warning to stderr if needed
    if fix_pct_result.get("warning"):
        pct = fix_pct_result["pct"]
        count = fix_pct_result["fix_count"]
        total = fix_pct_result["total"]
        print(
            f"[WARNING] Fix task ratio is {pct}% ({count}/{total}) - exceeds 30% threshold",
            file=sys.stderr,
        )

    return 0


if __name__ == "__main__":
    sys.exit(main())
