"""
test_server.py - Whisper GPU HTTP 서비스 단위 테스트 (GPU 모델 로딩 없이 API 구조 검증)
"""

from __future__ import annotations

import io
import json
from typing import Any, Generator
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi.testclient import TestClient


# ---------------------------------------------------------------------------
# 픽스처: 모델 로딩 없이 앱 임포트
# ---------------------------------------------------------------------------

@pytest.fixture(autouse=True)
def mock_whisper_model() -> Generator[MagicMock, None, None]:
    """WhisperModel 생성자를 Mock으로 교체하여 GPU 로딩을 건너뜁니다."""
    mock_model = MagicMock()
    mock_model.transcribe.return_value = (
        [
            MagicMock(
                start=0.0,
                end=3.5,
                text=" 안녕하세요.",
                words=None,
            )
        ],
        MagicMock(language="ko", duration=3.5),
    )
    with patch("faster_whisper.WhisperModel", return_value=mock_model):
        import server  # noqa: PLC0415

        server.medium_model = mock_model
        server.small_model = None  # lazy 로딩 상태
        yield mock_model


@pytest.fixture()
def client(mock_whisper_model: MagicMock) -> TestClient:
    """FastAPI TestClient를 반환합니다."""
    import server  # noqa: PLC0415

    return TestClient(server.app, raise_server_exceptions=False)


# ---------------------------------------------------------------------------
# GET /v1/health 테스트
# ---------------------------------------------------------------------------

