#!/usr/bin/env python3
"""utils/audit_logger.py 테스트 스위트"""

import json
import os
import sys
import threading
import time
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from utils.audit_logger import log_batch_operations, log_file_operation

# ---------------------------------------------------------------------------
# 헬퍼
# ---------------------------------------------------------------------------


def _read_jsonl(path: Path) -> list[dict]:
    """JSONL 파일을 파싱하여 레코드 목록 반환."""
    records = []
    if path.exists():
        for line in path.read_text(encoding="utf-8").splitlines():
            line = line.strip()
            if line:
                records.append(json.loads(line))
    return records


# ---------------------------------------------------------------------------
# 1. log_file_operation — 정상 케이스
# ---------------------------------------------------------------------------


class TestLogFileOperationBasic:
    """log_file_operation() 기본 동작 테스트"""

    def test_record_written_to_file(self, tmp_path):
        """레코드가 audit 파일에 기록됨"""
        trail = tmp_path / "audit-trail.jsonl"
        log_file_operation("task-001", "/tmp/foo.py", "Write", audit_path=str(trail))
        records = _read_jsonl(trail)
        assert len(records) == 1

    def test_record_fields_present(self, tmp_path):
        """레코드에 필수 필드가 모두 포함됨"""
        trail = tmp_path / "audit-trail.jsonl"
        log_file_operation("task-002", "/tmp/bar.py", "Edit", audit_path=str(trail))
        rec = _read_jsonl(trail)[0]
        assert rec["task_id"] == "task-002"
        assert rec["file"] == "/tmp/bar.py"
        assert rec["tool"] == "Edit"
        assert rec["operation"] == "write"
        assert rec["agent"] == "subagent"
        assert "ts" in rec

    def test_ts_is_iso8601(self, tmp_path):
        """ts 필드가 ISO 8601 형식임"""
        from datetime import datetime

        trail = tmp_path / "audit-trail.jsonl"
        log_file_operation("task-003", "/tmp/x.py", "Write", audit_path=str(trail))
        rec = _read_jsonl(trail)[0]
        # datetime.fromisoformat이 파싱 가능해야 함
        dt = datetime.fromisoformat(rec["ts"].replace("Z", "+00:00"))
        assert dt is not None

    def test_multiple_calls_append(self, tmp_path):
        """여러 번 호출하면 레코드가 누적됨"""
        trail = tmp_path / "audit-trail.jsonl"
        for i in range(5):
            log_file_operation(f"task-{i:03d}", f"/tmp/f{i}.py", "Write", audit_path=str(trail))
        records = _read_jsonl(trail)
        assert len(records) == 5

    def test_custom_operation_parameter(self, tmp_path):
        """operation 파라미터가 레코드에 반영됨"""
        trail = tmp_path / "audit-trail.jsonl"
        log_file_operation("task-004", "/tmp/z.py", "Write", operation="delete", audit_path=str(trail))
        rec = _read_jsonl(trail)[0]
        assert rec["operation"] == "delete"

    def test_creates_parent_directory(self, tmp_path):
        """audit 파일 부모 디렉토리가 없어도 자동 생성됨"""
        trail = tmp_path / "nested" / "deep" / "audit-trail.jsonl"
        # 디렉토리 없는 상태에서 호출
        log_file_operation("task-005", "/tmp/abc.py", "Write", audit_path=str(trail))
        assert trail.exists()
        records = _read_jsonl(trail)
        assert len(records) == 1

    def test_each_line_is_valid_json(self, tmp_path):
        """각 줄이 유효한 JSON임"""
        trail = tmp_path / "audit-trail.jsonl"
        for i in range(3):
            log_file_operation(f"task-{i}", f"/tmp/f{i}.py", "Write", audit_path=str(trail))
        raw_lines = [l for l in trail.read_text().splitlines() if l.strip()]
        for line in raw_lines:
            parsed = json.loads(line)  # 파싱 실패하면 예외
            assert isinstance(parsed, dict)


# ---------------------------------------------------------------------------
# 2. log_file_operation — 잘못된 인자
# ---------------------------------------------------------------------------


