#!/usr/bin/env python3
"""
Phase 자동 체이닝 시스템 — 멀티팀 Phase 오케스트레이터 (chain.py)

[역할] 여러 팀이 동일 Phase에서 병렬로 작업하고,
       모든 팀 완료 후 자동으로 다음 Phase를 dispatch한다.
[파일] memory/chains/{chain_id}.json (접두어 없음)
[호출] dispatch.py --chain <chain_id> → chain.py task-done
[구분] 순차 작업 체이닝은 chain_manager.py 참조

Usage:
    python3 chain.py create --id insuwiki-p1p2 --desc "InsuWiki Phase 1-2"
    python3 chain.py add-phase --chain insuwiki-p1p2 --name "Phase 1" --tasks '[{"team":"dev1-team","desc":"로그인 개발","level":"normal"}]'
    python3 chain.py task-done --chain insuwiki-p1p2 --task task-1.1
    python3 chain.py status --chain insuwiki-p1p2
    python3 chain.py list

참고: /home/jay/workspace/memory/docs/chaining-architecture.md
"""

import argparse
import fcntl
import json
import os
import re
import shlex
import subprocess
import sys
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path

try:
    from utils.logger import get_logger
except ImportError:
    sys.path.insert(0, str(Path(__file__).parent))
    from utils.logger import get_logger

logger = get_logger(__name__)

# 환경 설정
WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/home/jay/workspace"))
CHAINS_DIR = WORKSPACE / "memory" / "chains"
CHAT_ID = os.environ.get("COKACDIR_CHAT_ID", "6937032012")
ANU_KEY = os.environ.get("COKACDIR_KEY_ANU", "")
DISPATCH_PY = WORKSPACE / "dispatch.py"
ENV_KEYS = WORKSPACE / ".env.keys"

# ---------------------------------------------------------------------------
# 보안: 쉘 인젝션 방어용 ID 검증
# ---------------------------------------------------------------------------

_SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_\-\.]+$")


def _validate_safe_id(value: str, field_name: str) -> str:
    """영숫자+하이픈+밑줄+점만 허용. 인젝션 방어."""
    if not _SAFE_ID_RE.match(value):
        raise ValueError(f"{field_name}에 허용되지 않는 문자: {value!r}")
    return value


# ---------------------------------------------------------------------------
# 파일 락 컨텍스트 매니저
# ---------------------------------------------------------------------------


@contextmanager
def locked_chain_file(chain_path: Path, mode: str = "r+"):
    """chain JSON 파일을 독점 락으로 열고, 컨텍스트 종료 시 언락한다."""
    chain_path.parent.mkdir(parents=True, exist_ok=True)
    file_obj = open(chain_path, mode, encoding="utf-8")
    try:
        fcntl.flock(file_obj, fcntl.LOCK_EX)
        yield file_obj
    finally:
        fcntl.flock(file_obj, fcntl.LOCK_UN)
        file_obj.close()


# ---------------------------------------------------------------------------
# 체인 파일 I/O 헬퍼
# ---------------------------------------------------------------------------


def _chain_path(chain_id: str) -> Path:
    return CHAINS_DIR / f"{chain_id}.json"


def _load_chain(chain_id: str) -> dict:
    """체인 파일을 읽어 dict로 반환. 락 없는 단순 읽기용."""
    path = _chain_path(chain_id)
    if not path.exists():
        print(f"[ERROR] 체인 파일이 없습니다: {path}", file=sys.stderr)
        sys.exit(1)
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def _save_chain(file_obj, data: dict) -> None:
    """이미 열린(락 보유) 파일 객체에 dict를 JSON으로 덮어쓴다."""
    file_obj.seek(0)
    file_obj.truncate()
    json.dump(data, file_obj, ensure_ascii=False, indent=2)
    file_obj.flush()


# ---------------------------------------------------------------------------
# cokacdir 알림 헬퍼
# ---------------------------------------------------------------------------


