#!/usr/bin/env python3
"""utils/usage_pricing.py 테스트 스위트"""

import sys
from decimal import Decimal
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from utils.usage_pricing import (
    CostResult,
    PricingEntry,
    calculate_cost,
    format_cost,
)


class TestPricingEntry:
    """PricingEntry dataclass 테스트"""

    def test_pricing_entry_basic_fields(self) -> None:
        """기본 필드 생성 확인"""
        entry = PricingEntry(
            input_per_1m=Decimal("3.00"),
            output_per_1m=Decimal("15.00"),
            cache_read_per_1m=Decimal("0.30"),
        )
        assert entry.input_per_1m == Decimal("3.00")
        assert entry.output_per_1m == Decimal("15.00")
        assert entry.cache_read_per_1m == Decimal("0.30")

    def test_pricing_entry_optional_cache(self) -> None:
        """cache_read_per_1m 없이 생성 가능"""
        entry = PricingEntry(
            input_per_1m=Decimal("0.14"),
            output_per_1m=Decimal("0.28"),
        )
        assert entry.cache_read_per_1m is None

    def test_pricing_entry_frozen(self) -> None:
        """frozen dataclass - 수정 불가"""
        entry = PricingEntry(
            input_per_1m=Decimal("1.00"),
            output_per_1m=Decimal("2.00"),
        )
        with pytest.raises((AttributeError, TypeError)):
            entry.input_per_1m = Decimal("99.00")  # type: ignore[misc]


class TestCostResult:
    """CostResult dataclass 테스트"""

    def test_cost_result_fields(self) -> None:
        """CostResult 필드 구조 확인"""
        result = CostResult(
            input_cost=Decimal("0.003"),
            output_cost=Decimal("0.015"),
            cache_savings=Decimal("0.0"),
            total_cost=Decimal("0.018"),
        )
        assert result.input_cost == Decimal("0.003")
        assert result.output_cost == Decimal("0.015")
        assert result.total_cost == Decimal("0.018")

    def test_cost_result_total_equals_sum(self) -> None:
        """total_cost = input_cost + output_cost"""
        result = CostResult(
            input_cost=Decimal("0.005"),
            output_cost=Decimal("0.020"),
            cache_savings=Decimal("0.0"),
            total_cost=Decimal("0.025"),
        )
        assert result.total_cost == result.input_cost + result.output_cost


