#!/usr/bin/env python3
"""utils/session_search.py — FTS5 기반 세션 메시지 전문 검색."""

from __future__ import annotations

import re
import sqlite3
from typing import TYPE_CHECKING

from utils.logger import get_logger

if TYPE_CHECKING:
    from utils.session_store import SessionStore

logger = get_logger(__name__)

# 정규식 상수
_QUOTED_PHRASE_RE = re.compile(r'"[^"]*"')
_BOOL_OP_RE = re.compile(r"\b(AND|OR|NOT)\b")
_SPECIAL_CHARS_RE = re.compile(r"[()\\+*]")
_HYPHEN_WORD_RE = re.compile(r"\b(\w+-\w[\w-]*)\b")
_MULTI_SPACE_RE = re.compile(r" {2,}")


def sanitize_fts5_query(query: str) -> str:
    """FTS5 특수문자를 처리하여 안전한 쿼리 문자열을 반환한다.

    - 인용구("...")는 보호 (FTS5 phrase query)
    - (, ), +, * 제거
    - AND/OR/NOT 단어 경계 제거 ("ANDROID" 보호)
    - 하이픈 단어(chat-send)를 인용 처리
    """
    if not query:
        return ""

    quoted_parts: list[str] = []

    def _stash(m: re.Match[str]) -> str:
        quoted_parts.append(m.group(0))
        return f"\x00Q{len(quoted_parts) - 1}\x00"

    safe = _QUOTED_PHRASE_RE.sub(_stash, query)
    safe = _BOOL_OP_RE.sub(" ", safe)
    safe = _SPECIAL_CHARS_RE.sub(" ", safe)
    safe = _HYPHEN_WORD_RE.sub(lambda m: f'"{m.group(1)}"', safe)

    for i, part in enumerate(quoted_parts):
        safe = safe.replace(f"\x00Q{i}\x00", part)

    return _MULTI_SPACE_RE.sub(" ", safe).strip()


def _resolve_lineage_root(session_id: str, db: "SessionStore") -> str:
    """위임 체인을 역방향 추적하여 루트 세션 ID를 반환한다. 순환 방지 포함."""
    visited: set[str] = set()
    current = session_id

    while True:
        if current in visited:
            logger.warning("Cycle detected in session lineage at: %s", current)
            return current
        visited.add(current)
        row = db.get_session(current)
        if row is None:
            return current
        parent = row.get("parent_session_id")
        if not parent:
            return current
        current = parent


def _truncate_around_matches(full_text: str, query: str, max_chars: int = 100_000) -> str:
    """쿼리 매칭 위치를 중심으로 max_chars 이하 텍스트 창을 추출한다."""
    if len(full_text) <= max_chars:
        return full_text
    if not query:
        return full_text[:max_chars]

    pos = -1
    for token in query.strip().split():
        idx = full_text.lower().find(token.lower().strip('"'))
        if idx != -1:
            pos = idx
            break

    if pos == -1:
        return full_text[:max_chars]

    half = max_chars // 2
    start = max(0, pos - half)
    end = min(len(full_text), start + max_chars)
    start = max(0, end - max_chars)
    return full_text[start:end]


def _format_conversation(messages: list[dict]) -> str:  # type: ignore[type-arg]
    """세션 메시지 목록을 '[role] content' 텍스트로 직렬화한다."""
    return "\n".join(f"[{m.get('role', 'unknown')}] {m.get('content') or ''}" for m in messages)


def _execute_fts_search(
    clean_query: str,
    db: "SessionStore",
    role_filter: list[str] | None,
    limit: int,
    current_session_id: str | None,
) -> list[dict]:  # type: ignore[type-arg]
    """FTS5 SQL 쿼리 실행 후 결과를 구조화한다."""
    role_clause = ""
    params: list[object] = [clean_query]

    if role_filter:
        role_clause = f" AND m.role IN ({','.join('?' * len(role_filter))})"
        params.extend(role_filter)

    sql = f"""
        SELECT m.session_id, snippet(messages_fts, 0, '<b>', '</b>', '...', 32) AS snip, rank
        FROM messages_fts
        JOIN messages m ON messages_fts.rowid = m.id
        WHERE messages_fts MATCH ?{role_clause}
        ORDER BY rank
        LIMIT {limit * 5}
    """

    conn = sqlite3.connect(db.db_path)
    conn.row_factory = sqlite3.Row
    try:
        rows = conn.execute(sql, params).fetchall()
    finally:
        conn.close()

    seen: set[str] = set()
    results: list[dict] = []  # type: ignore[type-arg]

    for row in rows:
        sid: str = row["session_id"]
        if current_session_id and sid == current_session_id:
            continue
        if sid in seen:
            continue
        seen.add(sid)

        msgs = db.get_messages(sid)
        summary = _truncate_around_matches(_format_conversation(msgs), clean_query, max_chars=500)
        results.append(
            {
                "session_id": sid,
                "root_session_id": _resolve_lineage_root(sid, db),
                "summary": summary,
                "snippet": row["snip"],
                "score": row["rank"],
            }
        )
        if len(results) >= limit:
            break

    return results


def search_sessions(
    query: str,
    db: "SessionStore",
    role_filter: list[str] | None = None,
    limit: int = 3,
    current_session_id: str | None = None,
) -> dict:  # type: ignore[type-arg]
    """FTS5 검색을 수행하고 결과를 포맷하여 반환한다."""
    clean_query = sanitize_fts5_query(query)
    if not clean_query:
        return {"results": []}

    try:
        results = _execute_fts_search(clean_query, db, role_filter, limit, current_session_id)
    except Exception as exc:  # noqa: BLE001
        logger.warning("FTS5 search error (query=%r): %s", clean_query, exc)
        return {"results": []}

    return {"results": results}
