# -*- coding: utf-8 -*-
"""Regression: classify_prompt_bytes — 4 ranges verification."""
from __future__ import annotations

import pytest

from utils.callback_registration import classify_prompt_bytes


class TestPromptByteClassification:
    """4-range classification: OK_TARGET / OK_ABOVE_TARGET / WARNING_BUT_ALLOWED / HARD_BLOCK."""

    @pytest.mark.parametrize("byte_count,expected", [
        (0, "OK_TARGET"),
        (1, "OK_TARGET"),
        (3200, "OK_TARGET"),
        (3201, "OK_ABOVE_TARGET"),
        (3300, "OK_ABOVE_TARGET"),
        (3499, "OK_ABOVE_TARGET"),
        (3500, "WARNING_BUT_ALLOWED"),
        (3700, "WARNING_BUT_ALLOWED"),
        (3900, "WARNING_BUT_ALLOWED"),
        (3901, "HARD_BLOCK"),
        (4096, "HARD_BLOCK"),
        (9999, "HARD_BLOCK"),
    ])
    def test_classification_ranges(self, byte_count: int, expected: str):
        # Use ASCII chars (1 byte each)
        prompt = "x" * byte_count
        result = classify_prompt_bytes(prompt)
        assert result == expected, (
            f"classify_prompt_bytes({byte_count} bytes): "
            f"expected {expected!r}, got {result!r}"
        )

    def test_empty_string(self):
        assert classify_prompt_bytes("") == "OK_TARGET"

    def test_boundary_3200(self):
        """Exact boundary: 3200 bytes → OK_TARGET."""
        assert classify_prompt_bytes("x" * 3200) == "OK_TARGET"

    def test_boundary_3201(self):
        """3201 bytes → OK_ABOVE_TARGET."""
        assert classify_prompt_bytes("x" * 3201) == "OK_ABOVE_TARGET"

    def test_boundary_3499(self):
        """3499 bytes → OK_ABOVE_TARGET."""
        assert classify_prompt_bytes("x" * 3499) == "OK_ABOVE_TARGET"

    def test_boundary_3500(self):
        """3500 bytes → WARNING_BUT_ALLOWED."""
        assert classify_prompt_bytes("x" * 3500) == "WARNING_BUT_ALLOWED"

    def test_boundary_3900(self):
        """3900 bytes → WARNING_BUT_ALLOWED (max allowed)."""
        assert classify_prompt_bytes("x" * 3900) == "WARNING_BUT_ALLOWED"

    def test_boundary_3901(self):
        """3901 bytes → HARD_BLOCK."""
        assert classify_prompt_bytes("x" * 3901) == "HARD_BLOCK"

    def test_multibyte_utf8(self):
        """Korean chars are 3 bytes each — test UTF-8 byte counting."""
        # 한 char = 3 bytes, so 1067 chars = 3201 bytes → OK_ABOVE_TARGET
        prompt = "한" * 1067
        byte_count = len(prompt.encode("utf-8"))
        assert byte_count == 3201
        assert classify_prompt_bytes(prompt) == "OK_ABOVE_TARGET"

    def test_all_four_ranges_covered(self):
        """Ensure all 4 class names are returned from their respective ranges."""
        classes = {
            classify_prompt_bytes("x" * 100),    # OK_TARGET
            classify_prompt_bytes("x" * 3250),   # OK_ABOVE_TARGET
            classify_prompt_bytes("x" * 3600),   # WARNING_BUT_ALLOWED
            classify_prompt_bytes("x" * 4000),   # HARD_BLOCK
        }
        assert classes == {"OK_TARGET", "OK_ABOVE_TARGET", "WARNING_BUT_ALLOWED", "HARD_BLOCK"}
