#!/usr/bin/env python3
"""
design_md_search.py - CLI tool for searching and browsing DESIGN.md files
"""

import sys
import re
import argparse
import textwrap
from pathlib import Path

# Base directory for all DESIGN.md files
DESIGN_MD_DIR = Path("/home/jay/workspace/resources/design-md")

# Category mapping (hardcoded)
CATEGORIES = {
    "AI": [
        "claude", "cohere", "minimax", "mistral.ai", "nvidia", "ollama",
        "opencode.ai", "replicate", "together.ai", "x.ai", "elevenlabs",
        "runwayml", "cursor", "lovable",
    ],
    "DevTool": [
        "airtable", "cal", "figma", "framer", "linear.app", "mintlify",
        "miro", "posthog", "raycast", "resend", "sanity", "sentry",
        "supabase", "vercel", "voltagent", "warp", "webflow", "zapier",
        "hashicorp", "expo",
    ],
    "Infra": ["clickhouse", "mongodb"],
    "Design": ["pinterest", "spotify"],
    "Fintech": ["coinbase", "kraken", "revolut", "stripe", "wise"],
    "Enterprise": ["ibm", "intercom", "notion", "superhuman", "uber"],
    "Automotive": ["bmw", "ferrari", "lamborghini", "renault", "tesla", "spacex"],
    "Other": ["airbnb", "apple", "clay", "composio"],
}

# Recommendation mappings (hardcoded)
RECOMMENDATIONS = {
    "dashboard": ["linear.app", "posthog", "sentry", "airtable"],
    "fintech": ["stripe", "revolut", "wise", "coinbase", "kraken"],
    "minimal": ["notion", "cal", "superhuman", "resend"],
    "dark": ["x.ai", "cursor", "warp", "linear.app", "spacex"],
    "landing": ["vercel", "framer", "webflow", "mintlify"],
    "ai": ["claude", "cohere", "mistral.ai", "opencode.ai"],
    "luxury": ["ferrari", "lamborghini", "bmw", "apple"],
    "saas": ["zapier", "intercom", "airtable", "posthog"],
}

# Reverse mapping: site -> category
SITE_TO_CATEGORY = {}
for cat, sites in CATEGORIES.items():
    for site in sites:
        SITE_TO_CATEGORY[site] = cat


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def separator(char="=", width=72):
    return char * width


def get_all_sites():
    """Return sorted list of all site names that have a DESIGN.md file."""
    sites = []
    if DESIGN_MD_DIR.exists():
        for path in sorted(DESIGN_MD_DIR.iterdir()):
            if path.is_dir() and (path / "DESIGN.md").exists():
                sites.append(path.name)
    return sites


def get_design_md_path(site):
    return DESIGN_MD_DIR / site / "DESIGN.md"


def read_design_md(site):
    path = get_design_md_path(site)
    if not path.exists():
        return None
    return path.read_text(encoding="utf-8")


def extract_section(content, section_pattern):
    """Extract a section from DESIGN.md by matching the heading pattern."""
    lines = content.splitlines()
    result_lines = []
    in_section = False
    for line in lines:
        if re.match(section_pattern, line, re.IGNORECASE):
            in_section = True
            result_lines.append(line)
            continue
        if in_section:
            # Stop at the next ## heading
            if re.match(r"^##\s+\d+\.", line):
                break
            result_lines.append(line)
    return "\n".join(result_lines).strip() if result_lines else ""


def extract_key_characteristics(content):
    """Extract the Key Characteristics bullet points from section 1."""
    section = extract_section(content, r"^##\s+1\.\s+Visual Theme")
    if not section:
        return []
    bullets = []
    for line in section.splitlines():
        stripped = line.strip()
        if stripped.startswith("- "):
            bullets.append(stripped)
    return bullets


def extract_color_palette(content):
    """Extract the Color Palette section (section 2)."""
    return extract_section(content, r"^##\s+2\.\s+Color Palette")