def _cron_notify(prompt: str, delay: str = "10s") -> None:
    """아누에게 cokacdir --cron으로 알림을 등록한다."""
    if not ANU_KEY:
        logger.warning("COKACDIR_KEY_ANU 미설정, 알림 스킵")
        return
    cmd = [
        "cokacdir",
        "--cron",
        prompt,
        "--at",
        delay,
        "--chat",
        CHAT_ID,
        "--key",
        ANU_KEY,
        "--once",
    ]
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
    except subprocess.TimeoutExpired:
        logger.warning(f"cokacdir 알림 등록 타임아웃 (60초 초과): {prompt[:60]}")
        return
    if result.returncode != 0:
        logger.warning(f"cokacdir 알림 등록 실패: {result.stderr.strip()}")
    else:
        logger.info(f"알림 등록 완료: {prompt[:60]}")


# ---------------------------------------------------------------------------
# 서브커맨드: create
# ---------------------------------------------------------------------------


def cmd_create(args) -> None:
    """빈 체인 파일을 생성한다."""
    CHAINS_DIR.mkdir(parents=True, exist_ok=True)
    chain_path = _chain_path(args.id)

    if chain_path.exists():
        print(f"[ERROR] 이미 존재하는 체인입니다: {args.id}", file=sys.stderr)
        sys.exit(1)

    data = {
        "chain_id": args.id,
        "description": args.desc,
        "status": "active",
        "current_phase_idx": 0,
        "created_at": datetime.now().isoformat(),
        "phases": [],
    }

    with open(chain_path, "w", encoding="utf-8") as f:
        fcntl.flock(f, fcntl.LOCK_EX)
        json.dump(data, f, ensure_ascii=False, indent=2)
        fcntl.flock(f, fcntl.LOCK_UN)

    print(f"[OK] 체인 생성 완료: {chain_path}")
    logger.info(f"체인 생성: {args.id}")


# ---------------------------------------------------------------------------
# 서브커맨드: add-phase
# ---------------------------------------------------------------------------


def cmd_add_phase(args) -> None:
    """체인에 Phase를 추가한다."""
    try:
        raw_tasks = json.loads(args.tasks)
    except json.JSONDecodeError as e:
        print(f"[ERROR] tasks JSON 파싱 실패: {e}", file=sys.stderr)
        sys.exit(1)

    # tasks 스키마 검증 및 정규화
    normalized_tasks = []
    for t in raw_tasks:
        if "team" not in t or "desc" not in t:
            print("[ERROR] tasks 항목에 'team', 'desc' 필드가 필요합니다.", file=sys.stderr)
            sys.exit(1)
        normalized_tasks.append(
            {
                "team": t["team"],
                "task_id": None,
                "description": t["desc"],
                "level": t.get("level", "normal"),
                "status": "pending",
                "dispatched_at": None,
                "completed_at": None,
            }
        )

    new_phase = {
        "name": args.name,
        "status": "pending",
        "tasks": normalized_tasks,
    }

    chain_path = _chain_path(args.chain)
    if not chain_path.exists():
        print(f"[ERROR] 체인 파일이 없습니다: {args.chain}", file=sys.stderr)
        sys.exit(1)

    with locked_chain_file(chain_path, "r+") as f:
        data = json.load(f)
        data["phases"].append(new_phase)
        _save_chain(f, data)

    phase_idx = len(data["phases"]) - 1
    print(f"[OK] Phase 추가 완료: {args.chain} → Phase[{phase_idx}] '{args.name}' ({len(normalized_tasks)}개 task)")
    logger.info(f"Phase 추가: chain={args.chain}, phase={args.name}, tasks={len(normalized_tasks)}")


# ---------------------------------------------------------------------------
# 서브커맨드: task-done
# ---------------------------------------------------------------------------


