"""프롬프트 인젝션 탐지 모듈.

텍스트에서 프롬프트 인젝션 패턴(텍스트 기반 + 유니코드)을 탐지합니다.
Hermes Agent의 _scan_context_content() 로직을 참고하여 dev2-team 시스템에 맞게 재설계.

Usage:
    from utils.injection_guard import scan_content
    result = scan_content("some text")
    if not result.is_safe:
        for threat in result.threats:
            print(threat.pattern_name, threat.severity)
"""

import re
from dataclasses import dataclass, field


@dataclass
class ThreatInfo:
    """개별 위협 정보."""

    pattern_name: str
    matched_text: str
    severity: str  # "low" | "medium" | "high"


@dataclass
class ScanResult:
    """스캔 결과."""

    is_safe: bool
    threats: list[ThreatInfo] = field(default_factory=list)


# ---------------------------------------------------------------------------
# 유니코드 인젝션 문자 (ZWSP, RTL mark, 동형문자 관련 제어 문자 등)
# ---------------------------------------------------------------------------
_INVISIBLE_UNICODE: dict[str, str] = {
    "\u200b": "ZWSP (U+200B)",
    "\u200c": "ZWNJ (U+200C)",
    "\u200d": "ZWJ (U+200D)",
    "\u200f": "RTL_MARK (U+200F)",
    "\u2060": "WORD_JOINER (U+2060)",
    "\ufeff": "BOM (U+FEFF)",
    "\u202a": "LRE (U+202A)",
    "\u202b": "RLE (U+202B)",
    "\u202c": "PDF (U+202C)",
    "\u202d": "LRO (U+202D)",
    "\u202e": "RTL_OVERRIDE (U+202E)",
}

# ---------------------------------------------------------------------------
# 텍스트 기반 인젝션 패턴 (최소 10개)
# (pattern_regex, pattern_name, severity)
# ---------------------------------------------------------------------------
_TEXT_PATTERNS: list[tuple[str, str, str]] = [
    # 1. "ignore previous instructions" 계열
    (
        r"ignore\s+(previous|all|above|prior)\s+instructions",
        "ignore_instructions",
        "high",
    ),
    # 2. "forget your instructions"
    (
        r"forget\s+(your|all|my|the)\s+instructions",
        "forget_instructions",
        "high",
    ),
    # 3. "you are now" 역할 교체
    (
        r"you\s+are\s+now\b",
        "you_are_now",
        "high",
    ),
    # 4. "jailbreak"
    (
        r"\bjailbreak\b",
        "jailbreak",
        "high",
    ),
    # 5. "system prompt" 노출 유도
    (
        r"system\s+prompt",
        "system_prompt",
        "medium",
    ),
    # 6. "act as" 역할 가장
    (
        r"\bact\s+as\b",
        "act_as",
        "medium",
    ),
    # 7. "pretend you are"
    (
        r"pretend\s+(you\s+are|to\s+be)",
        "pretend_you_are",
        "medium",
    ),
    # 8. "override" 안전장치 우회
    (
        r"\boverride\b",
        "override",
        "medium",
    ),
    # 9. "bypass" 우회
    (
        r"\bbypass\b",
        "bypass",
        "medium",
    ),
    # 10. "disregard your instructions/rules" 변형
    (
        r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)",
        "disregard_rules",
        "high",
    ),
    # 11. HTML 숨김 주석 인젝션
    (
        r"<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->",
        "html_comment_injection",
        "high",
    ),
    # 12. 원격 실행 유도 (curl|bash)
    (
        r"\b(curl|wget)\b[^\n]*\|\s*(ba)?sh\b",
        "remote_execute",
        "high",
    ),
]


def scan_content(text: str) -> ScanResult:
    """텍스트에서 프롬프트 인젝션 패턴을 탐지합니다.

    Args:
        text: 검사할 문자열.

    Returns:
        ScanResult: is_safe=False이면 threats 목록에 탐지된 위협 포함.
    """
    threats: list[ThreatInfo] = []

    # 1. 유니코드 인젝션 문자 검사
    for char, name in _INVISIBLE_UNICODE.items():
        if char in text:
            threats.append(
                ThreatInfo(
                    pattern_name=f"invisible_unicode_{name.split()[0].lower()}",
                    matched_text=f"U+{ord(char):04X} ({name})",
                    severity="high",
                )
            )

    # 2. 텍스트 기반 패턴 검사
    for pattern, pattern_name, severity in _TEXT_PATTERNS:
        match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
        if match:
            threats.append(
                ThreatInfo(
                    pattern_name=pattern_name,
                    matched_text=match.group(0),
                    severity=severity,
                )
            )

    return ScanResult(is_safe=len(threats) == 0, threats=threats)


__all__ = [
    "ThreatInfo",
    "ScanResult",
    "InjectionBlockedError",
    "scan_content",
    "check_content",
]


class InjectionBlockedError(Exception):
    """프롬프트 인젝션이 탐지되어 처리가 차단됨.

    check_content()가 high severity 위협을 1개 이상 발견했을 때 발생합니다.

    Attributes:
        threats: 탐지된 모든 위협 목록 (ThreatInfo 리스트).
    """

    def __init__(self, threats: list[ThreatInfo]) -> None:
        self.threats = threats
        high_threats = [t for t in threats if t.severity == "high"]
        count = len(high_threats)
        first_name = high_threats[0].pattern_name if high_threats else threats[0].pattern_name
        super().__init__(f"인젝션 차단: {count}개의 위협 패턴 탐지, 첫 번째 패턴: {first_name!r}")


def check_content(text: str) -> ScanResult:
    """텍스트에서 프롬프트 인젝션을 탐지하고 위협 발견 시 즉시 차단합니다 (하드블록).

    내부에서 scan_content()를 호출하여 탐지 후, high severity 위협이 1개 이상이면
    InjectionBlockedError를 발생시킵니다. 안전한 텍스트는 ScanResult를 반환합니다.

    Args:
        text: 검사할 문자열.

    Returns:
        ScanResult: 위협이 없을 경우 is_safe=True인 ScanResult 반환.

    Raises:
        InjectionBlockedError: high severity 위협이 1개 이상 탐지된 경우.
    """
    result = scan_content(text)
    if not result.is_safe:
        high_threats = [t for t in result.threats if t.severity == "high"]
        if high_threats:
            raise InjectionBlockedError(result.threats)
    return result