def extract_typography(content):
    """Extract typography font family lines from section 3."""
    section = extract_section(content, r"^##\s+3\.\s+Typography")
    fonts = []
    for line in section.splitlines():
        stripped = line.strip()
        if stripped.startswith("- **") or "font" in stripped.lower() or "typeface" in stripped.lower():
            if len(stripped) < 200:
                fonts.append(stripped)
    return fonts[:8]


def extract_theme_summary(content):
    """Extract first paragraph of Visual Theme section."""
    section = extract_section(content, r"^##\s+1\.\s+Visual Theme")
    paragraphs = [p.strip() for p in section.split("\n\n") if p.strip()]
    # Skip the heading itself
    for para in paragraphs:
        if not para.startswith("#"):
            # Return wrapped summary
            return textwrap.shorten(para, width=280, placeholder="...")
    return ""


def site_category(site):
    return SITE_TO_CATEGORY.get(site, "Other")


def sites_in_category(category_name):
    """Return sites for a given category name (case-insensitive)."""
    for cat, sites in CATEGORIES.items():
        if cat.lower() == category_name.lower():
            return sites
    return []


# ---------------------------------------------------------------------------
# Command: list
# ---------------------------------------------------------------------------

def cmd_list(_args):  # noqa: ARG001
    all_sites = get_all_sites()
    site_set = set(all_sites)

    print(separator())
    print("  DESIGN.md Site Catalog")
    print(separator())
    print()

    total = 0
    for cat, sites in CATEGORIES.items():
        available = [s for s in sites if s in site_set]
        if not available:
            continue
        print(f"[{cat}]  ({len(available)} sites)")
        print(separator("-", 72))
        for i, site in enumerate(available, 1):
            print(f"  {i:>2}. {site}")
        print()
        total += len(available)

    # Sites not in any category
    categorized = {s for sites in CATEGORIES.values() for s in sites}
    uncategorized = [s for s in all_sites if s not in categorized]
    if uncategorized:
        print("[Uncategorized]")
        print(separator("-", 72))
        for site in uncategorized:
            print(f"       {site}")
        print()
        total += len(uncategorized)

    print(separator())
    print(f"  Total: {total} sites")
    print(separator())


# ---------------------------------------------------------------------------
# Command: search
# ---------------------------------------------------------------------------

def cmd_search(args):
    keyword = args.keyword
    category_filter = args.category

    all_sites = get_all_sites()

    if category_filter:
        filtered = sites_in_category(category_filter)
        if not filtered:
            print(f"[ERROR] Unknown category: '{category_filter}'")
            print("Available categories: " + ", ".join(CATEGORIES.keys()))
            sys.exit(1)
        sites_to_search = [s for s in all_sites if s in set(filtered)]
    else:
        sites_to_search = all_sites

    pattern = re.compile(re.escape(keyword), re.IGNORECASE)

    print(separator())
    print(f"  Search: \"{keyword}\"" + (f"  [category: {category_filter}]" if category_filter else ""))
    print(separator())
    print()

    match_count = 0
    site_count = 0

    for site in sites_to_search:
        content = read_design_md(site)
        if content is None:
            continue

        lines = content.splitlines()
        matches = []
        for lineno, line in enumerate(lines, 1):
            if pattern.search(line):
                # Provide surrounding context (trimmed)
                context = line.strip()
                context = textwrap.shorten(context, width=120, placeholder="...")
                matches.append((lineno, context))

        if matches:
            site_count += 1
            cat = site_category(site)
            print(f"  {site}  [{cat}]  ({len(matches)} match{'es' if len(matches) != 1 else ''})")
            print(separator("-", 72))
            shown = matches[:5]  # Show up to 5 matches per site
            for lineno, ctx in shown:
                print(f"    L{lineno:>4}: {ctx}")
            if len(matches) > 5:
                print(f"          ... and {len(matches) - 5} more match(es)")
            print()
            match_count += len(matches)

    print(separator())
    if site_count == 0:
        print(f"  No matches found for \"{keyword}\"")
    else:
        print(f"  Found {match_count} match(es) across {site_count} site(s)")
    print(separator())


