"""모델별 비용 계산 유틸리티.

실제 사용 중인 모델 위주의 간결한 가격표와 비용 계산 함수를 제공한다.
외부 패키지 없이 stdlib만 사용. 가격 단위는 모두 USD / 1M 토큰.

참고 가격 (2026-03 기준 공식 문서 스냅샷):
  Anthropic: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
  OpenAI:    https://openai.com/api/pricing/
  Google:    https://ai.google.dev/pricing
  DeepSeek:  https://api-docs.deepseek.com/quick_start/pricing
"""

from __future__ import annotations

from dataclasses import dataclass, field
from decimal import Decimal
from typing import Optional

_ONE_MILLION = Decimal("1000000")
_ZERO = Decimal("0")


@dataclass(frozen=True)
class PricingEntry:
    """모델 1개에 대한 가격 정보 (USD / 1M 토큰)."""

    input_per_1m: Decimal
    output_per_1m: Decimal
    cache_read_per_1m: Optional[Decimal] = field(default=None)


@dataclass(frozen=True)
class CostResult:
    """calculate_cost() 반환값.

    Attributes:
        input_cost: 일반 입력 토큰 비용 (USD).
        output_cost: 출력 토큰 비용 (USD).
        cache_savings: 캐시 사용으로 절약한 금액 (USD, >= 0).
        total_cost: 실제 청구 금액 = input + output + cache_read.
    """

    input_cost: Decimal
    output_cost: Decimal
    cache_savings: Decimal
    total_cost: Decimal


# ---------------------------------------------------------------------------
# 가격표 (compact tuple 형식)
# 컬럼 순서: input_per_1m, output_per_1m, cache_read_per_1m (None = 미지원)
# ---------------------------------------------------------------------------
def _p(inp: str, out: str, cr: Optional[str] = None) -> PricingEntry:
    return PricingEntry(
        input_per_1m=Decimal(inp),
        output_per_1m=Decimal(out),
        cache_read_per_1m=Decimal(cr) if cr is not None else None,
    )


_PRICING_TABLE: dict[str, PricingEntry] = {
    # Anthropic Claude 4.6
    "claude-opus-4-6": _p("15.00", "75.00", "1.50"),
    "claude-sonnet-4-6": _p("3.00", "15.00", "0.30"),
    # Anthropic Claude 3.5
    "claude-3-5-sonnet-20241022": _p("3.00", "15.00", "0.30"),
    "claude-3-5-haiku-20241022": _p("0.80", "4.00", "0.08"),
    # Anthropic Claude 3 (구형)
    "claude-3-opus-20240229": _p("15.00", "75.00", "1.50"),
    "claude-3-haiku-20240307": _p("0.25", "1.25", "0.03"),
    # OpenAI
    "gpt-4o": _p("2.50", "10.00", "1.25"),
    "gpt-4o-mini": _p("0.15", "0.60", "0.075"),
    # Google Gemini
    "gemini-2.5-pro": _p("1.25", "10.00"),
    "gemini-2.5-flash": _p("0.15", "0.60"),
    "gemini-2.0-flash": _p("0.10", "0.40"),
    # DeepSeek
    "deepseek-chat": _p("0.14", "0.28"),
    "deepseek-reasoner": _p("0.55", "2.19"),
}


def calculate_cost(
    model: str,
    input_tokens: int,
    output_tokens: int,
    cache_read_tokens: int = 0,
) -> CostResult:
    """모델 이름과 토큰 수를 받아 USD 비용을 계산한다.

    Args:
        model: 모델 ID (_PRICING_TABLE 키).
        input_tokens: 일반 입력 토큰 수.
        output_tokens: 출력 토큰 수.
        cache_read_tokens: 캐시 읽기 토큰 수 (기본 0).

    Returns:
        CostResult (input_cost, output_cost, cache_savings, total_cost).

    Raises:
        KeyError: 가격표에 없는 모델 ID.
    """
    entry = _PRICING_TABLE[model]  # 알 수 없는 모델 → KeyError

    input_cost = Decimal(input_tokens) * entry.input_per_1m / _ONE_MILLION
    output_cost = Decimal(output_tokens) * entry.output_per_1m / _ONE_MILLION

    cache_cost = _ZERO
    cache_savings = _ZERO
    if cache_read_tokens > 0 and entry.cache_read_per_1m is not None:
        cache_cost = Decimal(cache_read_tokens) * entry.cache_read_per_1m / _ONE_MILLION
        # 절약액 = (일반 입력 단가 - 캐시 단가) × 토큰 수
        saved_per_token = entry.input_per_1m - entry.cache_read_per_1m
        cache_savings = Decimal(cache_read_tokens) * saved_per_token / _ONE_MILLION

    return CostResult(
        input_cost=input_cost,
        output_cost=output_cost,
        cache_savings=cache_savings,
        total_cost=input_cost + output_cost + cache_cost,
    )


def format_cost(cost_result: CostResult) -> str:
    """CostResult를 사람이 읽기 쉬운 문자열로 변환한다.

    $0.01 이상: "$18.00" 형식 (소수 2자리)
    $0.01 미만: "$0.000045" 형식 (소수 6자리)
    캐시 절약 있을 때: "$0.90 (saved $2.70 with cache)" 형식
    """
    total = cost_result.total_cost
    savings = cost_result.cache_savings

    def _fmt(amount: Decimal) -> str:
        if amount >= Decimal("0.01"):
            return f"${amount:.2f}"
        if amount > _ZERO:
            return f"${amount:.6f}"
        return "$0.00"

    cost_str = _fmt(total)
    if savings > _ZERO:
        return f"{cost_str} (saved {_fmt(savings)} with cache)"
    return cost_str
