from faster_whisper import WhisperModel
import sys

audio_path = "/home/jay/workspace/teams/dev4/task-932.1/audio.wav"
out_dir = "/home/jay/workspace/teams/dev4/task-932.1"

def run_transcription(model_size, device, compute_type):
    print(f"[INFO] Loading model: {model_size}, device={device}, compute_type={compute_type}")
    model = WhisperModel(model_size, device=device, compute_type=compute_type)

    print(f"[INFO] Transcribing: {audio_path}")
    segments, info = model.transcribe(
        audio_path,
        language="ko",
        beam_size=5
    )

    print(f"[INFO] Detected language: {info.language} (probability: {info.language_probability:.2f})")

    # segments는 제너레이터이므로 리스트로 변환
    segment_list = list(segments)
    print(f"[INFO] Total segments: {len(segment_list)}")

    full_text = []
    for segment in segment_list:
        print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
        full_text.append(segment.text)

    # 전체 텍스트 저장
    with open(f"{out_dir}/transcript.txt", "w", encoding="utf-8") as f:
        f.write("\n".join(full_text))

    # 타임스탬프 포함 버전 저장
    with open(f"{out_dir}/transcript_timestamped.txt", "w", encoding="utf-8") as f:
        for segment in segment_list:
            f.write(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n")

    print(f"[INFO] transcript.txt saved ({len(full_text)} lines)")
    print(f"[INFO] transcript_timestamped.txt saved")
    return True


# 우선 small 모델 GPU 시도, 실패 시 base GPU, 최종 fallback CPU int8
attempts = [
    ("small", "cuda", "float16"),
    ("base",  "cuda", "float16"),
    ("base",  "cpu",  "int8"),
]

for model_size, device, compute_type in attempts:
    try:
        success = run_transcription(model_size, device, compute_type)
        if success:
            print(f"[DONE] Pipeline completed with model={model_size}, device={device}")
            sys.exit(0)
    except Exception as e:
        print(f"[WARN] Failed with model={model_size}, device={device}: {e}", file=sys.stderr)
        continue

print("[ERROR] All attempts failed.", file=sys.stderr)
sys.exit(1)