class TestHealthEndpoint:
    def test_health_returns_200(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        assert response.status_code == 200

    def test_health_response_schema(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        data = response.json()
        assert "status" in data
        assert "models" in data
        assert "device" in data
        assert "gpu_memory_used_mb" in data
        assert "last_used" in data
        assert "unload_timeout_sec" in data

    def test_health_status_ok(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        data = response.json()
        assert data["status"] == "ok"

    def test_health_device_cuda(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        data = response.json()
        assert data["device"] == "cuda"

    def test_health_models_field_structure(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        data = response.json()
        models = data["models"]
        assert isinstance(models, dict)
        assert "medium" in models
        assert "small" in models
        assert models["medium"] in ("loaded", "unloaded")
        assert models["small"] in ("loaded", "unloaded")

    def test_health_shows_loaded_model(self, client: TestClient) -> None:
        """medium_model이 설정된 상태에서 loaded 반환 확인."""
        response = client.get("/v1/health")
        data = response.json()
        # mock_whisper_model fixture가 server.medium_model을 설정하므로 loaded
        assert data["models"]["medium"] == "loaded"

    def test_health_shows_unloaded_model(self, client: TestClient) -> None:
        """small_model이 None이면 unloaded 반환 확인."""
        response = client.get("/v1/health")
        data = response.json()
        # mock_whisper_model fixture가 small_model = None으로 설정
        assert data["models"]["small"] == "unloaded"

    def test_health_unload_timeout_value(self, client: TestClient) -> None:
        """unload_timeout_sec이 600 반환 확인."""
        response = client.get("/v1/health")
        data = response.json()
        assert data["unload_timeout_sec"] == 600

    def test_health_gpu_memory_is_int(self, client: TestClient) -> None:
        response = client.get("/v1/health")
        data = response.json()
        assert isinstance(data["gpu_memory_used_mb"], int)


# ---------------------------------------------------------------------------
# POST /v1/transcribe 테스트
# ---------------------------------------------------------------------------

def _make_dummy_audio() -> bytes:
    """최소한의 유효한 WebM/Ogg 더미 바이트 (실제 전사는 Mock 처리)."""
    return b"\x1aE\xdf\xa3" + b"\x00" * 64


class TestTranscribeEndpoint:
    def test_transcribe_missing_input_returns_422(self, client: TestClient) -> None:
        """파일도 URL도 없으면 422 반환."""
        response = client.post("/v1/transcribe")
        assert response.status_code == 422

    def test_transcribe_with_file_returns_200(
        self, client: TestClient, mock_whisper_model: MagicMock
    ) -> None:
        """오디오 파일 업로드 시 200 반환 (Mock 전사)."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "안녕하세요.",
                "segments": [{"start": 0.0, "end": 3.5, "text": "안녕하세요."}],
                "language": "ko",
                "duration": 3.5,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
                data={"language": "ko", "model": "medium", "format": "json"},
            )
        assert response.status_code == 200

    def test_transcribe_response_schema(
        self, client: TestClient
    ) -> None:
        """응답 스키마 검증: text, segments, language, duration 필드."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "테스트",
                "segments": [{"start": 0.0, "end": 1.0, "text": "테스트"}],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
            )
        assert response.status_code == 200
        data = response.json()
        assert "text" in data
        assert "segments" in data
        assert "language" in data
        assert "duration" in data

    def test_transcribe_segments_schema(self, client: TestClient) -> None:
        """segments 각 항목에 start, end, text 필드 확인."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "테스트",
                "segments": [{"start": 0.0, "end": 1.0, "text": "테스트"}],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
            )
        data = response.json()
        assert isinstance(data["segments"], list)
        if data["segments"]:
            seg = data["segments"][0]
            assert "start" in seg
            assert "end" in seg
            assert "text" in seg

    def test_transcribe_with_audio_url_returns_200(
        self, client: TestClient
    ) -> None:
        """audio_url JSON 입력 시 200 반환."""
        with (
            patch("server.download_audio_from_url", new_callable=AsyncMock) as mock_dl,
            patch("server.get_audio_duration", return_value=5.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_dl.return_value = "/tmp/fake_audio.wav"
            mock_trans.return_value = {
                "text": "URL 전사 테스트",
                "segments": [],
                "language": "ko",
                "duration": 5.0,
            }
            response = client.post(
                "/v1/transcribe",
                json={"audio_url": "http://example.com/test.wav", "language": "ko"},
            )
        assert response.status_code == 200

    def test_transcribe_default_language_is_ko(
        self, client: TestClient
    ) -> None:
        """language 파라미터 기본값 ko 확인."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "기본 언어",
                "segments": [],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
                # language 파라미터 미전달
            )
        assert response.status_code == 200
        data = response.json()
        assert data["language"] == "ko"

    def test_transcribe_format_text(self, client: TestClient) -> None:
        """format=text 시 plain text 반환."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "텍스트 포맷",
                "segments": [],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
                data={"format": "text"},
            )
        assert response.status_code == 200
        # format=text 이면 content-type이 text/plain
        assert "text/plain" in response.headers.get("content-type", "")

    def test_transcribe_format_srt(self, client: TestClient) -> None:
        """format=srt 시 SRT 형식 반환."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "SRT 포맷",
                "segments": [{"start": 0.0, "end": 1.0, "text": "SRT 포맷"}],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
                data={"format": "srt"},
            )
        assert response.status_code == 200
        assert "text/plain" in response.headers.get("content-type", "")

    def test_transcribe_format_vtt(self, client: TestClient) -> None:
        """format=vtt 시 VTT 형식 반환."""
        with (
            patch("server.get_audio_duration", return_value=10.0),
            patch("server.run_transcription", new_callable=AsyncMock) as mock_trans,
        ):
            mock_trans.return_value = {
                "text": "VTT 포맷",
                "segments": [{"start": 0.0, "end": 1.0, "text": "VTT 포맷"}],
                "language": "ko",
                "duration": 1.0,
            }
            audio_bytes = _make_dummy_audio()
            response = client.post(
                "/v1/transcribe",
                files={"file": ("test.webm", io.BytesIO(audio_bytes), "audio/webm")},
                data={"format": "vtt"},
            )
        assert response.status_code == 200
        assert "text/plain" in response.headers.get("content-type", "")


# ---------------------------------------------------------------------------
# 모델 자동 선택 로직 테스트
# ---------------------------------------------------------------------------

class TestModelSelection:
    def test_select_small_model_for_short_audio(self) -> None:
        """30분 미만 → small 모델 선택."""
        import server  # noqa: PLC0415

        selected = server.select_model_for_duration(10.0 * 60)  # 10분
        assert selected == "small"

    def test_select_medium_model_for_medium_audio(self) -> None:
        """30분 이상 2시간 미만 → medium 모델 선택."""
        import server  # noqa: PLC0415

        selected = server.select_model_for_duration(60.0 * 60)  # 60분
        assert selected == "medium"

    def test_select_medium_model_for_long_audio(self) -> None:
        """2시간 이상 → medium 모델 선택."""
        import server  # noqa: PLC0415

        selected = server.select_model_for_duration(150.0 * 60)  # 150분
        assert selected == "medium"

    def test_boundary_30min(self) -> None:
        """경계값: 정확히 30분 → medium 모델."""
        import server  # noqa: PLC0415

        selected = server.select_model_for_duration(30.0 * 60)
        assert selected == "medium"

    def test_boundary_just_under_30min(self) -> None:
        """경계값: 29분 59초 → small 모델."""
        import server  # noqa: PLC0415

        selected = server.select_model_for_duration(29.0 * 60 + 59)
        assert selected == "small"


# ---------------------------------------------------------------------------
# SRT / VTT 포맷 변환 함수 테스트
# ---------------------------------------------------------------------------

class TestFormatConversion:
    def test_to_srt_format(self) -> None:
        """SRT 형식 변환 검증."""
        import server  # noqa: PLC0415

        segments = [
            {"start": 0.0, "end": 3.5, "text": "안녕하세요."},
            {"start": 3.5, "end": 7.0, "text": "반갑습니다."},
        ]
        srt = server.segments_to_srt(segments)
        assert "1\n" in srt
        assert "2\n" in srt
        assert "00:00:00,000" in srt
        assert "안녕하세요." in srt

    def test_to_vtt_format(self) -> None:
        """VTT 형식 변환 검증."""
        import server  # noqa: PLC0415

        segments = [
            {"start": 0.0, "end": 3.5, "text": "안녕하세요."},
        ]
        vtt = server.segments_to_vtt(segments)
        assert vtt.startswith("WEBVTT")
        assert "00:00:00.000" in vtt
        assert "안녕하세요." in vtt

    def test_srt_timestamp_format(self) -> None:
        """SRT 타임스탬프 포맷: HH:MM:SS,mmm."""
        import server  # noqa: PLC0415

        ts = server.seconds_to_srt_timestamp(3661.5)
        assert ts == "01:01:01,500"

    def test_vtt_timestamp_format(self) -> None:
        """VTT 타임스탬프 포맷: HH:MM:SS.mmm."""
        import server  # noqa: PLC0415

        ts = server.seconds_to_vtt_timestamp(3661.5)
        assert ts == "01:01:01.500"


# ---------------------------------------------------------------------------
# Lazy-load 기능 테스트
# ---------------------------------------------------------------------------

class TestLazyLoad:
    def test_unload_models_clears_globals(self, mock_whisper_model: MagicMock) -> None:
        """_unload_models 호출 후 medium_model과 small_model이 None."""
        import server  # noqa: PLC0415

        server.medium_model = mock_whisper_model
        server.small_model = mock_whisper_model
        with patch("server.torch", create=True) as mock_torch:
            mock_torch.cuda.empty_cache = MagicMock()
            server._unload_models()
        assert server.medium_model is None
        assert server.small_model is None

    def test_get_model_instance_lazy_loads_medium(self, mock_whisper_model: MagicMock) -> None:
        """medium_model이 None이면 get_model_instance가 로딩."""
        import server  # noqa: PLC0415

        server.medium_model = None
        model = server.get_model_instance("medium")
        assert model is not None
        assert server.medium_model is not None

    def test_get_model_instance_updates_last_used(self, mock_whisper_model: MagicMock) -> None:
        """get_model_instance 호출 후 _last_used가 갱신."""
        import server  # noqa: PLC0415
        import time

        server._last_used = 0.0
        before = time.time()
        server.get_model_instance("medium")
        assert server._last_used >= before
