"""
GTX 1060 6GB faster-whisper GPU 최적화 벤치마크 스크립트
Task: 932.2 | Engineer: Kartikeya (dev4)
GPU: GTX 1060 6GB, CC 6.1 | CUDA 12.2
CTranslate2 지원 compute types: int8, int8_float32, float32 (float16 미지원)
"""

import gc
import json
import subprocess
import sys
import time
from datetime import datetime

TEST_AUDIO = "/home/jay/workspace/teams/dev4/task-932.2/test_5min.wav"
RESULTS_FILE = "/home/jay/workspace/teams/dev4/task-932.2/benchmark_results.json"

# 테스트할 설정 조합
CONFIGS = [
    {"device": "cuda", "compute_type": "float32",      "label": "GPU + float32"},
    {"device": "cuda", "compute_type": "int8_float32", "label": "GPU + int8_float32 (혼합)"},
    {"device": "cuda", "compute_type": "int8",         "label": "GPU + int8"},
    {"device": "cpu",  "compute_type": "int8",         "label": "CPU + int8 (baseline)"},
    {"device": "cuda", "compute_type": "float16",      "label": "GPU + float16 (에러 캡처용)"},
]

# 테스트할 모델 크기 (VRAM 6GB 제약 고려 — 작은 것부터)
MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v2"]

TRANSCRIBE_KWARGS = {
    "language": "ko",
    "beam_size": 5,
}


def log(msg: str) -> None:
    """타임스탬프와 함께 stdout 출력"""
    ts = datetime.now().strftime("%H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)


def get_vram_used_mb() -> int | None:
    """nvidia-smi 로 현재 VRAM 사용량(MB) 반환. 실패 시 None."""
    try:
        out = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"],
            stderr=subprocess.DEVNULL,
            text=True,
        )
        return int(out.strip().split("\n")[0])
    except Exception as e:
        log(f"  [WARN] nvidia-smi 쿼리 실패: {e}")
        return None


def free_gpu_memory(model=None) -> None:
    """GPU 메모리 해제"""
    if model is not None:
        del model
    gc.collect()
    try:
        import torch
        torch.cuda.empty_cache()
    except ImportError:
        pass


def run_single_test(model_size: str, device: str, compute_type: str, label: str) -> dict:
    """
    단일 (model_size, device, compute_type) 조합 테스트.
    에러가 발생해도 dict를 반환하며, 다음 테스트는 계속 진행됨.
    """
    from faster_whisper import WhisperModel

    result = {
        "model_size": model_size,
        "device": device,
        "compute_type": compute_type,
        "label": label,
        "load_time_sec": None,
        "transcribe_time_sec": None,
        "vram_after_load_mb": None,
        "vram_after_transcribe_mb": None,
        "transcript_preview": None,
        "segment_count": None,
        "error": None,
        "status": "pending",
    }

    model = None
    try:
        # ── 모델 로딩 ──────────────────────────────────────────────
        log(f"  로딩 중: {model_size} / {device} / {compute_type} ...")
        vram_before = get_vram_used_mb()

        t0 = time.perf_counter()
        model = WhisperModel(model_size, device=device, compute_type=compute_type)
        load_time = time.perf_counter() - t0

        vram_after_load = get_vram_used_mb()
        result["load_time_sec"] = round(load_time, 3)
        result["vram_after_load_mb"] = vram_after_load
        log(f"  로딩 완료: {load_time:.2f}s | VRAM {vram_before}MB → {vram_after_load}MB")

        # ── 전사 ───────────────────────────────────────────────────
        log(f"  전사 시작 ({TEST_AUDIO}) ...")
        t1 = time.perf_counter()
        segments, info = model.transcribe(TEST_AUDIO, **TRANSCRIBE_KWARGS)

        # segments 는 generator — 실제로 소비해야 시간이 측정됨
        collected = list(segments)
        transcribe_time = time.perf_counter() - t1

        vram_after_transcribe = get_vram_used_mb()
        result["transcribe_time_sec"] = round(transcribe_time, 3)
        result["vram_after_transcribe_mb"] = vram_after_transcribe
        result["segment_count"] = len(collected)

        # 전사 결과 앞부분 200자
        full_text = " ".join(seg.text for seg in collected)
        result["transcript_preview"] = full_text[:200]

        log(f"  전사 완료: {transcribe_time:.2f}s | 세그먼트 {len(collected)}개 | VRAM {vram_after_transcribe}MB")
        log(f"  미리보기: {full_text[:100]!r}")

        result["status"] = "success"

    except Exception as e:
        err_msg = f"{type(e).__name__}: {e}"
        result["error"] = err_msg
        result["status"] = "error"
        log(f"  [ERROR] {err_msg}")

    finally:
        free_gpu_memory(model)
        vram_freed = get_vram_used_mb()
        log(f"  메모리 해제 후 VRAM: {vram_freed}MB")

    return result


def main() -> None:
    log("=" * 60)
    log("faster-whisper GPU 벤치마크 시작")
    log(f"테스트 오디오: {TEST_AUDIO}")
    log(f"결과 저장 위치: {RESULTS_FILE}")
    log("=" * 60)

    # GPU 초기 상태
    vram_init = get_vram_used_mb()
    log(f"초기 VRAM 사용량: {vram_init}MB")

    all_results = []
    total_tests = len(MODEL_SIZES) * len(CONFIGS)
    test_idx = 0

    for model_size in MODEL_SIZES:
        log("")
        log(f"{'='*60}")
        log(f"모델: {model_size.upper()}")
        log(f"{'='*60}")

        for cfg in CONFIGS:
            test_idx += 1
            device = cfg["device"]
            compute_type = cfg["compute_type"]
            label = cfg["label"]

            log("")
            log(f"[{test_idx}/{total_tests}] {model_size} | {label}")
            log(f"  device={device!r}, compute_type={compute_type!r}")

            result = run_single_test(
                model_size=model_size,
                device=device,
                compute_type=compute_type,
                label=label,
            )
            all_results.append(result)

            # 중간 저장 (테스트 도중 중단돼도 결과 보존)
            with open(RESULTS_FILE, "w", encoding="utf-8") as f:
                json.dump(all_results, f, ensure_ascii=False, indent=2)

    # ── 최종 요약 출력 ──────────────────────────────────────────────
    log("")
    log("=" * 60)
    log("벤치마크 완료 — 요약")
    log("=" * 60)

    success_results = [r for r in all_results if r["status"] == "success"]
    error_results   = [r for r in all_results if r["status"] == "error"]

    log(f"총 {total_tests}건 | 성공: {len(success_results)} | 실패: {len(error_results)}")
    log("")

    # 성공한 결과를 전사 시간 기준 정렬
    if success_results:
        log("[ 성공 결과 — 전사 시간 오름차순 ]")
        header = f"{'모델':<10} {'설정':<30} {'로딩(s)':<10} {'전사(s)':<10} {'VRAM(MB)':<10}"
        log(header)
        log("-" * len(header))
        for r in sorted(success_results, key=lambda x: x["transcribe_time_sec"]):
            log(
                f"{r['model_size']:<10} {r['label']:<30} "
                f"{r['load_time_sec']:<10.2f} {r['transcribe_time_sec']:<10.2f} "
                f"{str(r['vram_after_transcribe_mb']):<10}"
            )

    if error_results:
        log("")
        log("[ 실패 결과 ]")
        for r in error_results:
            log(f"  {r['model_size']} / {r['label']}: {r['error']}")

    log("")
    log(f"전체 결과 JSON: {RESULTS_FILE}")
    log("벤치마크 종료.")


if __name__ == "__main__":
    main()