class TestCalculateCost:
    """calculate_cost() 테스트"""

    def test_claude_sonnet_basic(self) -> None:
        """Claude Sonnet 4.6 기본 비용 계산"""
        result = calculate_cost(
            model="claude-sonnet-4-6",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        # $3.00 input + $15.00 output = $18.00
        assert result.input_cost == Decimal("3.00")
        assert result.output_cost == Decimal("15.00")
        assert result.total_cost == Decimal("18.00")

    def test_claude_opus_basic(self) -> None:
        """Claude Opus 4.6 비용 계산"""
        result = calculate_cost(
            model="claude-opus-4-6",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        # $15.00 input + $75.00 output
        assert result.input_cost == Decimal("15.00")
        assert result.output_cost == Decimal("75.00")

    def test_claude_haiku_basic(self) -> None:
        """Claude Haiku 비용 계산"""
        result = calculate_cost(
            model="claude-3-5-haiku-20241022",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("0.80")
        assert result.output_cost == Decimal("4.00")

    def test_gpt4o_basic(self) -> None:
        """GPT-4o 기본 비용 계산"""
        result = calculate_cost(
            model="gpt-4o",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("2.50")
        assert result.output_cost == Decimal("10.00")

    def test_gpt4o_mini_basic(self) -> None:
        """GPT-4o-mini 비용 계산"""
        result = calculate_cost(
            model="gpt-4o-mini",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("0.15")
        assert result.output_cost == Decimal("0.60")

    def test_gemini_pro_basic(self) -> None:
        """Gemini Pro 비용 계산"""
        result = calculate_cost(
            model="gemini-2.5-pro",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("1.25")
        assert result.output_cost == Decimal("10.00")

    def test_gemini_flash_basic(self) -> None:
        """Gemini Flash 비용 계산"""
        result = calculate_cost(
            model="gemini-2.5-flash",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("0.15")
        assert result.output_cost == Decimal("0.60")

    def test_deepseek_v3_basic(self) -> None:
        """DeepSeek v3 (deepseek-chat) 비용 계산"""
        result = calculate_cost(
            model="deepseek-chat",
            input_tokens=1_000_000,
            output_tokens=1_000_000,
        )
        assert result.input_cost == Decimal("0.14")
        assert result.output_cost == Decimal("0.28")

    def test_cache_read_savings(self) -> None:
        """캐시 읽기 절약 계산 - 캐시는 일반 입력보다 저렴"""
        result = calculate_cost(
            model="claude-sonnet-4-6",
            input_tokens=0,
            output_tokens=0,
            cache_read_tokens=1_000_000,
        )
        # cache_read: $0.30/1M (vs input $3.00/1M)
        # savings = (input_price - cache_price) * tokens / 1M
        assert result.cache_savings > Decimal("0")

    def test_zero_tokens(self) -> None:
        """토큰 0개면 비용 0"""
        result = calculate_cost(
            model="claude-sonnet-4-6",
            input_tokens=0,
            output_tokens=0,
        )
        assert result.total_cost == Decimal("0")

    def test_unknown_model_raises(self) -> None:
        """알 수 없는 모델은 KeyError"""
        with pytest.raises(KeyError):
            calculate_cost(
                model="unknown-model-xyz",
                input_tokens=100,
                output_tokens=100,
            )

    def test_small_token_count(self) -> None:
        """소량 토큰도 정확히 계산"""
        result = calculate_cost(
            model="gpt-4o-mini",
            input_tokens=1000,
            output_tokens=500,
        )
        # input: 1000/1M * $0.15 = $0.00015
        # output: 500/1M * $0.60 = $0.00030
        assert result.input_cost == Decimal("0.00015")
        assert result.output_cost == Decimal("0.00030")
        assert result.total_cost == Decimal("0.00045")

    def test_total_includes_cache_cost(self) -> None:
        """total_cost에 캐시 비용 포함"""
        result = calculate_cost(
            model="claude-sonnet-4-6",
            input_tokens=500_000,
            output_tokens=500_000,
            cache_read_tokens=500_000,
        )
        # input: 0.5M * $3.00 = $1.50
        # output: 0.5M * $15.00 = $7.50
        # cache: 0.5M * $0.30 = $0.15
        expected_total = Decimal("1.50") + Decimal("7.50") + Decimal("0.15")
        assert result.total_cost == expected_total


class TestFormatCost:
    """format_cost() 테스트"""

    def test_format_small_cost(self) -> None:
        """소액 포맷 - $0.001 미만은 sub-cent 표시"""
        result = CostResult(
            input_cost=Decimal("0.00015"),
            output_cost=Decimal("0.00030"),
            cache_savings=Decimal("0"),
            total_cost=Decimal("0.00045"),
        )
        formatted = format_cost(result)
        assert isinstance(formatted, str)
        assert len(formatted) > 0

    def test_format_dollar_cost(self) -> None:
        """달러 단위 포맷"""
        result = CostResult(
            input_cost=Decimal("3.00"),
            output_cost=Decimal("15.00"),
            cache_savings=Decimal("0"),
            total_cost=Decimal("18.00"),
        )
        formatted = format_cost(result)
        assert "$" in formatted
        assert "18" in formatted

    def test_format_zero_cost(self) -> None:
        """비용 0 포맷"""
        result = CostResult(
            input_cost=Decimal("0"),
            output_cost=Decimal("0"),
            cache_savings=Decimal("0"),
            total_cost=Decimal("0"),
        )
        formatted = format_cost(result)
        assert isinstance(formatted, str)

    def test_format_with_savings(self) -> None:
        """캐시 절약 있을 때 표시 포함"""
        result = CostResult(
            input_cost=Decimal("0.15"),
            output_cost=Decimal("0.75"),
            cache_savings=Decimal("2.70"),
            total_cost=Decimal("0.90"),
        )
        formatted = format_cost(result)
        assert isinstance(formatted, str)
