"""
test_auto_merge_ttl.py

task-1942: BatchWatchdog.check_batch_ttl() 및 cleanup_expired_batches() 테스트
작성자: 하누만 (개발4팀 테스터)
"""

import json
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import patch

# workspace를 sys.path에 추가
WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace"))
if str(WORKSPACE) not in sys.path:
    sys.path.insert(0, str(WORKSPACE))

# scripts 경로 추가
SCRIPTS = WORKSPACE / "scripts"
if str(SCRIPTS) not in sys.path:
    sys.path.insert(0, str(SCRIPTS))

from auto_merge import BatchWatchdog  # pyright: ignore[reportMissingImports]


# ---------------------------------------------------------------------------
# 헬퍼: task-timers.json 생성
# ---------------------------------------------------------------------------

def _write_timer_file(tmp_path: Path, tasks: dict) -> Path:
    """tmp_path/memory/task-timers.json에 tasks 데이터를 기록하고 경로를 반환한다."""
    memory_dir = tmp_path / "memory"
    memory_dir.mkdir(parents=True, exist_ok=True)
    timer_file = memory_dir / "task-timers.json"
    timer_file.write_text(
        json.dumps({"tasks": tasks}, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )
    return timer_file


# ---------------------------------------------------------------------------
# test_check_batch_ttl_no_expiry
# ---------------------------------------------------------------------------


def test_check_batch_ttl_no_expiry(tmp_path: Path):
    """batch 생성 후 1시간만 경과한 경우 → expired 0건."""
    now = datetime.now(timezone.utc)
    start_time = (now - timedelta(hours=1)).isoformat()

    tasks = {
        "task-100": {
            "batch_id": "batch-001",
            "status": "running",
            "start_time": start_time,
        },
        "task-101": {
            "batch_id": "batch-001",
            "status": "running",
            "start_time": start_time,
        },
    }
    _write_timer_file(tmp_path, tasks)

    wd = BatchWatchdog(workspace_path=str(tmp_path))
    # 기본 TTL 24시간 → 1시간 경과이므로 만료 없음
    result = wd.check_batch_ttl(batch_ttl_hours=24.0)

    assert result == [], f"만료 없어야 하는데 결과: {result}"


# ---------------------------------------------------------------------------
# test_check_batch_ttl_expired
# ---------------------------------------------------------------------------


def test_check_batch_ttl_expired(tmp_path: Path):
    """batch 생성 후 25시간 경과 → running 2개 expired, done 1개는 유지."""
    now = datetime.now(timezone.utc)
    start_time = (now - timedelta(hours=25)).isoformat()

    tasks = {
        "task-200": {
            "batch_id": "batch-002",
            "status": "running",
            "start_time": start_time,
        },
        "task-201": {
            "batch_id": "batch-002",
            "status": "running",
            "start_time": start_time,
        },
        "task-202": {
            "batch_id": "batch-002",
            "status": "done",
            "start_time": start_time,
        },
    }
    _write_timer_file(tmp_path, tasks)

    wd = BatchWatchdog(workspace_path=str(tmp_path))
    result = wd.check_batch_ttl(batch_ttl_hours=24.0)

    assert len(result) == 1, f"batch 1개가 만료되어야 함: {result}"
    expired_entry = result[0]
    assert expired_entry["batch_id"] == "batch-002"
    assert len(expired_entry["expired_tasks"]) == 2, (
        f"running 2개만 expired 되어야 함: {expired_entry['expired_tasks']}"
    )
    assert "task-200" in expired_entry["expired_tasks"]
    assert "task-201" in expired_entry["expired_tasks"]
    assert "task-202" not in expired_entry["expired_tasks"], "done 상태는 expired 대상 아님"

    # timer_file 내 status 확인
    memory_dir = tmp_path / "memory"
    updated = json.loads((memory_dir / "task-timers.json").read_text(encoding="utf-8"))
    assert updated["tasks"]["task-200"]["status"] == "expired"
    assert updated["tasks"]["task-201"]["status"] == "expired"
    assert updated["tasks"]["task-202"]["status"] == "done", "done 상태는 그대로 유지"


# ---------------------------------------------------------------------------
# test_check_batch_ttl_custom_hours
# ---------------------------------------------------------------------------


def test_check_batch_ttl_custom_hours(tmp_path: Path):
    """batch_ttl_hours=2.0으로 설정, 3시간 경과 → expired 발생."""
    now = datetime.now(timezone.utc)
    start_time = (now - timedelta(hours=3)).isoformat()

    tasks = {
        "task-300": {
            "batch_id": "batch-003",
            "status": "running",
            "start_time": start_time,
        },
    }
    _write_timer_file(tmp_path, tasks)

    wd = BatchWatchdog(workspace_path=str(tmp_path))
    result = wd.check_batch_ttl(batch_ttl_hours=2.0)

    assert len(result) == 1, f"TTL 2h 기준 3h 경과이므로 만료 1건이어야 함: {result}"
    assert result[0]["batch_id"] == "batch-003"
    assert "task-300" in result[0]["expired_tasks"]
    assert result[0]["elapsed_hours"] >= 2.0


# ---------------------------------------------------------------------------
# test_cleanup_expired_batches
# ---------------------------------------------------------------------------


def test_cleanup_expired_batches(tmp_path: Path):
    """만료된 batch가 있는 경우:
    - .done 파일이 있는 만료 task → .done.expired 파일 생성 확인
    - send_ttl_warning이 호출되었는지 확인 (mock)
    """
    now = datetime.now(timezone.utc)
    start_time = (now - timedelta(hours=25)).isoformat()

    tasks = {
        "task-400": {
            "batch_id": "batch-004",
            "status": "running",
            "start_time": start_time,
        },
        "task-401": {
            "batch_id": "batch-004",
            "status": "running",
            "start_time": start_time,
        },
    }
    _write_timer_file(tmp_path, tasks)

    # events_dir와 .done 파일 생성
    events_dir = tmp_path / "memory" / "events"
    events_dir.mkdir(parents=True, exist_ok=True)
    done_400 = events_dir / "task-400.done"
    done_400.write_text(
        json.dumps({"task_id": "task-400", "batch_id": "batch-004"}),
        encoding="utf-8",
    )
    done_401 = events_dir / "task-401.done"
    done_401.write_text(
        json.dumps({"task_id": "task-401", "batch_id": "batch-004"}),
        encoding="utf-8",
    )

    wd = BatchWatchdog(workspace_path=str(tmp_path))

    with patch.object(wd, "send_ttl_warning") as mock_warn:
        result = wd.cleanup_expired_batches(batch_ttl_hours=24.0)

    # 반환값 검증
    assert result["expired_count"] == 1
    assert "batch-004" in result["expired_batches"]

    # .done.expired 마커 파일 생성 확인
    assert (events_dir / "task-400.done.expired").exists(), "task-400.done.expired 파일이 없음"
    assert (events_dir / "task-401.done.expired").exists(), "task-401.done.expired 파일이 없음"

    # .done.expired 내용 검증
    expired_data = json.loads((events_dir / "task-400.done.expired").read_text(encoding="utf-8"))
    assert expired_data["batch_id"] == "batch-004"
    assert expired_data["task_id"] == "task-400"
    assert "elapsed_hours" in expired_data
    assert "expired_at" in expired_data

    # send_ttl_warning 호출 확인
    mock_warn.assert_called_once()
    call_args = mock_warn.call_args
    assert call_args[0][0] == "batch-004", f"batch_id 불일치: {call_args}"
    stale_tasks_arg = call_args[0][1]
    task_ids_warned = [t["task_id"] for t in stale_tasks_arg]
    assert "task-400" in task_ids_warned
    assert "task-401" in task_ids_warned


# ---------------------------------------------------------------------------
# test_cleanup_expired_batches_no_expired
# ---------------------------------------------------------------------------


def test_cleanup_expired_batches_no_expired(tmp_path: Path):
    """만료된 batch가 없는 경우 → expired_count=0."""
    now = datetime.now(timezone.utc)
    # 1시간 전 시작 → 24h TTL 미초과
    start_time = (now - timedelta(hours=1)).isoformat()

    tasks = {
        "task-500": {
            "batch_id": "batch-005",
            "status": "running",
            "start_time": start_time,
        },
    }
    _write_timer_file(tmp_path, tasks)

    wd = BatchWatchdog(workspace_path=str(tmp_path))

    with patch.object(wd, "send_ttl_warning") as mock_warn:
        result = wd.cleanup_expired_batches(batch_ttl_hours=24.0)

    assert result["expired_count"] == 0
    assert result["expired_batches"] == []
    mock_warn.assert_not_called()
