#!/usr/bin/env python3
"""
utils/session_store.py — SQLite WAL 기반 세션 저장소 (CRUD + 초기화)

FTS5 검색은 SearchMixin(session_store_search.py), 스키마 DDL은 ALL_DDL에서 제공.
Usage: from utils.session_store import SessionStore
"""

from __future__ import annotations

import json
import os
import re
import sqlite3
import threading
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any

from utils.logger import get_logger
from utils.session_store_search import ALL_DDL, SearchMixin

logger = get_logger(__name__)

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace")
DEFAULT_DB_PATH = str(Path(_WORKSPACE_ROOT) / "memory" / "sessions.db")

# 제어문자·zero-width 문자 패턴 (U+0000-001F, 007F, 0080-009F, 200B-200F, 2028, 2029, FEFF)
_STRIP_RE = re.compile(r"[\x00-\x1f\x7f\x80-\x9f\u200b\u200c\u200d\u200e\u200f\u2028\u2029\ufeff]")


def sanitize_title(text: str) -> str:
    """제어문자와 zero-width 문자를 제거한 문자열을 반환한다."""
    return _STRIP_RE.sub("", text)


def _now_iso() -> str:
    """UTC 기준 현재 시각을 ISO 8601 형식으로 반환."""
    return datetime.now(tz=timezone.utc).isoformat()


class SessionStore(SearchMixin):
    """SQLite WAL 모드 세션 저장소. FTS5 검색은 SearchMixin에서 제공."""

    def __init__(self, db_path: str | None = None) -> None:
        """db_path: SQLite 파일 경로. None이면 DEFAULT_DB_PATH 사용."""
        self.db_path: str = db_path or DEFAULT_DB_PATH
        Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)

        self._lock = threading.Lock()
        self._conn = sqlite3.connect(
            self.db_path,
            check_same_thread=False,
            isolation_level=None,  # autocommit — 트랜잭션은 명시적으로 관리
        )
        self._conn.row_factory = sqlite3.Row
        self._setup()

    def _setup(self) -> None:
        """WAL 설정 및 스키마 생성."""
        with self._lock:
            cur = self._conn.cursor()
            cur.execute("PRAGMA journal_mode=WAL;")
            cur.execute("PRAGMA foreign_keys=ON;")
            for ddl in ALL_DDL:
                cur.execute(ddl)
            self._conn.commit()

    def create_session(
        self,
        session_id: str,
        source: str,
        model: str | None = None,
        parent_session_id: str | None = None,
        title: str | None = None,
    ) -> None:
        """새 세션을 생성한다."""
        clean_title = sanitize_title(title) if title else None
        sql = "INSERT INTO sessions (session_id, source, model, parent_session_id, title, created_at) VALUES (?, ?, ?, ?, ?, ?)"
        with self._lock:
            self._conn.execute(sql, (session_id, source, model, parent_session_id, clean_title, _now_iso()))
            self._conn.commit()
        logger.debug("Session created: %s (source=%s)", session_id, source)

    def end_session(self, session_id: str, end_reason: str) -> None:
        """세션을 종료 상태로 업데이트한다."""
        with self._lock:
            self._conn.execute(
                "UPDATE sessions SET ended_at=?, end_reason=? WHERE session_id=?",
                (_now_iso(), end_reason, session_id),
            )
            self._conn.commit()
        logger.debug("Session ended: %s (reason=%s)", session_id, end_reason)

    def get_session(self, session_id: str) -> dict[str, Any] | None:
        """세션을 조회한다. 존재하지 않으면 None 반환."""
        cur = self._conn.execute("SELECT * FROM sessions WHERE session_id=?", (session_id,))
        row = cur.fetchone()
        return dict(row) if row else None

    def list_sessions(self, source: str | None = None, limit: int = 20) -> list[dict[str, Any]]:
        """세션 목록을 반환한다 (생성 시각 내림차순). source 필터 선택 가능."""
        if source is not None:
            cur = self._conn.execute(
                "SELECT * FROM sessions WHERE source=? ORDER BY created_at DESC LIMIT ?",
                (source, limit),
            )
        else:
            cur = self._conn.execute(
                "SELECT * FROM sessions ORDER BY created_at DESC LIMIT ?",
                (limit,),
            )
        return [dict(row) for row in cur.fetchall()]

    def append_message(
        self,
        session_id: str,
        role: str,
        content: str | None = None,
        tool_calls: list[dict[str, Any]] | None = None,
        tool_call_id: str | None = None,
    ) -> None:
        """세션에 메시지를 추가한다."""
        tool_calls_json = json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None
        sql = "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, created_at) VALUES (?, ?, ?, ?, ?, ?)"
        with self._lock:
            self._conn.execute(sql, (session_id, role, content, tool_calls_json, tool_call_id, _now_iso()))
            self._conn.commit()

    def get_messages(self, session_id: str) -> list[dict[str, Any]]:
        """세션의 모든 메시지를 생성 순서(created_at 오름차순)로 반환한다."""
        cur = self._conn.execute(
            "SELECT * FROM messages WHERE session_id=? ORDER BY created_at ASC, id ASC",
            (session_id,),
        )
        rows = []
        for row in cur.fetchall():
            d = dict(row)
            if d.get("tool_calls"):
                try:
                    d["tool_calls"] = json.loads(d["tool_calls"])
                except (json.JSONDecodeError, TypeError):
                    pass
            rows.append(d)
        return rows

    def prune_sessions(self, older_than_days: int = 90) -> int:
        """종료된 오래된 세션을 삭제하고 삭제 수를 반환한다 (기본 90일)."""
        cutoff = (datetime.now(tz=timezone.utc) - timedelta(days=older_than_days)).isoformat()
        with self._lock:
            cur = self._conn.execute(
                "DELETE FROM sessions WHERE ended_at IS NOT NULL AND ended_at < ?",
                (cutoff,),
            )
            self._conn.commit()
        deleted = cur.rowcount
        logger.info("prune_sessions: deleted %d sessions older than %d days", deleted, older_than_days)
        return deleted

    def close(self) -> None:
        """DB 연결을 닫는다."""
        try:
            self._conn.close()
            logger.debug("SessionStore closed: %s", self.db_path)
        except Exception as exc:  # noqa: BLE001
            logger.warning("SessionStore.close() error: %s", exc)
