"""Anthropic 프롬프트 캐싱 유틸리티 (system_and_3 전략).

멀티턴 대화에서 입력 토큰 비용을 절감하기 위해 Anthropic의 cache_control
breakpoint를 메시지에 삽입합니다.

system_and_3 전략:
  - breakpoint 1: 시스템 프롬프트 (모든 턴에서 동일, 캐시 효과 최대)
  - breakpoint 2-4: 최근 비시스템 메시지 최대 3개 (롤링 윈도우)

총 최대 4개의 breakpoint를 사용하며 이는 Anthropic API 제한과 일치합니다.

순수 함수 설계: 입력을 deep copy하여 원본 메시지 목록을 변경하지 않습니다.

Usage:
    from utils.prompt_cache import apply_cache_markers

    cached_messages = apply_cache_markers(messages)
    # Anthropic API 호출 시 cached_messages 사용
"""

import copy
from typing import Any


def _attach_cache_control(msg: dict[str, Any], marker: dict[str, str]) -> None:
    """단일 메시지의 마지막 content 블록에 cache_control을 삽입합니다.

    메시지 content 형식에 따라 처리:
    - str: [{"type": "text", "text": ..., "cache_control": ...}] 로 변환
    - list: 마지막 블록에 cache_control 추가
    - None/빈 문자열: 메시지 레벨에 직접 cache_control 추가
    """
    content = msg.get("content")

    if content is None or content == "":
        msg["cache_control"] = marker
        return

    if isinstance(content, str):
        msg["content"] = [{"type": "text", "text": content, "cache_control": marker}]
        return

    if isinstance(content, list) and content:
        last_block = content[-1]
        if isinstance(last_block, dict):
            last_block["cache_control"] = marker


def apply_cache_markers(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
    """system_and_3 전략으로 메시지에 Anthropic 캐시 마커를 삽입합니다.

    시스템 프롬프트 1개 + 최근 비시스템 메시지 최대 3개에 cache_control breakpoint를
    삽입하여 총 최대 4개의 breakpoint를 사용합니다.

    원본 messages 리스트를 변경하지 않습니다 (deep copy 반환).

    Args:
        messages: Anthropic API 형식의 메시지 목록.
                  각 메시지는 {"role": ..., "content": ...} 형태.

    Returns:
        cache_control breakpoint가 삽입된 메시지 목록의 deep copy.
    """
    if not messages:
        return []

    result = copy.deepcopy(messages)
    marker: dict[str, str] = {"type": "ephemeral"}
    breakpoints_remaining = 4

    # breakpoint 1: 첫 번째 메시지가 system이면 캐시 마커 적용
    if result[0].get("role") == "system":
        _attach_cache_control(result[0], marker)
        breakpoints_remaining -= 1

    # breakpoint 2-4: 비시스템 메시지 중 마지막 최대 3개
    non_system_indices = [i for i, m in enumerate(result) if m.get("role") != "system"]
    for idx in non_system_indices[-breakpoints_remaining:]:
        _attach_cache_control(result[idx], marker)

    return result
