"""
server.py - Whisper GPU HTTP 서비스 (FastAPI + faster-whisper)

포트: 8200
모델: medium (기본, int8, cuda) + small (lazy 로딩)
"""

from __future__ import annotations

import asyncio
import contextlib
import glob
import hmac
import logging
import os
import re
import subprocess
import tempfile
import time
from collections.abc import AsyncGenerator
from typing import Any, Optional

import httpx
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, PlainTextResponse
from pydantic import BaseModel

# ---------------------------------------------------------------------------
# 로깅 설정
# ---------------------------------------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("whisper-gpu")

# ---------------------------------------------------------------------------
# 전역 상태
# ---------------------------------------------------------------------------

_gpu_lock = asyncio.Lock()

# API 키 인증 (환경변수가 없으면 인증 비활성화)
WHISPER_API_KEY: Optional[str] = os.environ.get("WHISPER_API_KEY") or None

# 모델 인스턴스 (lazy-load: 첫 요청 시 로딩, 10분 미사용 시 자동 언로드)
medium_model: Any = None
small_model: Any = None

_last_used: float = 0.0
_UNLOAD_TIMEOUT = 600  # 10분 (초)

# ---------------------------------------------------------------------------
# 모델 로딩
# ---------------------------------------------------------------------------


def _load_whisper_model(model_size: str) -> Any:
    """WhisperModel을 CUDA + int8 양자화로 로드합니다."""
    from faster_whisper import WhisperModel  # noqa: PLC0415

    logger.info("WhisperModel 로딩: size=%s device=cuda compute_type=int8", model_size)
    model = WhisperModel(model_size, device="cuda", compute_type="int8")
    logger.info("WhisperModel 로딩 완료: size=%s", model_size)
    return model


def get_small_model() -> Any:
    """small 모델을 lazy 로딩으로 반환합니다."""
    global small_model, _last_used
    if small_model is None:
        small_model = _load_whisper_model("small")
    _last_used = time.time()
    return small_model


# ---------------------------------------------------------------------------
# FastAPI 앱
# ---------------------------------------------------------------------------


@contextlib.asynccontextmanager
async def lifespan(application: FastAPI) -> AsyncGenerator[None, None]:
    """FastAPI lifespan: lazy-load 모드 + 자동 언로드 태스크."""
    logger.info("서버 시작: lazy-load 모드 (모델 미로딩, %d초 미사용 시 자동 언로드)", _UNLOAD_TIMEOUT)
    unload_task = asyncio.create_task(_unload_checker())
    yield
    unload_task.cancel()
    with contextlib.suppress(asyncio.CancelledError):
        await unload_task
    logger.info("서버 종료")


app = FastAPI(
    title="Whisper GPU HTTP Service",
    description="faster-whisper 기반 오디오 전사 서비스",
    version="1.0.0",
    lifespan=lifespan,
)


# ---------------------------------------------------------------------------
# API 키 인증 미들웨어
# ---------------------------------------------------------------------------

_AUTH_EXEMPT_PATHS = {"/v1/health"}
_LOCALHOST_HOSTS = {"127.0.0.1", "::1", "localhost"}


@app.middleware("http")
async def api_key_auth_middleware(request: Request, call_next: Any) -> Any:
    """X-API-Key 헤더를 검증하는 인증 미들웨어.

    다음 경우에는 인증을 건너뜁니다:
    - WHISPER_API_KEY 환경변수가 설정되지 않은 경우 (인증 비활성화)
    - 요청 경로가 _AUTH_EXEMPT_PATHS에 포함된 경우 (/v1/health 등)
    - 요청 출처가 localhost/127.0.0.1인 경우 (로컬 테스트 호환)
    """
    # 인증 비활성화: 환경변수 미설정 시 통과
    if WHISPER_API_KEY is None:
        return await call_next(request)

    # 인증 예외 경로: /v1/health 등
    if request.url.path in _AUTH_EXEMPT_PATHS:
        return await call_next(request)

    # localhost 예외: 127.0.0.1 / ::1 / localhost 출처
    client_host = request.client.host if request.client else ""
    if client_host in _LOCALHOST_HOSTS:
        return await call_next(request)

    # X-API-Key 헤더 검증 (타이밍 공격 방지: hmac.compare_digest 사용)
    provided_key = request.headers.get("X-API-Key", "")
    key_valid = bool(provided_key) and hmac.compare_digest(provided_key.encode(), WHISPER_API_KEY.encode())
    if not key_valid:
        return JSONResponse(
            status_code=401,
            content={"detail": "Invalid or missing API key"},
        )

    return await call_next(request)


