"""anu_v3.batch_dependency_matrix — track dependency + overlap checkers.

Authority: task-2553+17.md §4(3,4,5) + §12 9-R.1.

Provides:
  * dependency matrix (declared blocking edges between tracks)
  * expected_files overlap checker
  * forbidden_write_targets overlap checker
  * independence query (does track X completion block track Y?)

Pure stdlib; never mutates existing tracked files.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from itertools import combinations
from typing import Dict, List, Sequence, Tuple


@dataclass
class TrackSpec:
    track_id: str
    expected_files: List[str] = field(default_factory=list)
    forbidden_write_targets: List[str] = field(default_factory=list)
    depends_on: List[str] = field(default_factory=list)
    own_artifacts: List[str] = field(default_factory=list)


class BatchDependencyMatrix:
    def __init__(self, tracks: Sequence[TrackSpec]) -> None:
        self._tracks: Dict[str, TrackSpec] = {}
        for t in tracks:
            if t.track_id in self._tracks:
                raise ValueError(f"duplicate track {t.track_id!r}")
            self._tracks[t.track_id] = t
        self._validate_edges()

    def _validate_edges(self) -> None:
        for t in self._tracks.values():
            for dep in t.depends_on:
                if dep not in self._tracks:
                    raise ValueError(
                        f"track {t.track_id!r} depends on unknown {dep!r}"
                    )
                if dep == t.track_id:
                    raise ValueError(f"track {t.track_id!r} self-dependency")

    # -- matrix ----------------------------------------------------------
    def matrix(self) -> Dict[str, Dict[str, bool]]:
        """matrix[a][b] == True  =>  a depends on (is blocked by) b."""
        ids = list(self._tracks)
        return {
            a: {b: (b in self._tracks[a].depends_on) for b in ids}
            for a in ids
        }

    def is_independent(self, a: str, b: str) -> bool:
        """True if neither track blocks the other (parallel-safe)."""
        return (
            b not in self._tracks[a].depends_on
            and a not in self._tracks[b].depends_on
        )

    def blocks(self, completed_track: str, candidate_track: str) -> bool:
        """Does *candidate_track* have to wait for *completed_track*?"""
        return completed_track in self._tracks[candidate_track].depends_on

    def parallelizable_groups(self) -> List[List[str]]:
        """Topological layers; tracks in the same layer run in parallel."""
        remaining = dict(self._tracks)
        done: set[str] = set()
        layers: List[List[str]] = []
        while remaining:
            layer = [
                tid
                for tid, t in remaining.items()
                if all(d in done for d in t.depends_on)
            ]
            if not layer:
                raise ValueError("dependency cycle detected")
            layers.append(sorted(layer))
            for tid in layer:
                done.add(tid)
                del remaining[tid]
        return layers

    # -- overlap checkers (9-R.1 boundary) ------------------------------
    def expected_files_overlap(self) -> List[Tuple[str, str, List[str]]]:
        out: List[Tuple[str, str, List[str]]] = []
        for a, b in combinations(self._tracks, 2):
            shared = sorted(
                set(self._tracks[a].expected_files)
                & set(self._tracks[b].expected_files)
            )
            if shared:
                out.append((a, b, shared))
        return out

    def forbidden_write_overlap(self) -> List[Tuple[str, str, List[str]]]:
        """Real cross-track write violation = a track writing a file that
        another track *owns* and/or has declared no-write, where the file is
        NOT the writer's own artifact.

        Returns (writer, owner, files). A track legitimately writing its own
        artifact that merely appears in a sibling's forbidden list is NOT a
        violation (that is the sibling declaring "I won't touch this")."""
        out: List[Tuple[str, str, List[str]]] = []
        for a in self._tracks:
            writer = self._tracks[a]
            writer_writes = set(writer.expected_files) | set(
                writer.own_artifacts
            )
            for b in self._tracks:
                if a == b:
                    continue
                owner = self._tracks[b]
                protected = set(owner.own_artifacts) | set(
                    owner.forbidden_write_targets
                )
                # files the writer writes, the owner protects, and the writer
                # does NOT own -> writer reaching into a sibling's zone.
                shared = sorted(
                    (writer_writes & protected)
                    - set(writer.own_artifacts)
                )
                if shared:
                    out.append((a, b, shared))
        return out

    def has_boundary_conflict(self) -> bool:
        return bool(
            self.expected_files_overlap() or self.forbidden_write_overlap()
        )

    def to_dict(self) -> Dict[str, object]:
        return {
            "tracks": list(self._tracks),
            "matrix": self.matrix(),
            "parallelizable_groups": self.parallelizable_groups(),
            "expected_files_overlap": [
                {"a": a, "b": b, "shared": s}
                for (a, b, s) in self.expected_files_overlap()
            ],
            "forbidden_write_overlap": [
                {"writer": a, "owner": b, "shared": s}
                for (a, b, s) in self.forbidden_write_overlap()
            ],
        }