def _dispatch_phase(data: dict, phase_idx: int, chain_id: str) -> list[dict]:
    """지정 Phase의 모든 tasks를 dispatch.py를 통해 위임하고 결과를 반환한다."""
    # chain_id 검증
    _validate_safe_id(chain_id, "chain_id")

    phase = data["phases"][phase_idx]
    results = []
    for task in phase["tasks"]:
        # team, level 검증
        _validate_safe_id(task["team"], "team")
        _validate_safe_id(task["level"], "level")

        # description을 임시 파일로 저장하여 특수문자 깨짐 방지
        tasks_dir = WORKSPACE / "memory" / "tasks"
        tasks_dir.mkdir(parents=True, exist_ok=True)
        task_file_name = f"chain-{chain_id}-phase{phase_idx}-{task['team']}.md"
        task_file_path = tasks_dir / task_file_name
        task_file_path.write_text(task["description"], encoding="utf-8")

        # 안전한 방식: .env.keys에서 환경변수 로드 후 subprocess 리스트 방식 호출
        env = os.environ.copy()
        env_keys_path = str(ENV_KEYS)
        if Path(env_keys_path).exists():
            load_result = subprocess.run(
                ["bash", "-c", f"source {shlex.quote(env_keys_path)} && env"],
                capture_output=True,
                text=True,
            )
            for line in load_result.stdout.splitlines():
                if "=" in line:
                    k, v = line.split("=", 1)
                    env[k] = v

        cmd = [
            "python3",
            str(DISPATCH_PY),
            "--team",
            task["team"],
            "--task-file",
            str(task_file_path),
            "--level",
            task["level"],
            "--chain",
            chain_id,
        ]
        logger.info(f"dispatch 호출: team={task['team']}, chain={chain_id}")
        try:
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=60,
                env=env,
            )
        except subprocess.TimeoutExpired:
            logger.error(f"dispatch 타임아웃: team={task['team']}, chain={chain_id} (60초 초과)")
            task["status"] = "dispatch_error"
            task["error"] = "dispatch 타임아웃 (60초 초과)"
            results.append(
                {
                    "team": task["team"],
                    "task_id": None,
                    "status": "error",
                    "error": "dispatch 타임아웃 (60초 초과)",
                }
            )
            continue
        if result.returncode == 0:
            try:
                resp = json.loads(result.stdout)
            except json.JSONDecodeError:
                resp = {"raw": result.stdout.strip()}
            task_id = resp.get("task_id")
            task["task_id"] = task_id
            task["dispatched_at"] = datetime.now().isoformat()
            task["status"] = "in_progress"
            results.append({"team": task["team"], "task_id": task_id, "status": "dispatched"})
            logger.info(f"dispatch 성공: team={task['team']}, task_id={task_id}")
        else:
            err_msg = result.stderr.strip() or result.stdout.strip()
            task["status"] = "dispatch_error"
            task["error"] = err_msg
            results.append({"team": task["team"], "task_id": None, "status": "error", "error": err_msg})
            logger.error(f"dispatch 실패: team={task['team']}, error={err_msg}")
    return results


