#!/usr/bin/env python3
"""
ast_dependency_map.py — AST 기반 Python 의존성 맵 생성기

주어진 변경 파일 목록에 대해 영향 범위(blast radius)를 분석합니다.
- 직접 임포터 (direct importers)
- 전이적 의존자 (transitive dependents, 최대 2홉)
- 테스트 파일 (test_*.py 또는 tests/ 디렉토리)
- 함수 레벨 호출자 (function-level callers)

Usage:
    python3 ast_dependency_map.py --root /home/jay/workspace/dashboard --files data_loader.py
    python3 ast_dependency_map.py --root /home/jay/workspace/dashboard --files data_loader.py --function get_member_status
    python3 ast_dependency_map.py --root /home/jay/workspace/dashboard --files wiki_engine.py --function sync_firestore
"""

import argparse
import ast
import json
import logging
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# AST 파싱 헬퍼
# ---------------------------------------------------------------------------


def _parse_file(file_path: Path) -> Optional[ast.Module]:
    """파일을 AST로 파싱합니다. 구문 오류가 있으면 None을 반환합니다."""
    try:
        # 대형 파일 보호: 100KB 초과 파일은 스킵
        file_size = file_path.stat().st_size
        if file_size > 100_000:
            logger.warning("대형 파일 스킵 (%.1fKB): %s", file_size / 1024, file_path)
            return None
        source = file_path.read_text(encoding="utf-8", errors="replace")
        return ast.parse(source, filename=str(file_path))
    except SyntaxError:
        return None
    except Exception:
        return None


def _extract_imported_modules(tree: ast.Module, pkg_name: str) -> Set[str]:
    """
    AST에서 임포트된 모듈 이름을 추출합니다.

    다음 패턴을 처리합니다:
    - import X
    - from X import Y
    - from dashboard.X import Y  →  X
    - from .X import Y  →  X (상대 임포트)
    - try/except 블록 내의 임포트 패턴

    반환값: 해당 파일이 의존하는 모듈명 집합 (예: {"data_loader", "server_utils"})
    """
    modules: Set[str] = set()

    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                # import X  또는  import dashboard.X
                name = alias.name
                # dashboard.X 패턴 → X 추출
                if "." in name:
                    parts = name.split(".")
                    # 패키지 이름이 첫 번째 부분인 경우 두 번째 부분 사용
                    if parts[0] == pkg_name and len(parts) > 1:
                        modules.add(parts[1])
                    else:
                        # 일반 모듈 (마지막 이름)
                        modules.add(parts[-1])
                else:
                    modules.add(name)

        elif isinstance(node, ast.ImportFrom):
            module = node.module or ""
            level = node.level  # 상대 임포트 레벨 (0=절대, 1=., 2=..)

            if level > 0:
                # from . import X  또는  from .X import Y
                if module:
                    # from .X import Y → X
                    top = module.split(".")[0]
                    modules.add(top)
                else:
                    # from . import X → X (각 name이 모듈)
                    for alias in node.names:
                        modules.add(alias.name)
            else:
                # 절대 임포트
                if module:
                    parts = module.split(".")
                    if parts[0] == pkg_name and len(parts) > 1:
                        # from dashboard.X import Y → X
                        modules.add(parts[1])
                    elif len(parts) == 1:
                        # from X import Y → X
                        modules.add(parts[0])
                    else:
                        # from some.other.package import ... → 첫 부분만
                        modules.add(parts[0])

    return modules


def _extract_function_calls(tree: ast.Module, imported_names: Dict[str, str]) -> List[Tuple[str, int]]:
    """
    AST에서 특정 함수 호출을 찾습니다.

    imported_names: {함수명: 소스모듈명} 매핑
    반환값: [(함수명, 줄번호), ...]
    """
    calls: List[Tuple[str, int]] = []

    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            func = node.func
            # 단순 함수 호출: func_name(...)
            if isinstance(func, ast.Name):
                if func.id in imported_names:
                    calls.append((func.id, node.lineno))
            # 속성 호출: obj.method(...)
            elif isinstance(func, ast.Attribute):
                name = func.attr
                if name in imported_names:
                    calls.append((name, node.lineno))

    return calls