# ---------------------------------------------------------------------------
# Command: show
# ---------------------------------------------------------------------------

def cmd_show(args):
    site = args.site

    all_sites = get_all_sites()
    if site not in all_sites:
        # Try partial/case-insensitive match
        candidates = [s for s in all_sites if site.lower() in s.lower()]
        if len(candidates) == 1:
            site = candidates[0]
        elif len(candidates) > 1:
            print(f"[ERROR] Ambiguous site name '{site}'. Did you mean one of:")
            for c in candidates:
                print(f"  - {c}")
            sys.exit(1)
        else:
            print(f"[ERROR] Site '{site}' not found.")
            print()
            print("Available sites:")
            for s in all_sites:
                print(f"  {s}")
            sys.exit(1)

    content = read_design_md(site)
    if content is None:
        print(f"[ERROR] DESIGN.md not found for site: {site}")
        sys.exit(1)

    cat = site_category(site)
    print(separator())
    print(f"  DESIGN.md  --  {site}  [{cat}]")
    print(separator())
    print()
    print(content)
    print()
    print(separator())


# ---------------------------------------------------------------------------
# Command: compare
# ---------------------------------------------------------------------------

def print_two_columns(left_lines, right_lines, col_width=34, gutter=" | "):
    """Print two lists side-by-side."""
    max_rows = max(len(left_lines), len(right_lines))
    for i in range(max_rows):
        l = left_lines[i] if i < len(left_lines) else ""
        r = right_lines[i] if i < len(right_lines) else ""
        print(f"{l:<{col_width}}{gutter}{r}")


