"""
Tests for embedding_service.py

TDD: 테스트 먼저 작성, 구현은 이후
"""

import os
from unittest.mock import MagicMock, call, patch

import openai
import pytest

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_embedding_response(vectors: list[list[float]]) -> MagicMock:
    """openai embeddings.create() 반환값을 흉내 내는 mock 객체 생성."""
    response = MagicMock()
    response.data = [MagicMock(embedding=vec) for vec in vectors]
    return response


def _make_vector(dim: int = 1536) -> list[float]:
    """dim 차원의 더미 벡터 반환."""
    return [0.1] * dim


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------


@pytest.fixture(autouse=True)
def set_api_key(monkeypatch):
    """모든 테스트에서 OPENAI_API_KEY를 설정."""
    monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")


# ---------------------------------------------------------------------------
# Test cases
# ---------------------------------------------------------------------------


class TestGetEmbedding:
    """get_embedding() 단일 텍스트 임베딩 테스트."""

    def test_returns_1536_dim_vector(self):
        """단일 텍스트 입력 시 1536차원 벡터를 반환한다."""
        from embedding_service import get_embedding

        mock_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.return_value = mock_response

            result = get_embedding("hello world")

        assert isinstance(result, list)
        assert len(result) == 1536

    def test_vector_length_is_1536(self):
        """반환 벡터 길이가 정확히 1536인지 검증."""
        from embedding_service import get_embedding

        mock_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.return_value = mock_response

            result = get_embedding("test text")

        assert len(result) == 1536

    def test_calls_api_with_correct_model_and_dimensions(self):
        """올바른 model과 dimensions 인자로 API를 호출하는지 검증."""
        from embedding_service import get_embedding

        mock_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.return_value = mock_response

            get_embedding("sample text")

            mock_client.embeddings.create.assert_called_once_with(
                input=["sample text"],
                model="text-embedding-3-small",
                dimensions=1536,
            )