class TestLogFileOperationInvalidArgs:
    """잘못된 인자 처리 테스트"""

    def test_empty_task_id_raises(self, tmp_path):
        """빈 task_id는 ValueError 발생"""
        trail = tmp_path / "audit-trail.jsonl"
        with pytest.raises((ValueError, TypeError)):
            log_file_operation("", "/tmp/foo.py", "Write", audit_path=str(trail))

    def test_none_task_id_raises(self, tmp_path):
        """None task_id는 예외 발생"""
        trail = tmp_path / "audit-trail.jsonl"
        with pytest.raises((ValueError, TypeError)):
            log_file_operation(None, "/tmp/foo.py", "Write", audit_path=str(trail))  # type: ignore[arg-type]

    def test_empty_filepath_raises(self, tmp_path):
        """빈 filepath는 ValueError 발생"""
        trail = tmp_path / "audit-trail.jsonl"
        with pytest.raises((ValueError, TypeError)):
            log_file_operation("task-x", "", "Write", audit_path=str(trail))

    def test_empty_tool_raises(self, tmp_path):
        """빈 tool은 ValueError 발생"""
        trail = tmp_path / "audit-trail.jsonl"
        with pytest.raises((ValueError, TypeError)):
            log_file_operation("task-x", "/tmp/foo.py", "", audit_path=str(trail))


# ---------------------------------------------------------------------------
# 3. log_batch_operations — 정상 케이스
# ---------------------------------------------------------------------------


class TestLogBatchOperations:
    """log_batch_operations() 테스트"""

    def test_batch_writes_all_files(self, tmp_path):
        """배치 함수가 모든 파일 레코드를 기록함"""
        trail = tmp_path / "audit-trail.jsonl"
        filepaths = [f"/tmp/file{i}.py" for i in range(4)]
        log_batch_operations("task-010", filepaths, "Write", audit_path=str(trail))
        records = _read_jsonl(trail)
        assert len(records) == 4

    def test_batch_empty_list_writes_nothing(self, tmp_path):
        """빈 파일 목록은 레코드를 기록하지 않음"""
        trail = tmp_path / "audit-trail.jsonl"
        log_batch_operations("task-011", [], "Write", audit_path=str(trail))
        records = _read_jsonl(trail)
        assert len(records) == 0

    def test_batch_records_have_correct_task_id(self, tmp_path):
        """배치 레코드 모두 task_id가 동일함"""
        trail = tmp_path / "audit-trail.jsonl"
        filepaths = ["/tmp/a.py", "/tmp/b.py", "/tmp/c.py"]
        log_batch_operations("task-012", filepaths, "Edit", audit_path=str(trail))
        records = _read_jsonl(trail)
        assert all(r["task_id"] == "task-012" for r in records)

    def test_batch_records_correct_filepaths(self, tmp_path):
        """배치 레코드의 file 필드가 입력 경로와 일치함"""
        trail = tmp_path / "audit-trail.jsonl"
        filepaths = ["/tmp/x.py", "/tmp/y.py"]
        log_batch_operations("task-013", filepaths, "Write", audit_path=str(trail))
        recorded_files = {r["file"] for r in _read_jsonl(trail)}
        assert recorded_files == set(filepaths)


# ---------------------------------------------------------------------------
# 4. 동시 쓰기 안전성
# ---------------------------------------------------------------------------


class TestConcurrentWrites:
    """동시 쓰기 시 파일 락으로 안전하게 기록됨"""

    def test_concurrent_writes_no_data_loss(self, tmp_path):
        """N개 스레드가 동시에 기록해도 모든 레코드가 남음"""
        trail = tmp_path / "audit-trail.jsonl"
        n_threads = 20
        errors: list[Exception] = []

        def worker(idx: int):
            try:
                log_file_operation(f"task-{idx:03d}", f"/tmp/concurrent_{idx}.py", "Write", audit_path=str(trail))
            except Exception as e:
                errors.append(e)

        threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_threads)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        assert not errors, f"스레드 오류 발생: {errors}"
        records = _read_jsonl(trail)
        assert len(records) == n_threads

    def test_concurrent_writes_all_valid_json(self, tmp_path):
        """동시 쓰기 후에도 모든 줄이 유효한 JSON임"""
        trail = tmp_path / "audit-trail.jsonl"
        n_threads = 15

        def worker(idx: int):
            log_file_operation(f"task-{idx}", f"/tmp/f{idx}.py", "Write", audit_path=str(trail))

        threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_threads)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        raw_lines = [l for l in trail.read_text().splitlines() if l.strip()]
        assert len(raw_lines) == n_threads
        for line in raw_lines:
            json.loads(line)  # 파싱 실패 시 예외
