#!/usr/bin/env python3
"""Token usage tracker: scans Claude JSONL sessions, builds token ledger."""

import argparse
import glob
import json
import os
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Optional

_WORKSPACE_ROOT = os.environ.get("WORKSPACE_ROOT", str(Path(__file__).resolve().parent.parent))
LEDGER_PATH = str(Path(_WORKSPACE_ROOT) / "memory" / "token-ledger.json")
TIMERS_PATH = str(Path(_WORKSPACE_ROOT) / "memory" / "task-timers.json")
BOT_GLOB = "/home/jay/.claude/projects/-home-jay-workspace/*.jsonl"
COKAC_GLOB = "/home/jay/.claude/projects/-home-jay--cokacdir-workspace-*/*.jsonl"
PRICES: dict[str, dict[str, float]] = {
    "claude-opus-4-6": {"input": 15.0, "output": 75.0},
    "claude-sonnet-4-6": {"input": 3.0, "output": 15.0},
    "claude-haiku-4-5-20251001": {"input": 0.80, "output": 4.0},
}
TASK_RE = re.compile(r"task-\d+\.\d+")
TEAM_RE = re.compile(
    r"(?:팀|team_id|team)\s*[:：]\s*(dev\d+-team|marketing|anu-direct|[a-z][a-z0-9]*-[a-z][a-z0-9-]*[a-z0-9])"
)
# fmt: off
_TOK = ("input_tokens", "cache_creation_tokens", "cache_read_tokens", "output_tokens")
_EK = ("session_id", "team_id", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "message_count", "timestamp")
# fmt: on


def _text(c: Any) -> str:
    if isinstance(c, str):
        return c
    return " ".join(i.get("text", "") for i in c if isinstance(i, dict)) if isinstance(c, list) else ""


def parse_session(jsonl_path: str) -> dict[str, Any]:  # noqa: C901
    r: dict[str, Any] = dict.fromkeys(_TOK, 0)
    r.update(message_count=0, model="", session_id="", task_id="", team_id="", timestamp="", total_tokens=0)
    path = Path(jsonl_path)
    if not path.exists():
        return r
    for raw in path.read_text(encoding="utf-8").splitlines():
        raw = raw.strip()
        if not raw:
            continue
        try:
            obj = json.loads(raw)
        except json.JSONDecodeError:
            continue
        if not r["session_id"] and obj.get("sessionId"):
            r["session_id"] = obj["sessionId"]
        typ = obj.get("type", "")
        if typ == "user":
            txt = _text(obj.get("message", {}).get("content", ""))
            if not r["task_id"] and (m := TASK_RE.search(txt)):
                r["task_id"] = m.group(0)
            if not r["team_id"] and (tm := TEAM_RE.search(txt)):
                r["team_id"] = tm.group(1)
            if not r["timestamp"]:
                r["timestamp"] = obj.get("timestamp", "")
        elif typ == "assistant":
            msg = obj.get("message", {})
            u = msg.get("usage", {})
            if u:
                r["input_tokens"] += u.get("input_tokens", 0)
                r["cache_creation_tokens"] += u.get("cache_creation_input_tokens", 0)
                r["cache_read_tokens"] += u.get("cache_read_input_tokens", 0)
                r["output_tokens"] += u.get("output_tokens", 0)
                r["message_count"] += 1
                if not r["model"] and msg.get("model"):
                    r["model"] = msg["model"]
    r["total_tokens"] = sum(r[k] for k in _TOK)
    return r


def compute_cost(ud: dict[str, Any], model: str) -> float:
    p = PRICES.get(model)
    if not p:
        return 0.0
    M = 1_000_000.0
    return (
        ud.get("input_tokens", 0) * p["input"] / M
        + ud.get("cache_read_tokens", 0) * p["input"] * 0.1 / M
        + ud.get("cache_creation_tokens", 0) * p["input"] * 0.25 / M
        + ud.get("output_tokens", 0) * p["output"] / M
    )