# ---------------------------------------------------------------------------
# 유틸리티 함수
# ---------------------------------------------------------------------------


def select_model_for_duration(duration_seconds: float) -> str:
    """오디오 길이에 따라 적절한 모델 크기를 반환합니다.

    Args:
        duration_seconds: 오디오 길이 (초)

    Returns:
        "small" (30분 미만) 또는 "medium" (30분 이상)
    """
    threshold_30min = 30.0 * 60  # 1800초
    if duration_seconds < threshold_30min:
        return "small"
    return "medium"


def get_model_instance(model_name: str) -> Any:
    """모델 이름에 해당하는 인스턴스를 반환합니다. lazy-load 적용."""
    global medium_model, _last_used
    _last_used = time.time()
    if model_name == "small":
        return get_small_model()
    if medium_model is None:
        medium_model = _load_whisper_model("medium")
    return medium_model


def _unload_models() -> None:
    """모든 모델을 CUDA 메모리에서 해제합니다."""
    global medium_model, small_model
    try:
        import torch  # noqa: PLC0415
    except ImportError:
        logger.warning("torch 미설치 — CUDA 캐시 정리 건너뜀")
        medium_model = None
        small_model = None
        return
    unloaded = []
    if medium_model is not None:
        del medium_model
        medium_model = None
        unloaded.append("medium")
    if small_model is not None:
        del small_model
        small_model = None
        unloaded.append("small")
    if unloaded:
        torch.cuda.empty_cache()
        logger.info("모델 언로드 완료: %s, CUDA 캐시 정리", ", ".join(unloaded))


async def _unload_checker() -> None:
    """백그라운드 태스크: 10분 미사용 시 모델 해제."""
    while True:
        await asyncio.sleep(60)
        if _last_used > 0 and time.time() - _last_used > _UNLOAD_TIMEOUT:
            if medium_model is not None or small_model is not None:
                logger.info(
                    "자동 언로드 시작: %.0f초 미사용",
                    time.time() - _last_used,
                )
                async with _gpu_lock:
                    loop = asyncio.get_event_loop()
                    await loop.run_in_executor(None, _unload_models)


def get_audio_duration(file_path: str) -> float:
    """ffprobe로 오디오 파일의 길이(초)를 반환합니다."""
    try:
        result = subprocess.run(
            [
                "ffprobe",
                "-v",
                "error",
                "-show_entries",
                "format=duration",
                "-of",
                "default=noprint_wrappers=1:nokey=1",
                file_path,
            ],
            capture_output=True,
            text=True,
            timeout=30,
        )
        duration_str = result.stdout.strip()
        if duration_str and duration_str != "N/A":
            return float(duration_str)
    except (subprocess.TimeoutExpired, ValueError, FileNotFoundError) as e:
        logger.warning("ffprobe duration 조회 실패: %s", e)
    return 0.0


async def download_audio_from_url(audio_url: str) -> str:
    """URL에서 오디오를 다운로드하여 임시 파일 경로를 반환합니다."""
    logger.info("URL에서 오디오 다운로드: %s", audio_url)
    async with httpx.AsyncClient(timeout=120.0) as http_client:
        response = await http_client.get(audio_url)
        response.raise_for_status()

    suffix = _get_suffix_from_url(audio_url)
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, prefix="whisper_url_") as tmp:
        tmp.write(response.content)
        return tmp.name


