#!/usr/bin/env python3
"""Token usage tracker: scans Claude JSONL sessions, builds token ledger."""

import argparse
import glob
import json
import os
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Optional

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", str(Path(__file__).resolve().parent.parent))
LEDGER_PATH = str(Path(_WORKSPACE_ROOT) / "memory" / "token-ledger.json")
TIMERS_PATH = str(Path(_WORKSPACE_ROOT) / "memory" / "task-timers.json")
BOT_GLOB = "/home/jay/.claude/projects/-home-jay-workspace/*.jsonl"
COKAC_GLOB = "/home/jay/.claude/projects/-home-jay--cokacdir-workspace-*/*.jsonl"
PRICES: dict[str, dict[str, float]] = {
    "claude-opus-4-6": {"input": 15.0, "output": 75.0},
    "claude-sonnet-4-6": {"input": 3.0, "output": 15.0},
    "claude-haiku-4-5-20251001": {"input": 0.80, "output": 4.0},
}
TASK_RE = re.compile(r"task-\d+(?:_\d+\.\d+)?(?:_[a-z])?(?:\+\d+)?")  # Mirror of utils.task_id_parser V2 pattern
TEAM_RE = re.compile(
    r"(?:팀|team_id|team)\s*[:：]\s*(dev\d+-team|marketing|anu-direct|[a-z][a-z0-9]*-[a-z][a-z0-9-]*[a-z0-9])"
)
# fmt: off
_TOK = ("input_tokens", "cache_creation_tokens", "cache_read_tokens", "output_tokens")
_EK = ("session_id", "team_id", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "message_count", "timestamp", "cwd_project", "mcp_used")
# fmt: on
_CWD_PROJECTS: list[tuple[str, str]] = [
    ("/home/jay/projects/InsuRo", "InsuRo"),
    ("/home/jay/projects/insuwiki", "insuwiki"),
]
_MCP_TOOL_MARKERS = ("code-review-graph",)


def _text(c: Any) -> str:
    if isinstance(c, str):
        return c
    return " ".join(i.get("text", "") for i in c if isinstance(i, dict)) if isinstance(c, list) else ""


def parse_session(jsonl_path: str) -> dict[str, Any]:  # noqa: C901
    r: dict[str, Any] = dict.fromkeys(_TOK, 0)
    r.update(
        message_count=0,
        model="",
        session_id="",
        task_id="",
        team_id="",
        timestamp="",
        total_tokens=0,
        cwd_project="",
        mcp_used=False,
    )
    path = Path(jsonl_path)
    if not path.exists():
        return r
    for raw in path.read_text(encoding="utf-8").splitlines():
        raw = raw.strip()
        if not raw:
            continue
        try:
            obj = json.loads(raw)
        except json.JSONDecodeError:
            continue
        if not r["session_id"] and obj.get("sessionId"):
            r["session_id"] = obj["sessionId"]
        typ = obj.get("type", "")
        if typ == "user":
            txt = _text(obj.get("message", {}).get("content", ""))
            if not r["task_id"] and (m := TASK_RE.search(txt)):
                r["task_id"] = m.group(0)
            if not r["team_id"] and (tm := TEAM_RE.search(txt)):
                r["team_id"] = tm.group(1)
            if not r["timestamp"]:
                r["timestamp"] = obj.get("timestamp", "")
            if not r["cwd_project"]:
                for path_prefix, slug in _CWD_PROJECTS:
                    if path_prefix in txt:
                        r["cwd_project"] = slug
                        break
            if not r["mcp_used"]:
                for marker in _MCP_TOOL_MARKERS:
                    if marker in txt:
                        r["mcp_used"] = True
                        break
        elif typ == "assistant":
            msg = obj.get("message", {})
            content = msg.get("content", "")
            txt_a = _text(content)
            if not r["cwd_project"]:
                for path_prefix, slug in _CWD_PROJECTS:
                    if path_prefix in txt_a:
                        r["cwd_project"] = slug
                        break
            if not r["mcp_used"]:
                for marker in _MCP_TOOL_MARKERS:
                    if marker in txt_a:
                        r["mcp_used"] = True
                        break
            if not r["mcp_used"] and isinstance(content, list):
                for item in content:
                    if isinstance(item, dict) and item.get("type") == "tool_use":
                        name = item.get("name", "")
                        for marker in _MCP_TOOL_MARKERS:
                            if marker in name:
                                r["mcp_used"] = True
                                break
                    if r["mcp_used"]:
                        break
            u = msg.get("usage", {})
            if u:
                r["input_tokens"] += u.get("input_tokens", 0)
                r["cache_creation_tokens"] += u.get("cache_creation_input_tokens", 0)
                r["cache_read_tokens"] += u.get("cache_read_input_tokens", 0)
                r["output_tokens"] += u.get("output_tokens", 0)
                r["message_count"] += 1
                if not r["model"] and msg.get("model"):
                    r["model"] = msg["model"]
    r["total_tokens"] = sum(r[k] for k in _TOK)
    return r