def scan(glob_paths: Optional[list[str]] = None, ledger_path: str = LEDGER_PATH) -> None:
    files = (
        glob.glob(BOT_GLOB) + glob.glob(COKAC_GLOB)
        if glob_paths is None
        else [f for gp in glob_paths for f in (glob.glob(gp) if "*" in gp else [gp])]
    )
    tasks: dict[str, Any] = {}
    for f in files:
        s = parse_session(f)
        if not s["task_id"]:
            continue
        mdl = s["model"] or "claude-sonnet-4-6"
        entry = {k: s[k] for k in _EK}
        entry.update(model=mdl, cost_estimate_usd=round(compute_cost(s, mdl), 6))
        tasks[s["task_id"]] = entry
    tot = sum(v["total_tokens"] for v in tasks.values())
    n = len(tasks)
    top = sorted(tasks.items(), key=lambda x: x[1]["total_tokens"], reverse=True)[:5]
    ledger: dict[str, Any] = {
        "tasks": tasks,
        "summary": {
            "total_tokens": tot,
            "total_tasks": n,
            "avg_tokens_per_task": tot // n if n else 0,
            "total_cost_usd": round(sum(v["cost_estimate_usd"] for v in tasks.values()), 6),
            "top_consumers": [{"task_id": k, "total_tokens": v["total_tokens"]} for k, v in top],
        },
        "last_scan": datetime.now().isoformat(timespec="seconds"),
    }
    Path(ledger_path).parent.mkdir(parents=True, exist_ok=True)
    Path(ledger_path).write_text(json.dumps(ledger, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Scanned {len(files)} files, recorded {n} tasks -> {ledger_path}")


def _load(ledger_path: str) -> dict[str, Any]:
    p = Path(ledger_path)
    return json.loads(p.read_text(encoding="utf-8")) if p.exists() else {"tasks": {}, "summary": {}}


def get_task(task_id: str, ledger_path: str = LEDGER_PATH) -> Optional[dict[str, Any]]:
    return _load(ledger_path)["tasks"].get(task_id)


def get_summary(ledger_path: str = LEDGER_PATH) -> dict[str, Any]:
    return _load(ledger_path).get("summary", {})


def detect_anomaly(ledger_path: str = LEDGER_PATH) -> list[dict[str, Any]]:
    ld = _load(ledger_path)
    avg = ld.get("summary", {}).get("avg_tokens_per_task", 0)
    if not avg:
        return []
    return [
        {"task_id": k, "total_tokens": v["total_tokens"], "ratio": v["total_tokens"] / avg}
        for k, v in ld.get("tasks", {}).items()
        if v["total_tokens"] >= avg * 2
    ]


def enrich(timers_path: str = TIMERS_PATH, ledger_path: str = LEDGER_PATH) -> None:
    tp = Path(timers_path)
    if not tp.exists():
        print(f"task-timers.json not found: {timers_path}")
        return
    timers = json.loads(tp.read_text(encoding="utf-8"))
    tl = _load(ledger_path).get("tasks", {})
    changed = 0
    for tid, td in timers.get("tasks", {}).items():
        if tid in tl:
            ld = tl[tid]
            keys2 = ("input_tokens", "output_tokens", "total_tokens", "cost_estimate_usd")
            td["token_usage"] = {k: ld.get(k, 0) for k in keys2}
            changed += 1
    tp.write_text(json.dumps(timers, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Enriched {changed} tasks in {timers_path}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Token usage tracker")
    sub = parser.add_subparsers(dest="cmd")
    sub.add_parser("scan")
    p_get = sub.add_parser("get")
    p_get.add_argument("--task", required=True)
    sub.add_parser("summary")
    sub.add_parser("anomaly")
    sub.add_parser("enrich")
    args = parser.parse_args()
    if args.cmd == "scan":
        scan()
    elif args.cmd == "get":
        data = get_task(args.task)
        print(json.dumps(data, ensure_ascii=False, indent=2) if data else "Not found")
    elif args.cmd == "summary":
        print(json.dumps(get_summary(), ensure_ascii=False, indent=2))
    elif args.cmd == "anomaly":
        print(json.dumps(detect_anomaly(), ensure_ascii=False, indent=2))
    elif args.cmd == "enrich":
        enrich()
    else:
        parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    main()