def _get_suffix_from_url(url: str) -> str:
    """URL에서 파일 확장자를 추출합니다."""
    path = url.split("?")[0].split("/")[-1]
    if "." in path:
        return "." + path.rsplit(".", 1)[-1]
    return ".audio"


async def run_transcription(
    file_path: str,
    language: str,
    model_name: str,
) -> dict[str, Any]:
    """GPU Lock을 획득한 후 전사를 실행합니다."""
    loop = asyncio.get_event_loop()

    async with _gpu_lock:
        logger.info("전사 시작: model=%s language=%s file=%s", model_name, language, file_path)
        model = get_model_instance(model_name)
        segments_raw, info = await loop.run_in_executor(
            None,
            lambda: model.transcribe(
                file_path,
                language=language if language else None,
                beam_size=5,
            ),
        )
        segments = [
            {
                "start": round(seg.start, 3),
                "end": round(seg.end, 3),
                "text": seg.text.strip(),
            }
            for seg in segments_raw
        ]
        full_text = " ".join(seg["text"] for seg in segments)
        detected_language = getattr(info, "language", language) or language
        duration = float(getattr(info, "duration", 0.0))

    logger.info(
        "전사 완료: language=%s duration=%.1fs segments=%d",
        detected_language,
        duration,
        len(segments),
    )
    return {
        "text": full_text,
        "segments": segments,
        "language": detected_language,
        "duration": duration,
    }


def get_gpu_memory_used_mb() -> int:
    """nvidia-smi로 현재 GPU 메모리 사용량(MB)을 반환합니다."""
    try:
        result = subprocess.run(
            [
                "nvidia-smi",
                "--query-gpu=memory.used",
                "--format=csv,noheader,nounits",
            ],
            capture_output=True,
            text=True,
            timeout=10,
        )
        value_str = result.stdout.strip().split("\n")[0].strip()
        if value_str:
            return int(value_str)
    except (subprocess.TimeoutExpired, ValueError, FileNotFoundError) as e:
        logger.warning("GPU 메모리 조회 실패: %s", e)
    return 0


# ---------------------------------------------------------------------------
# 포맷 변환 함수
# ---------------------------------------------------------------------------


def seconds_to_srt_timestamp(seconds: float) -> str:
    """초를 SRT 타임스탬프 형식 HH:MM:SS,mmm으로 변환합니다."""
    total_ms = round(seconds * 1000)
    ms = total_ms % 1000
    total_s = total_ms // 1000
    s = total_s % 60
    total_m = total_s // 60
    m = total_m % 60
    h = total_m // 60
    return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"


def seconds_to_vtt_timestamp(seconds: float) -> str:
    """초를 VTT 타임스탬프 형식 HH:MM:SS.mmm으로 변환합니다."""
    total_ms = round(seconds * 1000)
    ms = total_ms % 1000
    total_s = total_ms // 1000
    s = total_s % 60
    total_m = total_s // 60
    m = total_m % 60
    h = total_m // 60
    return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}"


def segments_to_srt(segments: list[dict[str, Any]]) -> str:
    """세그먼트 목록을 SRT 자막 형식 문자열로 변환합니다."""
    lines: list[str] = []
    for i, seg in enumerate(segments, start=1):
        start_ts = seconds_to_srt_timestamp(seg["start"])
        end_ts = seconds_to_srt_timestamp(seg["end"])
        lines.append(f"{i}\n{start_ts} --> {end_ts}\n{seg['text']}\n")
    return "\n".join(lines)


def segments_to_vtt(segments: list[dict[str, Any]]) -> str:
    """세그먼트 목록을 WebVTT 자막 형식 문자열로 변환합니다."""
    lines: list[str] = ["WEBVTT\n"]
    for seg in segments:
        start_ts = seconds_to_vtt_timestamp(seg["start"])
        end_ts = seconds_to_vtt_timestamp(seg["end"])
        lines.append(f"{start_ts} --> {end_ts}\n{seg['text']}\n")
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# 요청 모델
# ---------------------------------------------------------------------------