def compute_cost(ud: dict[str, Any], model: str) -> float:
    p = PRICES.get(model)
    if not p:
        return 0.0
    M = 1_000_000.0
    return (
        ud.get("input_tokens", 0) * p["input"] / M
        + ud.get("cache_read_tokens", 0) * p["input"] * 0.1 / M
        + ud.get("cache_creation_tokens", 0) * p["input"] * 0.25 / M
        + ud.get("output_tokens", 0) * p["output"] / M
    )


def scan(glob_paths: Optional[list[str]] = None, ledger_path: str = LEDGER_PATH) -> None:
    files = (
        glob.glob(BOT_GLOB) + glob.glob(COKAC_GLOB)
        if glob_paths is None
        else [f for gp in glob_paths for f in (glob.glob(gp) if "*" in gp else [gp])]
    )
    tasks: dict[str, Any] = {}
    for f in files:
        s = parse_session(f)
        if not s["task_id"]:
            continue
        mdl = s["model"] or "claude-sonnet-4-6"
        entry = {k: s[k] for k in _EK}
        entry.update(model=mdl, cost_estimate_usd=round(compute_cost(s, mdl), 6))
        tasks[s["task_id"]] = entry
    tot = sum(v["total_tokens"] for v in tasks.values())
    n = len(tasks)
    top = sorted(tasks.items(), key=lambda x: x[1]["total_tokens"], reverse=True)[:5]
    ledger: dict[str, Any] = {
        "tasks": tasks,
        "summary": {
            "total_tokens": tot,
            "total_tasks": n,
            "avg_tokens_per_task": tot // n if n else 0,
            "total_cost_usd": round(sum(v["cost_estimate_usd"] for v in tasks.values()), 6),
            "top_consumers": [{"task_id": k, "total_tokens": v["total_tokens"]} for k, v in top],
        },
        "last_scan": datetime.now().isoformat(timespec="seconds"),
    }
    Path(ledger_path).parent.mkdir(parents=True, exist_ok=True)
    Path(ledger_path).write_text(json.dumps(ledger, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Scanned {len(files)} files, recorded {n} tasks -> {ledger_path}")


def _load(ledger_path: str) -> dict[str, Any]:
    p = Path(ledger_path)
    return json.loads(p.read_text(encoding="utf-8")) if p.exists() else {"tasks": {}, "summary": {}}


def get_task(task_id: str, ledger_path: str = LEDGER_PATH) -> Optional[dict[str, Any]]:
    return _load(ledger_path)["tasks"].get(task_id)


def get_summary(ledger_path: str = LEDGER_PATH) -> dict[str, Any]:
    return _load(ledger_path).get("summary", {})


def detect_anomaly(ledger_path: str = LEDGER_PATH) -> list[dict[str, Any]]:
    ld = _load(ledger_path)
    avg = ld.get("summary", {}).get("avg_tokens_per_task", 0)
    if not avg:
        return []
    return [
        {"task_id": k, "total_tokens": v["total_tokens"], "ratio": v["total_tokens"] / avg}
        for k, v in ld.get("tasks", {}).items()
        if v["total_tokens"] >= avg * 2
    ]


def enrich(timers_path: str = TIMERS_PATH, ledger_path: str = LEDGER_PATH) -> None:
    tp = Path(timers_path)
    if not tp.exists():
        print(f"task-timers.json not found: {timers_path}")
        return
    timers = json.loads(tp.read_text(encoding="utf-8"))
    tl = _load(ledger_path).get("tasks", {})
    changed = 0
    for tid, td in timers.get("tasks", {}).items():
        if tid in tl:
            ld = tl[tid]
            keys2 = (
                "input_tokens",
                "cache_creation_tokens",
                "cache_read_tokens",
                "output_tokens",
                "total_tokens",
                "cost_estimate_usd",
            )
            td["token_usage"] = {k: ld.get(k, 0) for k in keys2}
            if ld.get("model"):
                td["token_usage"]["model"] = ld["model"]
            if ld.get("cwd_project"):
                td["cwd_project"] = ld["cwd_project"]
            td["mcp_used"] = ld.get("mcp_used", False)
            changed += 1
    tp.write_text(json.dumps(timers, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Enriched {changed} tasks in {timers_path}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Token usage tracker")
    sub = parser.add_subparsers(dest="cmd")
    sub.add_parser("scan")
    p_get = sub.add_parser("get")
    p_get.add_argument("--task", required=True)
    sub.add_parser("summary")
    sub.add_parser("anomaly")
    sub.add_parser("enrich")
    args = parser.parse_args()
    if args.cmd == "scan":
        scan()
    elif args.cmd == "get":
        data = get_task(args.task)
        print(json.dumps(data, ensure_ascii=False, indent=2) if data else "Not found")
    elif args.cmd == "summary":
        print(json.dumps(get_summary(), ensure_ascii=False, indent=2))
    elif args.cmd == "anomaly":
        print(json.dumps(detect_anomaly(), ensure_ascii=False, indent=2))
    elif args.cmd == "enrich":
        enrich()
    else:
        parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    main()