def _extract_imported_names_from_module(tree: ast.Module, target_module: str, pkg_name: str) -> Set[str]:
    """
    특정 모듈에서 임포트된 이름들을 추출합니다.

    target_module: 소스 모듈 이름 (예: "data_loader")
    반환값: 임포트된 이름 집합
    """
    names: Set[str] = set()

    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom):
            module = node.module or ""
            level = node.level

            # 모듈 이름 정규화
            if level > 0:
                # 상대 임포트
                base = module.split(".")[0] if module else ""
                normalized = base
            else:
                parts = module.split(".")
                if parts[0] == pkg_name and len(parts) > 1:
                    normalized = parts[1]
                else:
                    normalized = parts[0] if parts else ""

            if normalized == target_module:
                for alias in node.names:
                    names.add(alias.name)  # 원래 함수명
                    if alias.asname:
                        names.add(alias.asname)  # 별칭도 추가

    return names


# ---------------------------------------------------------------------------
# 의존성 그래프 구축
# ---------------------------------------------------------------------------


class DependencyGraph:
    """
    대상 디렉토리 내 Python 파일들의 의존성 그래프를 구축합니다.
    """

    def __init__(self, root: Path):
        self.root = root
        self.pkg_name = root.name  # 예: "dashboard"

        # 모듈명 → 파일 경로 매핑
        self.module_to_file: Dict[str, Path] = {}
        # 파일 경로 → AST 매핑
        self.file_to_ast: Dict[Path, Optional[ast.Module]] = {}
        # 파일 경로 → 의존하는 모듈명 집합
        self.file_imports: Dict[Path, Set[str]] = {}
        # 모듈명 → 해당 모듈을 임포트하는 파일 집합 (역방향)
        self.module_importers: Dict[str, Set[Path]] = defaultdict(set)

        self._build()

    def _build(self):
        """디렉토리를 스캔하고 그래프를 구축합니다."""
        build_start = time.time()
        BUILD_TIMEOUT = 60  # 초

        py_files = list(self.root.rglob("*.py"))

        # 1단계: 모듈명 → 파일 경로 매핑 구축 (빠름, 타임아웃 불필요)
        for py_file in py_files:
            rel = py_file.relative_to(self.root)
            # stem: data_loader, server 등
            module_name = rel.stem
            # 이미 존재하면 최상위 파일 우선 (더 짧은 경로)
            if module_name not in self.module_to_file:
                self.module_to_file[module_name] = py_file
            else:
                existing = self.module_to_file[module_name]
                # 더 짧은(상위) 경로 우선
                if len(py_file.parts) < len(existing.parts):
                    self.module_to_file[module_name] = py_file

        # 2단계: 각 파일 파싱 및 임포트 추출 (타임아웃 적용)
        for py_file in py_files:
            if time.time() - build_start > BUILD_TIMEOUT:
                logger.warning(
                    "그래프 구축 타임아웃 (%ds): %d/%d 파일 처리됨",
                    BUILD_TIMEOUT,
                    len(self.file_to_ast),
                    len(py_files),
                )
                # 미처리 파일은 빈 임포트로 등록
                for remaining in py_files:
                    if remaining not in self.file_to_ast:
                        self.file_to_ast[remaining] = None
                        self.file_imports[remaining] = set()
                break
            tree = _parse_file(py_file)
            self.file_to_ast[py_file] = tree

            if tree is not None:
                imported = _extract_imported_modules(tree, self.pkg_name)
                self.file_imports[py_file] = imported
            else:
                self.file_imports[py_file] = set()

        # 3단계: 역방향 인덱스 구축 (모듈 → 임포터 파일들)
        for py_file, imports in self.file_imports.items():
            for mod in imports:
                self.module_importers[mod].add(py_file)

    def get_file_path(self, filename: str) -> Optional[Path]:
        """
        파일명 또는 모듈명으로 파일 경로를 반환합니다.
        예: "data_loader.py" 또는 "data_loader"
        """
        # .py 확장자 처리
        if filename.endswith(".py"):
            module_name = filename[:-3]
        else:
            module_name = filename

        # 직접 모듈명 조회
        if module_name in self.module_to_file:
            return self.module_to_file[module_name]

        # 상대 경로로 조회
        candidate = self.root / filename
        if candidate.exists():
            return candidate

        return None

    def get_direct_importers(self, module_name: str) -> Set[Path]:
        """해당 모듈을 직접 임포트하는 파일들을 반환합니다."""
        return set(self.module_importers.get(module_name, set()))

    def get_transitive_dependents(self, module_name: str, hops: int = 2) -> Set[Path]:
        """
        전이적 의존자를 최대 hops 단계까지 반환합니다.
        직접 임포터는 포함하지 않습니다.
        """
        visited: Set[Path] = set()
        current_level: Set[Path] = self.get_direct_importers(module_name)
        visited.update(current_level)

        for _ in range(hops - 1):
            next_level: Set[Path] = set()
            for file_path in current_level:
                # 이 파일의 모듈명을 구해서 역방향 조회
                file_module = file_path.stem
                importers = self.get_direct_importers(file_module)
                for imp in importers:
                    if imp not in visited:
                        next_level.add(imp)
                        visited.add(imp)
            current_level = next_level

        return visited

    def get_function_callers(self, module_name: str, function_name: str) -> List[Tuple[Path, int]]:
        """
        특정 모듈의 특정 함수를 호출하는 (파일, 줄번호) 목록을 반환합니다.
        """
        callers: List[Tuple[Path, int]] = []
        direct_importers = self.get_direct_importers(module_name)

        # 전이적 의존자도 포함 (2홉)
        all_dependents = self.get_transitive_dependents(module_name, hops=2)
        all_candidates = direct_importers | all_dependents

        for file_path in all_candidates:
            tree = self.file_to_ast.get(file_path)
            if tree is None:
                continue

            # 이 파일에서 target_module로부터 임포트된 이름들 확인
            imported_names = _extract_imported_names_from_module(tree, module_name, self.pkg_name)

            # function_name이 임포트되었는지 확인
            if function_name not in imported_names and function_name not in [n.split(".")[-1] for n in imported_names]:
                # 직접 임포트하지 않아도 호출할 수 있으므로 전체 검색
                pass

            # 함수 호출 검색 (임포트 여부와 관계없이 이름으로 검색)
            calls = _extract_function_calls(tree, {function_name: module_name})
            for _, lineno in calls:
                callers.append((file_path, lineno))

        return callers

    def is_test_file(self, file_path: Path) -> bool:
        """테스트 파일 여부를 확인합니다."""
        name = file_path.name
        # test_*.py 패턴
        if name.startswith("test_") and name.endswith(".py"):
            return True
        # tests/ 디렉토리 내 파일
        try:
            rel = file_path.relative_to(self.root)
            parts = rel.parts
            if "tests" in parts[:-1]:  # 마지막 부분(파일명) 제외
                return True
        except ValueError:
            pass
        return False