def cmd_task_done(args) -> None:
    """특정 task를 완료 마킹하고, Phase 전환 로직을 수행한다."""
    chain_path = _chain_path(args.chain)
    if not chain_path.exists():
        print(f"[ERROR] 체인 파일이 없습니다: {args.chain}", file=sys.stderr)
        sys.exit(1)

    with locked_chain_file(chain_path, "r+") as f:
        data = json.load(f)

        # 안전장치: chain이 paused 상태면 처리 중단
        if data.get("status") == "paused":
            print(f"[WARNING] 체인 '{args.chain}'이 paused 상태입니다. task-done을 무시합니다.", file=sys.stderr)
            logger.warning(f"paused 체인에 task-done 호출 무시: chain={args.chain}, task={args.task}")
            _save_chain(f, data)
            return

        # 1. 해당 task_id를 찾아 completed 마킹
        target_task = None
        target_phase_idx = None
        for p_idx, phase in enumerate(data["phases"]):
            for task in phase["tasks"]:
                if task.get("task_id") == args.task:
                    target_task = task
                    target_phase_idx = p_idx
                    break
            if target_task:
                break

        if target_task is None:
            print(f"[ERROR] task_id '{args.task}'를 체인 '{args.chain}'에서 찾을 수 없습니다.", file=sys.stderr)
            logger.error(f"task_id 미발견: chain={args.chain}, task={args.task}")
            _save_chain(f, data)
            sys.exit(1)

        target_task["status"] = "completed"
        target_task["completed_at"] = datetime.now().isoformat()
        logger.info(f"task 완료 마킹: chain={args.chain}, task={args.task}, phase_idx={target_phase_idx}")

        # 2. 현재 Phase의 모든 tasks 완료 여부 확인
        current_phase_idx = data["current_phase_idx"]
        current_phase = data["phases"][current_phase_idx]
        all_completed = all(t["status"] == "completed" for t in current_phase["tasks"])

        if not all_completed:
            # 아직 완료되지 않은 task가 있으면 저장 후 종료
            remaining = sum(1 for t in current_phase["tasks"] if t["status"] != "completed")
            print(f"[OK] task '{args.task}' 완료. 현재 Phase에 {remaining}개 task 남음.")
            _save_chain(f, data)
            return

        # 3. 전팀 완료: 현재 Phase status를 completed로 변경
        current_phase["status"] = "completed"
        logger.info(f"Phase 완료: chain={args.chain}, phase_idx={current_phase_idx}, name='{current_phase['name']}'")

        next_phase_idx = current_phase_idx + 1
        total_phases = len(data["phases"])

        if next_phase_idx < total_phases:
            # 4. 다음 Phase가 있으면 dispatch
            next_phase = data["phases"][next_phase_idx]

            # 안전장치 재확인 (dispatch 전 paused 여부)
            if data.get("status") == "paused":
                print(
                    f"[WARNING] 체인 '{args.chain}'이 paused 상태입니다. 다음 Phase dispatch를 건너뜁니다.",
                    file=sys.stderr,
                )
                _save_chain(f, data)
                return

            dispatch_results = _dispatch_phase(data, next_phase_idx, args.chain)

            # dispatch 오류 확인 → 오류 시 체인 pause
            errors = [r for r in dispatch_results if r["status"] == "error"]
            if errors:
                data["status"] = "paused"
                data["error"] = {
                    "phase_idx": next_phase_idx,
                    "phase_name": next_phase["name"],
                    "dispatch_errors": errors,
                    "occurred_at": datetime.now().isoformat(),
                }
                _save_chain(f, data)
                print(f"[ERROR] 다음 Phase dispatch 중 오류 발생. 체인을 paused 상태로 변경합니다.", file=sys.stderr)
                for e in errors:
                    print(f"  - team={e['team']}: {e.get('error', '알 수 없는 오류')}", file=sys.stderr)
                logger.error(f"Phase dispatch 실패로 체인 pause: chain={args.chain}, phase_idx={next_phase_idx}")
                return

            # dispatch 성공: Phase 상태 업데이트
            next_phase["status"] = "in_progress"
            data["current_phase_idx"] = next_phase_idx
            _save_chain(f, data)

            print(f"[OK] {current_phase['name']} 전팀 완료. {next_phase['name']} 자동 dispatch 완료.")
            for r in dispatch_results:
                print(f"  - team={r['team']}, task_id={r['task_id']}")

            # 5. 제이회장님 중간보고 알림 (Phase 전환)
            notify_msg = (
                f"체인 {args.chain}의 {current_phase['name']} 전팀 완료. " f"{next_phase['name']} 자동 dispatch됨."
            )
            _cron_notify(notify_msg)
            logger.info(f"Phase 전환 알림 등록: {notify_msg}")

        else:
            # 5. 마지막 Phase 완료: 체인 전체 완료
            data["status"] = "completed"
            data["completed_at"] = datetime.now().isoformat()
            _save_chain(f, data)

            print(f"[OK] 체인 '{args.chain}' 전체 완료!")
            logger.info(f"체인 전체 완료: chain={args.chain}")

            # 아누에게 종합 보고 알림
            notify_msg = f"체인 {args.chain} 전체 완료. 종합 보고 작성하라"
            _cron_notify(notify_msg)

            # 제이회장님 최종 완료 보고
            final_msg = (
                f"체인 {args.chain}의 {current_phase['name']} 전팀 완료. "
                f"모든 Phase가 완료되어 체인이 종료되었습니다."
            )
            _cron_notify(final_msg)
            logger.info(f"체인 완료 알림 등록: chain={args.chain}")