def cmd_compare(args):
    site_a = args.site1
    site_b = args.site2

    all_sites = get_all_sites()
    errors = []
    for site in [site_a, site_b]:
        if site not in all_sites:
            errors.append(site)
    if errors:
        for e in errors:
            print(f"[ERROR] Site '{e}' not found.")
        sys.exit(1)

    content_a = read_design_md(site_a)
    content_b = read_design_md(site_b)

    cat_a = site_category(site_a)
    cat_b = site_category(site_b)

    col_width = 34
    gutter = " | "

    print(separator())
    print(f"  Comparing: {site_a} [{cat_a}]  vs  {site_b} [{cat_b}]")
    print(separator())

    # --- Visual Theme summary ---
    print()
    print("  VISUAL THEME SUMMARY")
    print(separator("-", 72))
    summary_a = extract_theme_summary(content_a)
    summary_b = extract_theme_summary(content_b)

    wrapped_a = textwrap.wrap(summary_a or "(no summary)", width=col_width)
    wrapped_b = textwrap.wrap(summary_b or "(no summary)", width=col_width)
    # Header row
    print(f"  {'[ ' + site_a + ' ]':<{col_width}}{gutter}{'[ ' + site_b + ' ]'}")
    print(f"  {separator('-', col_width)}{gutter}{separator('-', col_width)}")
    max_rows = max(len(wrapped_a), len(wrapped_b))
    for i in range(max_rows):
        l = wrapped_a[i] if i < len(wrapped_a) else ""
        r = wrapped_b[i] if i < len(wrapped_b) else ""
        print(f"  {l:<{col_width}}{gutter}{r}")

    # --- Key Characteristics ---
    print()
    print("  KEY CHARACTERISTICS")
    print(separator("-", 72))
    chars_a = extract_key_characteristics(content_a)
    chars_b = extract_key_characteristics(content_b)

    print(f"  {'[ ' + site_a + ' ]':<{col_width}}{gutter}{'[ ' + site_b + ' ]'}")
    print(f"  {separator('-', col_width)}{gutter}{separator('-', col_width)}")

    # Wrap each bullet
    wrapped_chars_a = []
    wrapped_chars_b = []
    for bullet in chars_a:
        for line in textwrap.wrap(bullet, width=col_width):
            wrapped_chars_a.append(line)
        wrapped_chars_a.append("")
    for bullet in chars_b:
        for line in textwrap.wrap(bullet, width=col_width):
            wrapped_chars_b.append(line)
        wrapped_chars_b.append("")

    if not wrapped_chars_a:
        wrapped_chars_a = ["(none found)"]
    if not wrapped_chars_b:
        wrapped_chars_b = ["(none found)"]

    max_rows = max(len(wrapped_chars_a), len(wrapped_chars_b))
    for i in range(max_rows):
        l = wrapped_chars_a[i] if i < len(wrapped_chars_a) else ""
        r = wrapped_chars_b[i] if i < len(wrapped_chars_b) else ""
        print(f"  {l:<{col_width}}{gutter}{r}")

    # --- Color Palette (first 30 lines) ---
    print()
    print("  COLOR PALETTE  (first colors listed)")
    print(separator("-", 72))

    palette_a = extract_color_palette(content_a)
    palette_b = extract_color_palette(content_b)

    # Extract only lines that look like color definitions
    def color_lines(text):
        lines = []
        for line in text.splitlines():
            stripped = line.strip()
            if not stripped or stripped.startswith("#"):
                continue
            # Include lines that mention hex colors or look like color entries
            if re.search(r"#[0-9A-Fa-f]{3,6}", stripped) or stripped.startswith("- **"):
                short = textwrap.shorten(stripped, width=col_width, placeholder="...")
                lines.append(short)
            if len(lines) >= 20:
                break
        return lines or ["(no color data)"]

    colors_a = color_lines(palette_a)
    colors_b = color_lines(palette_b)

    print(f"  {'[ ' + site_a + ' ]':<{col_width}}{gutter}{'[ ' + site_b + ' ]'}")
    print(f"  {separator('-', col_width)}{gutter}{separator('-', col_width)}")

    max_rows = max(len(colors_a), len(colors_b))
    for i in range(max_rows):
        l = colors_a[i] if i < len(colors_a) else ""
        r = colors_b[i] if i < len(colors_b) else ""
        print(f"  {l:<{col_width}}{gutter}{r}")

    # --- Typography ---
    print()
    print("  TYPOGRAPHY")
    print(separator("-", 72))

    # Extract font family lines
    def font_lines(content):
        section = extract_section(content, r"^##\s+3\.\s+Typography")
        lines = []
        for line in section.splitlines():
            stripped = line.strip()
            if not stripped or stripped.startswith("#"):
                continue
            if any(kw in stripped.lower() for kw in ["font", "typeface", "headline", "body", "sans", "serif", "mono", "display"]):
                short = textwrap.shorten(stripped, width=col_width, placeholder="...")
                if short not in lines:
                    lines.append(short)
            if len(lines) >= 10:
                break
        return lines or ["(no typography data)"]

    fonts_a = font_lines(content_a)
    fonts_b = font_lines(content_b)

    print(f"  {'[ ' + site_a + ' ]':<{col_width}}{gutter}{'[ ' + site_b + ' ]'}")
    print(f"  {separator('-', col_width)}{gutter}{separator('-', col_width)}")

    max_rows = max(len(fonts_a), len(fonts_b))
    for i in range(max_rows):
        l = fonts_a[i] if i < len(fonts_a) else ""
        r = fonts_b[i] if i < len(fonts_b) else ""
        print(f"  {l:<{col_width}}{gutter}{r}")

    print()
    print(separator())


# ---------------------------------------------------------------------------
# Command: recommend
# ---------------------------------------------------------------------------