class TestGetEmbeddingsBatch:
    """get_embeddings_batch() 배치 임베딩 테스트."""

    def test_returns_same_number_of_vectors_as_input(self):
        """입력 텍스트 수와 동일한 수의 벡터를 반환한다."""
        from embedding_service import get_embeddings_batch

        texts = ["text1", "text2", "text3"]
        vectors = [_make_vector() for _ in texts]
        mock_response = _make_embedding_response(vectors)

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.return_value = mock_response

            result = get_embeddings_batch(texts)

        assert len(result) == len(texts)

    def test_empty_list_returns_empty_list(self):
        """빈 리스트 입력 시 빈 리스트를 반환한다 (API 호출 없음)."""
        from embedding_service import get_embeddings_batch

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client

            result = get_embeddings_batch([])

        assert result == []
        mock_client.embeddings.create.assert_not_called()

    def test_over_100_texts_splits_into_batches(self):
        """100개 초과 텍스트 입력 시 100개씩 분할하여 여러 번 API를 호출한다."""
        from embedding_service import get_embeddings_batch

        texts = [f"text_{i}" for i in range(250)]

        # 배치별 응답 생성: 100, 100, 50개
        batch1_response = _make_embedding_response([_make_vector() for _ in range(100)])
        batch2_response = _make_embedding_response([_make_vector() for _ in range(100)])
        batch3_response = _make_embedding_response([_make_vector() for _ in range(50)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = [
                batch1_response,
                batch2_response,
                batch3_response,
            ]

            result = get_embeddings_batch(texts)

        # 3번 호출되어야 함
        assert mock_client.embeddings.create.call_count == 3
        # 총 250개 벡터 반환
        assert len(result) == 250

    def test_batch_split_calls_correct_input_slices(self):
        """배치 분할 시 각 호출의 input이 올바른 슬라이스인지 검증."""
        from embedding_service import get_embeddings_batch

        texts = [f"t{i}" for i in range(150)]

        batch1_response = _make_embedding_response([_make_vector() for _ in range(100)])
        batch2_response = _make_embedding_response([_make_vector() for _ in range(50)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = [
                batch1_response,
                batch2_response,
            ]

            get_embeddings_batch(texts)

        calls = mock_client.embeddings.create.call_args_list
        assert calls[0].kwargs["input"] == texts[:100]
        assert calls[1].kwargs["input"] == texts[100:]


class TestRetryLogic:
    """API 에러 시 재시도 로직 테스트."""

    def test_retries_3_times_on_api_error(self):
        """openai.APIError 발생 시 최대 3회 재시도한다."""
        from embedding_service import get_embedding

        api_error = openai.APIError(
            message="Internal Server Error",
            request=MagicMock(),
            body=None,
        )

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = api_error

            with patch("embedding_service.time.sleep"):
                with pytest.raises(openai.APIError):
                    get_embedding("test")

        # 초기 시도 1번 + 재시도 2번 = 총 3번 (3회 실패 후 raise)
        assert mock_client.embeddings.create.call_count == 3

    def test_raises_after_3_failed_retries(self):
        """3회 모두 실패 시 원래 에러를 raise한다."""
        from embedding_service import get_embedding

        api_error = openai.APIError(
            message="Server Error",
            request=MagicMock(),
            body=None,
        )

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = api_error

            with patch("embedding_service.time.sleep"):
                with pytest.raises(openai.APIError):
                    get_embedding("test")

    def test_succeeds_on_second_attempt_after_api_error(self):
        """첫 번째 시도 실패 후 두 번째 시도에서 성공하면 결과를 반환한다."""
        from embedding_service import get_embedding

        api_error = openai.APIError(
            message="Temporary Error",
            request=MagicMock(),
            body=None,
        )
        success_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = [api_error, success_response]

            with patch("embedding_service.time.sleep"):
                result = get_embedding("test")

        assert len(result) == 1536
        assert mock_client.embeddings.create.call_count == 2


class TestRateLimitHandling:
    """Rate limit 에러(429) 처리 테스트."""

    def test_rate_limit_error_waits_and_retries(self):
        """RateLimitError 발생 시 대기(sleep) 후 재시도한다."""
        from embedding_service import get_embedding

        rate_limit_error = openai.RateLimitError(
            message="Rate limit exceeded",
            response=MagicMock(status_code=429, headers={}),
            body=None,
        )
        success_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            mock_client.embeddings.create.side_effect = [
                rate_limit_error,
                success_response,
            ]

            with patch("embedding_service.time.sleep") as mock_sleep:
                result = get_embedding("test")

        # sleep이 최소 1회 호출되어야 함
        assert mock_sleep.call_count >= 1
        # 최종적으로 성공 결과를 반환해야 함
        assert len(result) == 1536

    def test_rate_limit_uses_exponential_backoff(self):
        """RateLimitError 연속 발생 시 지수 백오프(1, 2, 4초)를 사용한다."""
        from embedding_service import get_embedding

        rate_limit_error = openai.RateLimitError(
            message="Rate limit exceeded",
            response=MagicMock(status_code=429, headers={}),
            body=None,
        )
        success_response = _make_embedding_response([_make_vector(1536)])

        with patch("embedding_service.openai.OpenAI") as mock_client_cls:
            mock_client = MagicMock()
            mock_client_cls.return_value = mock_client
            # 2번 rate limit 에러 후 성공
            mock_client.embeddings.create.side_effect = [
                rate_limit_error,
                rate_limit_error,
                success_response,
            ]

            with patch("embedding_service.time.sleep") as mock_sleep:
                result = get_embedding("test")

        sleep_calls = [c.args[0] for c in mock_sleep.call_args_list]
        # 첫 번째 대기: 1초, 두 번째 대기: 2초
        assert sleep_calls[0] == 1
        assert sleep_calls[1] == 2
        assert len(result) == 1536


class TestApiKeyValidation:
    """API 키 환경변수 검증 테스트."""

    def test_raises_value_error_when_api_key_missing(self, monkeypatch):
        """OPENAI_API_KEY 환경변수가 없으면 ValueError를 발생시킨다."""
        from embedding_service import get_embedding

        monkeypatch.delenv("OPENAI_API_KEY", raising=False)

        with pytest.raises(ValueError, match="OPENAI_API_KEY"):
            get_embedding("test")

    def test_raises_value_error_for_batch_when_api_key_missing(self, monkeypatch):
        """OPENAI_API_KEY 없으면 get_embeddings_batch도 ValueError를 발생시킨다."""
        from embedding_service import get_embeddings_batch

        monkeypatch.delenv("OPENAI_API_KEY", raising=False)

        with pytest.raises(ValueError, match="OPENAI_API_KEY"):
            get_embeddings_batch(["text"])