# ---------------------------------------------------------------------------
# 분석 메인 로직
# ---------------------------------------------------------------------------


def analyze(
    root: Path,
    changed_files: List[str],
    function_name: Optional[str] = None,
) -> List[dict]:
    """
    변경된 파일들에 대한 의존성 분석을 수행합니다.

    반환값: 각 변경 파일에 대한 분석 결과 리스트
    """
    graph = DependencyGraph(root)

    results = []

    for changed_file in changed_files:
        file_start_ms = time.time() * 1000

        # 파일 경로 해석 (존재 확인용, 결과에는 미사용)
        graph.get_file_path(changed_file)
        module_name = Path(changed_file).stem

        # 직접 임포터
        direct_importers: Set[Path] = graph.get_direct_importers(module_name)

        # 직접 임포터에서 분석 대상 파일 자체를 제외 (self-reference 방지)
        target_path = graph.get_file_path(changed_file)
        if target_path:
            direct_importers.discard(target_path)

        # 전이적 의존자 (직접 임포터 제외)
        all_dependents: Set[Path] = graph.get_transitive_dependents(module_name, hops=2)
        transitive_dependents = all_dependents - direct_importers

        # 테스트 파일 필터링
        test_files: Set[Path] = set()
        non_test_direct: Set[Path] = set()
        for p in direct_importers:
            if graph.is_test_file(p):
                test_files.add(p)
            else:
                non_test_direct.add(p)

        for p in transitive_dependents:
            if graph.is_test_file(p):
                test_files.add(p)

        # 함수 레벨 호출자
        callers_with_lines: List[Tuple[Path, int]] = []
        if function_name:
            callers_with_lines = graph.get_function_callers(module_name, function_name)

        # 경로를 상대 경로 문자열로 변환
        def to_rel(p: Path) -> str:
            try:
                return str(p.relative_to(root))
            except ValueError:
                return str(p)

        # 전체 영향 파일 수 계산
        all_affected: Set[Path] = direct_importers | transitive_dependents | test_files
        total_affected = len(all_affected)

        # 결과 조립
        result: dict = {
            "changed_file": changed_file,
        }
        if function_name:
            result["changed_function"] = function_name

        # 직접 임포터 (테스트 파일 제외)
        direct_importers_rel = sorted(to_rel(p) for p in direct_importers if not graph.is_test_file(p))
        # 전이적 의존자 (테스트 파일 제외)
        transitive_rel = sorted(to_rel(p) for p in transitive_dependents if not graph.is_test_file(p))
        # 테스트 파일
        test_files_rel = sorted(to_rel(p) for p in test_files)
        # 함수 호출자
        callers_rel = sorted(f"{to_rel(p)}:{lineno}" for p, lineno in callers_with_lines)

        result["blast_radius"] = {
            "direct_importers": direct_importers_rel,
            "callers": callers_rel,
            "transitive_dependents": transitive_rel,
            "test_files": test_files_rel,
            "total_affected": total_affected,
        }

        elapsed = time.time() * 1000 - file_start_ms
        if elapsed > 10_000:  # 10초 초과
            print(f"⚠️ 대형 파일 분석 경고: {changed_file} ({elapsed:.0f}ms)", file=sys.stderr)
        result["analysis_time_ms"] = round(elapsed, 2)

        results.append(result)

    return results