class TranscribeURLRequest(BaseModel):
    audio_url: str
    language: str = "ko"
    model: str = "medium"
    format: str = "json"


class YouTubeTranscribeRequest(BaseModel):
    video_id: str
    language: str = "ko"


# ---------------------------------------------------------------------------
# 엔드포인트
# ---------------------------------------------------------------------------


@app.get("/v1/health")
async def health() -> JSONResponse:
    """서비스 상태, 모델 로딩 상태, GPU 메모리 정보를 반환합니다."""
    gpu_mem_mb = get_gpu_memory_used_mb()
    return JSONResponse(
        content={
            "status": "ok",
            "device": "cuda",
            "gpu_memory_used_mb": gpu_mem_mb,
            "models": {
                "medium": "loaded" if medium_model is not None else "unloaded",
                "small": "loaded" if small_model is not None else "unloaded",
            },
            "last_used": _last_used if _last_used > 0 else None,
            "unload_timeout_sec": _UNLOAD_TIMEOUT,
        }
    )


@app.post("/v1/transcribe")
async def transcribe(request: Request) -> Any:
    """오디오 파일을 전사하여 결과를 반환합니다.

    multipart/form-data로 파일을 업로드하거나
    JSON body로 audio_url을 전달할 수 있습니다.

    form-data:
        - file: 오디오 파일 (필수, audio_url 없을 때)
        - language: 언어 코드 (기본: ko)
        - model: 모델 크기 (기본: medium)
        - format: 출력 형식 text|json|srt|vtt (기본: json)

    JSON body:
        - audio_url: 오디오 파일 URL (필수)
        - language: 언어 코드 (기본: ko)
        - model: 모델 크기 (기본: medium)
        - format: 출력 형식 (기본: json)
    """
    content_type = request.headers.get("content-type", "")

    # JSON body 처리 (audio_url 방식)
    if "application/json" in content_type:
        try:
            body = await request.json()
        except Exception as e:
            raise HTTPException(status_code=422, detail=f"JSON 파싱 실패: {str(e)}") from e

        audio_url: Optional[str] = body.get("audio_url")
        if not audio_url:
            raise HTTPException(
                status_code=422,
                detail="JSON body에 audio_url 필드가 필요합니다.",
            )
        language = body.get("language", "ko")
        model = body.get("model", "medium")
        output_format = body.get("format", "json")

        tmp_path: Optional[str] = None
        try:
            tmp_path = await download_audio_from_url(audio_url)
            return await _process_transcription(tmp_path, language, model, output_format)
        except HTTPException:
            raise
        except httpx.HTTPStatusError as e:
            logger.error("오디오 URL 다운로드 실패: %s", e)
            raise HTTPException(status_code=422, detail=f"오디오 URL 접근 실패: {str(e)}") from e
        except Exception as e:
            logger.exception("URL 전사 처리 중 오류: %s", e)
            raise HTTPException(status_code=500, detail=f"전사 처리 실패: {str(e)}") from e
        finally:
            if tmp_path and os.path.exists(tmp_path):
                try:
                    os.unlink(tmp_path)
                except OSError as unlink_err:
                    logger.warning("임시 파일 삭제 실패: %s", unlink_err)

    # multipart/form-data 처리 (파일 업로드 방식)
    try:
        form = await request.form()
    except Exception as e:
        raise HTTPException(status_code=422, detail=f"폼 데이터 파싱 실패: {str(e)}") from e

    file_field = form.get("file")
    if file_field is None:
        raise HTTPException(
            status_code=422,
            detail="파일(file) 또는 audio_url을 제공해야 합니다.",
        )

    upload_file: UploadFile = file_field  # type: ignore[assignment]
    language = str(form.get("language", "ko"))
    model = str(form.get("model", "medium"))
    output_format = str(form.get("format", "json"))

    file_tmp_path: Optional[str] = None
    try:
        suffix = _get_suffix_from_upload(upload_file.filename)
        with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, prefix="whisper_upload_") as tmp:
            content = await upload_file.read()
            tmp.write(content)
            file_tmp_path = tmp.name

        return await _process_transcription(file_tmp_path, language, model, output_format)

    except HTTPException:
        raise
    except Exception as e:
        logger.exception("전사 처리 중 오류 발생: %s", e)
        raise HTTPException(status_code=500, detail=f"전사 처리 실패: {str(e)}") from e
    finally:
        if file_tmp_path and os.path.exists(file_tmp_path):
            try:
                os.unlink(file_tmp_path)
            except OSError as unlink_err:
                logger.warning("임시 파일 삭제 실패: %s", unlink_err)