# ---------------------------------------------------------------------------
# 서브커맨드: status
# ---------------------------------------------------------------------------


def cmd_status(args) -> None:
    """현재 체인 상태를 JSON으로 출력한다."""
    data = _load_chain(args.chain)
    print(json.dumps(data, ensure_ascii=False, indent=2))


# ---------------------------------------------------------------------------
# 서브커맨드: list
# ---------------------------------------------------------------------------


def cmd_list(args) -> None:
    """모든 활성 체인 목록을 출력한다."""
    CHAINS_DIR.mkdir(parents=True, exist_ok=True)
    chain_files = sorted(CHAINS_DIR.glob("*.json"))

    if not chain_files:
        print("활성 체인이 없습니다.")
        return

    active = []
    for path in chain_files:
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
            active.append(
                {
                    "chain_id": data.get("chain_id"),
                    "description": data.get("description"),
                    "status": data.get("status"),
                    "current_phase_idx": data.get("current_phase_idx"),
                    "total_phases": len(data.get("phases", [])),
                    "created_at": data.get("created_at"),
                }
            )
        except (json.JSONDecodeError, OSError) as e:
            logger.warning(f"체인 파일 읽기 실패: {path}, {e}")
            active.append({"chain_id": path.stem, "error": str(e)})

    print(json.dumps(active, ensure_ascii=False, indent=2))


# ---------------------------------------------------------------------------
# 메인
# ---------------------------------------------------------------------------


def main() -> None:
    # .env.keys 자동 로드 (환경변수 누락 방지)
    from utils.env_loader import load_env_keys

    load_env_keys()

    parser = argparse.ArgumentParser(
        description="Phase 자동 체이닝 시스템",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    # create
    p_create = subparsers.add_parser("create", help="빈 체인 파일 생성")
    p_create.add_argument("--id", required=True, help="체인 ID (예: insuwiki-p1p2)")
    p_create.add_argument("--desc", required=True, help="체인 설명")

    # add-phase
    p_add = subparsers.add_parser("add-phase", help="체인에 Phase 추가")
    p_add.add_argument("--chain", required=True, help="체인 ID")
    p_add.add_argument("--name", required=True, help="Phase 이름 (예: Phase 1)")
    p_add.add_argument(
        "--tasks", required=True, help='tasks JSON 배열 (예: [{"team":"dev1-team","desc":"...","level":"normal"}])'
    )

    # task-done
    p_done = subparsers.add_parser("task-done", help="특정 task 완료 마킹 및 Phase 전환")
    p_done.add_argument("--chain", required=True, help="체인 ID")
    p_done.add_argument("--task", required=True, help="완료할 task_id")

    # status
    p_status = subparsers.add_parser("status", help="체인 상태 조회")
    p_status.add_argument("--chain", required=True, help="체인 ID")

    # list
    subparsers.add_parser("list", help="모든 체인 목록 조회")

    args = parser.parse_args()

    dispatch_table = {
        "create": cmd_create,
        "add-phase": cmd_add_phase,
        "task-done": cmd_task_done,
        "status": cmd_status,
        "list": cmd_list,
    }

    handler = dispatch_table.get(args.command)
    if handler:
        handler(args)
    else:
        parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    main()
