#!/usr/bin/env python3
"""
analyze_ab.py - haiku vs sonnet A/B 실험 분석 스크립트

Usage:
    python scripts/analyze_ab.py --input logs/ab_results.jsonl
"""

import argparse
import json
import sys
from collections import defaultdict
from typing import Optional


def load_results(path: str) -> list[dict]:
    """JSONL 파일에서 결과 로드."""
    results = []
    with open(path, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                results.append(json.loads(line))
            except json.JSONDecodeError as e:
                print(f"[WARN] Line {line_num} JSON 파싱 실패: {e}", file=sys.stderr)
    return results


def compute_fnr(results: list[dict], model: str) -> tuple[float, int]:
    """특정 모델의 FNR(False Negative Rate) 계산.

    Returns:
        (fnr, count) 튜플
    """
    model_results = [r for r in results if r.get("assigned_model") == model and not r.get("is_recheck", False)]
    if not model_results:
        return 0.0, 0
    fnr_values = [r.get("fnr", 0.0) for r in model_results]
    avg_fnr = sum(fnr_values) / len(fnr_values)
    return avg_fnr, len(model_results)


def fishers_exact_test(results: list[dict], threshold: float = 0.15) -> dict:
    """Fisher's exact test 수행.

    실험군(haiku)과 대조군(sonnet)의 FNR을 비교.
    FNR >= threshold이면 'fail', 그렇지 않으면 'pass'로 이진 분류.

    Returns:
        dict with p_value, odds_ratio, table, counts
    """
    try:
        from scipy.stats import fisher_exact
    except ImportError:
        return {
            "error": "scipy 미설치. pip install scipy 필요",
            "p_value": None,
            "verdict": "ERROR",
        }

    haiku_results = [r for r in results if r.get("assigned_model") == "haiku" and not r.get("is_recheck", False)]
    sonnet_results = [r for r in results if r.get("assigned_model") == "sonnet" and not r.get("is_recheck", False)]

    haiku_fail = sum(1 for r in haiku_results if r.get("fnr", 0.0) >= threshold)
    haiku_pass = len(haiku_results) - haiku_fail
    sonnet_fail = sum(1 for r in sonnet_results if r.get("fnr", 0.0) >= threshold)
    sonnet_pass = len(sonnet_results) - sonnet_fail

    table = [[haiku_pass, haiku_fail], [sonnet_pass, sonnet_fail]]

    try:
        odds_ratio_raw, p_value_raw = fisher_exact(table, alternative="two-sided")  # type: ignore[misc]
    except Exception as e:
        return {"error": str(e), "p_value": None, "verdict": "ERROR"}

    or_val: float = float(odds_ratio_raw)  # type: ignore[arg-type]
    pv_val: float = float(p_value_raw)  # type: ignore[arg-type]
    return {
        "table": table,
        "odds_ratio": round(or_val, 4) if or_val != float("inf") else "inf",
        "p_value": round(pv_val, 6),
        "haiku_n": len(haiku_results),
        "sonnet_n": len(sonnet_results),
    }


def determine_verdict(
    haiku_fnr: float,
    p_value: Optional[float],
    haiku_n: int,
    sonnet_n: int,
    alpha: float = 0.05,
    fnr_threshold: float = 0.15,
    min_sample: int = 150,
) -> dict:
    """판정 로직: 채택/기각/연장."""
    if haiku_n < min_sample or sonnet_n < min_sample:
        return {
            "verdict": "EXTEND",
            "reason": f"표본 부족: haiku={haiku_n}, sonnet={sonnet_n} (최소 {min_sample} 필요)",
        }

    if p_value is None:
        return {"verdict": "ERROR", "reason": "p-value 계산 실패"}

    if haiku_fnr < fnr_threshold and p_value < alpha:
        return {
            "verdict": "ADOPT",
            "reason": f"FNR={haiku_fnr:.3f} (<{fnr_threshold}) AND p={p_value:.6f} (<{alpha})",
        }

    return {
        "verdict": "REJECT",
        "reason": f"FNR={haiku_fnr:.3f} (threshold={fnr_threshold}), p={p_value:.6f} (alpha={alpha})",
    }


def stratification_check(results: list[dict]) -> dict:
    """층화 추출 균등 분배 검증."""
    level_counts: dict[str, dict[str, int]] = defaultdict(lambda: {"haiku": 0, "sonnet": 0})
    for r in results:
        if r.get("is_recheck", False):
            continue
        level = r.get("task_level", "unknown")
        model = r.get("assigned_model", "unknown")
        if model in ("haiku", "sonnet"):
            level_counts[level][model] += 1

    report = {}
    for level, counts in sorted(level_counts.items()):
        total = counts["haiku"] + counts["sonnet"]
        if total == 0:
            continue
        haiku_pct = counts["haiku"] / total * 100
        sonnet_pct = counts["sonnet"] / total * 100
        report[level] = {
            "haiku": counts["haiku"],
            "sonnet": counts["sonnet"],
            "total": total,
            "haiku_pct": round(haiku_pct, 1),
            "sonnet_pct": round(sonnet_pct, 1),
            "balanced": abs(haiku_pct - 50.0) <= 2.0,
        }
    return report


def main() -> int:
    parser = argparse.ArgumentParser(description="haiku vs sonnet A/B 실험 분석")
    parser.add_argument("--input", required=True, help="JSONL 입력 파일 경로")
    parser.add_argument("--threshold", type=float, default=0.15, help="FNR 임계값 (기본: 0.15)")
    parser.add_argument("--alpha", type=float, default=0.05, help="유의 수준 (기본: 0.05)")
    parser.add_argument("--min-sample", type=int, default=150, help="최소 표본 크기 (기본: 150)")
    args = parser.parse_args()

    results = load_results(args.input)
    if not results:
        print(json.dumps({"error": "결과 데이터 없음", "verdict": "ERROR"}, ensure_ascii=False, indent=2))
        return 1

    print(f"[AB] 총 {len(results)}건 로드", file=sys.stderr)

    haiku_fnr, haiku_n = compute_fnr(results, "haiku")
    sonnet_fnr, sonnet_n = compute_fnr(results, "sonnet")

    print(f"[AB] haiku FNR={haiku_fnr:.3f} (n={haiku_n})", file=sys.stderr)
    print(f"[AB] sonnet FNR={sonnet_fnr:.3f} (n={sonnet_n})", file=sys.stderr)

    fisher = fishers_exact_test(results, threshold=args.threshold)

    verdict = determine_verdict(
        haiku_fnr=haiku_fnr,
        p_value=fisher.get("p_value"),
        haiku_n=haiku_n,
        sonnet_n=sonnet_n,
        alpha=args.alpha,
        fnr_threshold=args.threshold,
        min_sample=args.min_sample,
    )

    strat = stratification_check(results)

    output = {
        "total_records": len(results),
        "haiku": {"fnr": round(haiku_fnr, 4), "n": haiku_n},
        "sonnet": {"fnr": round(sonnet_fnr, 4), "n": sonnet_n},
        "fisher_exact": fisher,
        "verdict": verdict,
        "stratification": strat,
    }

    print(json.dumps(output, ensure_ascii=False, indent=2))
    return 0


if __name__ == "__main__":
    sys.exit(main())