@app.post("/v1/transcribe/url")
async def transcribe_url(request: TranscribeURLRequest) -> Any:
    """JSON body의 audio_url로 오디오를 다운로드하여 전사합니다."""
    tmp_path: Optional[str] = None
    try:
        tmp_path = await download_audio_from_url(request.audio_url)
        return await _process_transcription(tmp_path, request.language, request.model, request.format)
    except HTTPException:
        raise
    except httpx.HTTPStatusError as e:
        logger.error("오디오 URL 다운로드 실패: %s", e)
        raise HTTPException(status_code=422, detail=f"오디오 URL 접근 실패: {str(e)}") from e
    except Exception as e:
        logger.exception("URL 전사 처리 중 오류 발생: %s", e)
        raise HTTPException(status_code=500, detail=f"전사 처리 실패: {str(e)}") from e
    finally:
        if tmp_path and os.path.exists(tmp_path):
            try:
                os.unlink(tmp_path)
            except OSError as e:
                logger.warning("임시 파일 삭제 실패: %s", e)


@app.post("/v1/youtube-transcribe")
async def youtube_transcribe(request: YouTubeTranscribeRequest) -> JSONResponse:
    """YouTube video_id로 오디오를 다운로드하여 전사합니다.

    yt-dlp로 베스트 오디오를 /tmp에 다운로드한 뒤 Whisper로 전사하고
    임시 파일을 삭제합니다.

    Request body:
        - video_id: YouTube 동영상 ID (영숫자 + 하이픈/언더스코어, 11자)
        - language: 언어 코드 (기본: ko)

    Returns:
        { "text": "...", "source": "whisper_local_ytdlp", "duration": 123.4 }
    """
    # 1. video_id 유효성 검증
    _VIDEO_ID_RE = re.compile(r"^[a-zA-Z0-9_-]{11}$")
    if not _VIDEO_ID_RE.match(request.video_id):
        raise HTTPException(
            status_code=422,
            detail="video_id는 영숫자·하이픈·언더스코어로 구성된 11자여야 합니다.",
        )

    video_id = request.video_id
    output_template = f"/tmp/whisper_yt_{video_id}.%(ext)s"
    youtube_url = f"https://youtube.com/watch?v={video_id}"

    audio_path: Optional[str] = None
    try:
        # 2. yt-dlp로 오디오 다운로드 (비동기 subprocess, 300초 타임아웃)
        logger.info("yt-dlp 다운로드 시작: video_id=%s", video_id)
        try:
            proc = await asyncio.create_subprocess_exec(
                "yt-dlp",
                "-f",
                "ba",
                "--no-playlist",
                "-o",
                output_template,
                youtube_url,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
        except asyncio.TimeoutError as e:
            logger.error("yt-dlp 다운로드 타임아웃: video_id=%s", video_id)
            raise HTTPException(
                status_code=422,
                detail=f"yt-dlp 다운로드 타임아웃 (300초): video_id={video_id}",
            ) from e

        if proc.returncode != 0:
            stderr_text = stderr.decode(errors="replace").strip()
            logger.error("yt-dlp 실패: returncode=%d stderr=%s", proc.returncode, stderr_text)
            raise HTTPException(
                status_code=422,
                detail=f"yt-dlp 다운로드 실패 (returncode={proc.returncode}): {stderr_text}",
            )

        logger.info("yt-dlp 다운로드 완료: video_id=%s", video_id)

        # 3. 다운로드된 파일 경로 찾기
        matched = glob.glob(f"/tmp/whisper_yt_{video_id}.*")
        if not matched:
            raise HTTPException(
                status_code=422,
                detail=f"yt-dlp 다운로드 후 파일을 찾을 수 없습니다: video_id={video_id}",
            )
        audio_path = matched[0]
        logger.info("다운로드된 오디오 파일: %s", audio_path)

        # 4. 오디오 길이 확인 및 모델 선택 후 전사
        duration_sec = get_audio_duration(audio_path)
        model_name = select_model_for_duration(duration_sec)
        logger.info(
            "전사 모델 선택: model=%s duration=%.1fs video_id=%s",
            model_name,
            duration_sec,
            video_id,
        )

        try:
            result = await run_transcription(audio_path, request.language, model_name)
        except Exception as e:
            logger.exception("전사 실패: video_id=%s error=%s", video_id, e)
            raise HTTPException(
                status_code=500,
                detail=f"전사 처리 실패: {str(e)}",
            ) from e

        # 5. 결과 반환
        return JSONResponse(
            content={
                "text": result["text"],
                "source": "whisper_local_ytdlp",
                "duration": result["duration"],
            }
        )

    finally:
        # 6. 임시 파일 삭제 (항상 실행)
        if audio_path and os.path.exists(audio_path):
            try:
                os.unlink(audio_path)
                logger.info("임시 파일 삭제 완료: %s", audio_path)
            except OSError as unlink_err:
                logger.warning("임시 파일 삭제 실패: %s", unlink_err)


# ---------------------------------------------------------------------------
# 내부 처리 함수
# ---------------------------------------------------------------------------


async def _process_transcription(
    file_path: str,
    language: str,
    model_hint: str,
    output_format: str,
) -> Any:
    """실제 전사 처리 및 포맷 변환을 수행합니다."""
    # 오디오 길이 확인 후 모델 자동 선택
    duration = get_audio_duration(file_path)
    auto_model = select_model_for_duration(duration)

    # 사용자가 명시적으로 모델을 지정한 경우 그것을 사용,
    # 아니면 자동 선택
    chosen_model = model_hint if model_hint in ("small", "medium", "large") else auto_model
    # auto 선택을 항상 우선: 요청 모델이 auto이거나 기본값이면 duration 기반 선택
    if model_hint == "medium" and duration > 0:
        # 짧은 오디오라면 small로 최적화
        chosen_model = auto_model

    logger.info(
        "모델 선택: hint=%s auto=%s chosen=%s duration=%.1fs",
        model_hint,
        auto_model,
        chosen_model,
        duration,
    )

    result = await run_transcription(file_path, language, chosen_model)

    # 포맷에 따라 응답 생성
    if output_format == "text":
        return PlainTextResponse(content=result["text"])
    elif output_format == "srt":
        srt_content = segments_to_srt(result["segments"])
        return PlainTextResponse(content=srt_content, media_type="text/plain")
    elif output_format == "vtt":
        vtt_content = segments_to_vtt(result["segments"])
        return PlainTextResponse(content=vtt_content, media_type="text/plain")
    else:
        # json (기본)
        return JSONResponse(content=result)


def _get_suffix_from_upload(filename: Optional[str]) -> str:
    """업로드 파일명에서 확장자를 추출합니다."""
    if filename and "." in filename:
        return "." + filename.rsplit(".", 1)[-1]
    return ".audio"


# ---------------------------------------------------------------------------
# 메인 실행 (직접 실행 시)
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "server:app",
        host="0.0.0.0",
        port=8200,
        log_level="info",
        workers=1,  # GPU 1개이므로 단일 워커
    )