# ---------------------------------------------------------------------------
# CLI 진입점
# ---------------------------------------------------------------------------


def main():
    parser = argparse.ArgumentParser(
        description="AST 기반 Python 의존성 맵 생성기",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument(
        "--root",
        type=str,
        default=".",
        help="분석 대상 Python 코드베이스 루트 디렉토리 (기본값: 현재 디렉토리)",
    )
    parser.add_argument(
        "--files",
        nargs="+",
        metavar="FILE",
        required=True,
        help="분석할 변경 파일 목록 (예: data_loader.py server.py)",
    )
    parser.add_argument(
        "--function",
        type=str,
        default=None,
        help="함수 레벨 분석을 위한 함수명 (선택 사항)",
    )
    parser.add_argument(
        "--pretty",
        action="store_true",
        default=True,
        help="들여쓰기된 JSON 출력 (기본값: True)",
    )
    parser.add_argument(
        "--json",
        action="store_true",
        default=False,
        help="컴팩트 JSON 출력 (--pretty 오버라이드)",
    )

    args = parser.parse_args()

    root = Path(args.root).resolve()
    if not root.exists():
        print(
            json.dumps({"error": f"루트 디렉토리를 찾을 수 없습니다: {root}"}),
            file=sys.stderr,
        )
        sys.exit(1)
    if not root.is_dir():
        print(
            json.dumps({"error": f"루트 경로가 디렉토리가 아닙니다: {root}"}),
            file=sys.stderr,
        )
        sys.exit(1)

    results = analyze(
        root=root,
        changed_files=args.files,
        function_name=args.function,
    )

    # 단일 파일이면 dict로, 여러 파일이면 list로 출력
    output = results[0] if len(results) == 1 else results

    # --json이 지정되면 pretty를 무시하고 컴팩트 JSON 출력
    if args.json:
        indent = None
    else:
        indent = 2 if args.pretty else None
    print(json.dumps(output, ensure_ascii=False, indent=indent))


if __name__ == "__main__":
    main()