def cmd_recommend(args):
    use_case = args.usecase
    category_filter = args.category

    if use_case not in RECOMMENDATIONS:
        print(f"[ERROR] Unknown use-case: '{use_case}'")
        print()
        print("Available use-cases:")
        for uc, sites in RECOMMENDATIONS.items():
            print(f"  {uc:<12}  ->  {', '.join(sites)}")
        sys.exit(1)

    sites = RECOMMENDATIONS[use_case]

    if category_filter:
        cat_sites = set(sites_in_category(category_filter))
        if not cat_sites:
            print(f"[ERROR] Unknown category: '{category_filter}'")
            print("Available categories: " + ", ".join(CATEGORIES.keys()))
            sys.exit(1)
        sites = [s for s in sites if s in cat_sites]

    all_sites = get_all_sites()
    site_set = set(all_sites)

    print(separator())
    print(f"  Recommendations for: \"{use_case}\"" +
          (f"  [category: {category_filter}]" if category_filter else ""))
    print(separator())
    print()

    if not sites:
        print("  No recommendations match the given filters.")
        print()
        print(separator())
        return

    for rank, site in enumerate(sites, 1):
        available = site in site_set
        cat = site_category(site)
        status = "" if available else "  [DESIGN.md not found]"
        print(f"  {rank}. {site}  [{cat}]{status}")
        if available:
            content = read_design_md(site)
            if content:
                # Print theme summary (first 2 wrapped lines)
                summary = extract_theme_summary(content)
                if summary:
                    wrapped = textwrap.wrap(summary, width=66)
                    for line in wrapped[:3]:
                        print(f"       {line}")
                # Print key characteristics (first 3)
                chars = extract_key_characteristics(content)
                for char in chars[:3]:
                    short = textwrap.shorten(char, width=66, placeholder="...")
                    print(f"       {short}")
        print()

    print(separator())
    print(f"  {len(sites)} site(s) recommended for '{use_case}'")
    print(separator())


# ---------------------------------------------------------------------------
# CLI setup
# ---------------------------------------------------------------------------

def build_parser():
    parser = argparse.ArgumentParser(
        prog="design_md_search.py",
        description="Search and browse DESIGN.md files for 58 sites.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent("""\
            examples:
              python3 design_md_search.py list
              python3 design_md_search.py search "dark theme"
              python3 design_md_search.py search "gradient" --category AI
              python3 design_md_search.py show claude
              python3 design_md_search.py compare claude vercel
              python3 design_md_search.py recommend dashboard
              python3 design_md_search.py recommend dark --category DevTool
        """),
    )

    subparsers = parser.add_subparsers(dest="command", metavar="COMMAND")
    subparsers.required = True

    # list
    p_list = subparsers.add_parser("list", help="List all 58 sites grouped by category")
    p_list.set_defaults(func=cmd_list)

    # search
    p_search = subparsers.add_parser("search", help="Search keyword across all DESIGN.md files")
    p_search.add_argument("keyword", help="Keyword to search (case-insensitive)")
    p_search.add_argument(
        "--category", "-c",
        metavar="CAT",
        help="Limit search to a category (AI, DevTool, Infra, Design, Fintech, Enterprise, Automotive, Other)",
    )
    p_search.set_defaults(func=cmd_search)

    # show
    p_show = subparsers.add_parser("show", help="Show full DESIGN.md for a site")
    p_show.add_argument("site", help="Site name (e.g. claude, vercel, stripe)")
    p_show.set_defaults(func=cmd_show)

    # compare
    p_compare = subparsers.add_parser("compare", help="Compare two sites side-by-side")
    p_compare.add_argument("site1", help="First site name")
    p_compare.add_argument("site2", help="Second site name")
    p_compare.set_defaults(func=cmd_compare)

    # recommend
    p_recommend = subparsers.add_parser("recommend", help="Get site recommendations by use-case")
    p_recommend.add_argument(
        "usecase",
        help="Use-case (dashboard, fintech, minimal, dark, landing, ai, luxury, saas)",
    )
    p_recommend.add_argument(
        "--category", "-c",
        metavar="CAT",
        help="Filter recommendations by category",
    )
    p_recommend.set_defaults(func=cmd_recommend)

    return parser


def main():
    parser = build_parser()
    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()
