Spaces:
Running
Running
| #!/usr/bin/env python | |
| """ | |
| modular_graph_and_candidates.py | |
| ================================ | |
| Create **one** rich view that combines | |
| 1. The *dependency graph* between existing **modular_*.py** implementations in | |
| π€Β Transformers (blue/π‘) **and** | |
| 2. The list of *missing* modular models (fullβred nodes) **plus** similarity | |
| edges (fullβred links) between highlyβoverlapping modelling files β the | |
| output of *find_modular_candidates.py* β so you can immediately spot good | |
| refactor opportunities. | |
| βββΒ UsageΒ βββ | |
| ```bash | |
| python modular_graph_and_candidates.py /path/to/transformers \ | |
| --multimodal # keep only models whose modelling code mentions | |
| # "pixel_values" β₯Β 3 times | |
| --sim-threshold 0.5 # Jaccard cutoff (default 0.50) | |
| --out graph.html # output HTML file name | |
| ``` | |
| Colour legend in the generated HTML: | |
| * π‘Β **base model**Β β has modular shards *imported* by others but no parent | |
| * π΅Β **derived modular model**Β β has a `modular_*.py` and inherits from β₯β―1 model | |
| * π΄Β **candidate**Β β no `modular_*.py` yet (and/or very similar to another) | |
| * red edges = highβJaccard similarity links (potential to factorise) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import ast | |
| import json | |
| import re | |
| import subprocess | |
| import tokenize | |
| from collections import Counter, defaultdict | |
| from itertools import combinations | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| from sentence_transformers import SentenceTransformer, util | |
| from tqdm import tqdm | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from datetime import datetime | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SIM_DEFAULT = 0.5 # similarity threshold | |
| PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values") | |
| HTML_DEFAULT = "d3_modular_graph.html" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) Helpers to analyse *modelling* files (for similarity & multimodal filter) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _strip_source(code: str) -> str: | |
| """Remove docβstrings, comments and import lines to keep only the core code.""" | |
| code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # docβstrings | |
| code = re.sub(r"#.*", "", code) # # comments | |
| return "\n".join(ln for ln in code.splitlines() | |
| if not re.match(r"\s*(from|import)\s+", ln)) | |
| def _tokenise(code: str) -> Set[str]: | |
| """Extract identifiers using regex - more robust than tokenizer for malformed code.""" | |
| toks: Set[str] = set() | |
| for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code): | |
| toks.add(match.group()) | |
| return toks | |
| def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: | |
| """Return tokenβbags of every `modeling_*.py` plus a pixelβvalue counter.""" | |
| bags: Dict[str, List[Set[str]]] = defaultdict(list) | |
| pixel_hits: Dict[str, int] = defaultdict(int) | |
| for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): | |
| for py in mdl_dir.rglob("modeling_*.py"): | |
| try: | |
| text = py.read_text(encoding="utfβ8") | |
| pixel_hits[mdl_dir.name] += text.count("pixel_values") | |
| bags[mdl_dir.name].append(_tokenise(_strip_source(text))) | |
| except Exception as e: | |
| print(f"β οΈ Skipped {py}: {e}") | |
| return bags, pixel_hits | |
| def _jaccard(a: Set[str], b: Set[str]) -> float: | |
| return 0.0 if (not a or not b) else len(a & b) / len(a | b) | |
| def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float = 0.1) -> Dict[Tuple[str,str], float]: | |
| largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} | |
| out: Dict[Tuple[str,str], float] = {} | |
| for m1, m2 in combinations(sorted(largest.keys()), 2): | |
| s = _jaccard(largest[m1], largest[m2]) | |
| if s >= thr: | |
| out[(m1, m2)] = s | |
| return out | |
| def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float = 0.1) -> Dict[Tuple[str, str], float]: | |
| model = SentenceTransformer("microsoft/codebert-base", trust_remote_code=True) | |
| try: | |
| cfg = model[0].auto_model.config | |
| pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings"))) | |
| except Exception: | |
| pos_limit = 1024 | |
| seq_len = min(pos_limit, 2048) | |
| model.max_seq_length = seq_len | |
| model[0].max_seq_length = seq_len | |
| model[0].tokenizer.model_max_length = seq_len | |
| texts = {} | |
| for name in tqdm(missing, desc="Reading modeling files"): | |
| if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]): | |
| print(f"Skipping {name} (causes GPU abort)") | |
| continue | |
| code = "" | |
| for py in (models_root / name).rglob("modeling_*.py"): | |
| try: | |
| code += _strip_source(py.read_text(encoding="utf-8")) + "\n" | |
| except Exception: | |
| continue | |
| texts[name] = code.strip() or " " | |
| names = list(texts) | |
| all_embeddings = [] | |
| print(f"Encoding embeddings for {len(names)} models...") | |
| batch_size = 4 # keep your default | |
| # ββ two-stage caching: temp (for resume) + permanent (for reuse) βββββββββββββ | |
| temp_cache_path = Path("temp_embeddings.npz") # For resuming computation | |
| final_cache_path = Path("embeddings_cache.npz") # For permanent storage | |
| start_idx = 0 | |
| emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)() | |
| # Try to load from permanent cache first | |
| if final_cache_path.exists(): | |
| try: | |
| cached = np.load(final_cache_path, allow_pickle=True) | |
| cached_names = list(cached["names"]) | |
| if names == cached_names: # Exact match - use final cache | |
| print(f"β Using final embeddings cache ({len(cached_names)} models)") | |
| return compute_similarities_from_cache(thr) | |
| except Exception as e: | |
| print(f"β οΈ Failed to load final cache: {e}") | |
| # Try to resume from temp cache | |
| if temp_cache_path.exists(): | |
| try: | |
| cached = np.load(temp_cache_path, allow_pickle=True) | |
| cached_names = list(cached["names"]) | |
| if names[:len(cached_names)] == cached_names: | |
| loaded = cached["embeddings"].astype(np.float32) | |
| all_embeddings.append(loaded) | |
| start_idx = len(cached_names) | |
| print(f"π Resuming from temp cache: {start_idx}/{len(names)} models") | |
| except Exception as e: | |
| print(f"β οΈ Failed to load temp cache: {e}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False): | |
| batch_names = names[i:i+batch_size] | |
| batch_texts = [texts[name] for name in batch_names] | |
| try: | |
| print(f"Processing batch: {batch_names}") | |
| emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) | |
| except Exception as e: | |
| print(f"β οΈ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}") | |
| emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32) | |
| all_embeddings.append(emb) | |
| # save to temp cache after each batch (for resume) | |
| try: | |
| cur = np.vstack(all_embeddings).astype(np.float32) | |
| np.savez( | |
| temp_cache_path, | |
| embeddings=cur, | |
| names=np.array(names[:i+len(batch_names)], dtype=object), | |
| ) | |
| except Exception as e: | |
| print(f"β οΈ Failed to write temp cache: {e}") | |
| if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| print(f"π§Ή Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}") | |
| embeddings = np.vstack(all_embeddings).astype(np.float32) | |
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | |
| embeddings = embeddings / norms | |
| print("Computing pairwise similarities...") | |
| sims_mat = embeddings @ embeddings.T | |
| out = {} | |
| matrix_size = embeddings.shape[0] | |
| processed_names = names[:matrix_size] | |
| for i in range(matrix_size): | |
| for j in range(i + 1, matrix_size): | |
| s = float(sims_mat[i, j]) | |
| if s >= thr: | |
| out[(processed_names[i], processed_names[j])] = s | |
| # Save to final cache when complete | |
| try: | |
| np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object)) | |
| print(f"πΎ Final embeddings saved to {final_cache_path}") | |
| # Clean up temp cache | |
| if temp_cache_path.exists(): | |
| temp_cache_path.unlink() | |
| print(f"π§Ή Cleaned up temp cache") | |
| except Exception as e: | |
| print(f"β οΈ Failed to save final cache: {e}") | |
| return out | |
| def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]: | |
| """Compute similarities from cached embeddings without reprocessing.""" | |
| embeddings_path = Path("embeddings_cache.npz") | |
| if not embeddings_path.exists(): | |
| return {} | |
| try: | |
| cached = np.load(embeddings_path, allow_pickle=True) | |
| embeddings = cached["embeddings"].astype(np.float32) | |
| names = list(cached["names"]) | |
| # Normalize embeddings | |
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | |
| embeddings = embeddings / norms | |
| # Compute similarities | |
| sims_mat = embeddings @ embeddings.T | |
| out = {} | |
| for i in range(len(names)): | |
| for j in range(i + 1, len(names)): | |
| s = float(sims_mat[i, j]) | |
| if s >= threshold: | |
| out[(names[i], names[j])] = s | |
| print(f"β‘ Computed {len(out)} similarities from cache (threshold: {threshold})") | |
| return out | |
| except Exception as e: | |
| print(f"β οΈ Failed to compute from cache: {e}") | |
| return {} | |
| def filter_similarities_by_threshold(similarities: Dict[Tuple[str, str], float], threshold: float) -> Dict[Tuple[str, str], float]: | |
| return {pair: score for pair, score in similarities.items() if score >= threshold} | |
| def filter_graph_by_threshold(graph_data: dict, threshold: float) -> dict: | |
| filtered_links = [] | |
| for link in graph_data["links"]: | |
| if link.get("cand", False): | |
| try: | |
| score = float(link["label"].rstrip('%')) / 100.0 | |
| if score >= threshold: | |
| filtered_links.append(link) | |
| except (ValueError, AttributeError): | |
| filtered_links.append(link) | |
| else: | |
| filtered_links.append(link) | |
| return { | |
| "nodes": graph_data["nodes"], | |
| "links": filtered_links, | |
| **{k: v for k, v in graph_data.items() if k not in ["nodes", "links"]} | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2) Scan *modular_*.py* files to build an importβdependency graph | |
| # β only **modeling_*** imports are considered (skip configuration / processing) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def modular_files(models_root: Path) -> List[Path]: | |
| return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] | |
| def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: | |
| """Return {derived_model: [{source, imported_class}, ...]} | |
| Only `modeling_*` imports are kept; anything coming from configuration/processing/ | |
| image* utils is ignored so the visual graph focuses strictly on modelling code. | |
| Excludes edges to sources whose model name is not a model dir. | |
| """ | |
| model_names = {p.name for p in models_root.iterdir() if p.is_dir()} | |
| deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) | |
| for fp in modular_files: | |
| derived = fp.parent.name | |
| try: | |
| tree = ast.parse(fp.read_text(encoding="utfβ8"), filename=str(fp)) | |
| except Exception as e: | |
| print(f"β οΈ AST parse failed for {fp}: {e}") | |
| continue | |
| for node in ast.walk(tree): | |
| if not isinstance(node, ast.ImportFrom) or not node.module: | |
| continue | |
| mod = node.module | |
| # keep only *modeling_* imports, drop anything else | |
| if ("modeling_" not in mod or | |
| "configuration_" in mod or | |
| "processing_" in mod or | |
| "image_processing" in mod or | |
| "modeling_attn_mask_utils" in mod): | |
| continue | |
| parts = re.split(r"[./]", mod) | |
| src = next((p for p in parts if p not in {"", "models", "transformers"}), "") | |
| if not src or src == derived or src not in model_names: | |
| continue | |
| for alias in node.names: | |
| deps[derived].append({"source": src, "imported_class": alias.name}) | |
| return dict(deps) | |
| # modular_graph_and_candidates.py (top-level) | |
| def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]: | |
| """Get list of models missing modular implementations.""" | |
| bags, pix_hits = build_token_bags(models_root) | |
| mod_files = modular_files(models_root) | |
| models_with_modular = {p.parent.name for p in mod_files} | |
| missing = [m for m in bags if m not in models_with_modular] | |
| if multimodal: | |
| missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] | |
| return missing, bags, pix_hits | |
| def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]], | |
| threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]: | |
| min_threshold = 0.1 | |
| if sim_method == "jaccard": | |
| return similarity_clusters({m: bags[m] for m in missing}, min_threshold) | |
| else: | |
| embeddings_path = Path("embeddings_cache.npz") | |
| if embeddings_path.exists(): | |
| cached_sims = compute_similarities_from_cache(min_threshold) | |
| if cached_sims: | |
| return cached_sims | |
| return embedding_similarity_clusters(models_root, missing, min_threshold) | |
| def build_graph_json( | |
| transformers_dir: Path, | |
| threshold: float = SIM_DEFAULT, | |
| multimodal: bool = False, | |
| sim_method: str = "jaccard", | |
| ) -> dict: | |
| """Return the {nodes, links} dict that D3 needs.""" | |
| # Check if we can use cached embeddings only | |
| embeddings_cache = Path("embeddings_cache.npz") | |
| print(f"π Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}") | |
| if sim_method == "embedding" and embeddings_cache.exists(): | |
| try: | |
| # Try to compute from cache without accessing repo | |
| cached_sims = compute_similarities_from_cache(0.1) | |
| print(f"π Got {len(cached_sims)} cached similarities") | |
| if cached_sims: | |
| # Create graph with cached similarities + modular dependencies | |
| cached_data = np.load(embeddings_cache, allow_pickle=True) | |
| missing = list(cached_data["names"]) | |
| # Still need to get modular dependencies from repo | |
| models_root = transformers_dir / "src/transformers/models" | |
| mod_files = modular_files(models_root) | |
| deps = dependency_graph(mod_files, models_root) | |
| # Build full graph structure | |
| nodes = set(missing) # Start with cached models | |
| links = [] | |
| # Add dependency links | |
| for drv, lst in deps.items(): | |
| for d in lst: | |
| links.append({ | |
| "source": d["source"], | |
| "target": drv, | |
| "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | |
| "cand": False | |
| }) | |
| nodes.update({d["source"], drv}) | |
| # Add similarity links | |
| for (a, b), s in cached_sims.items(): | |
| links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | |
| # Create node list with proper classification | |
| targets = {lk["target"] for lk in links if not lk["cand"]} | |
| sources = {lk["source"] for lk in links if not lk["cand"]} | |
| nodelist = [] | |
| for n in sorted(nodes): | |
| if n in missing and n not in sources and n not in targets: | |
| cls = "cand" | |
| elif n in sources and n not in targets: | |
| cls = "base" | |
| else: | |
| cls = "derived" | |
| nodelist.append({"id": n, "cls": cls, "sz": 1}) | |
| graph = {"nodes": nodelist, "links": links} | |
| print(f"β‘ Built graph from cache: {len(nodelist)} nodes, {len(links)} links") | |
| if threshold > 0.1: | |
| graph = filter_graph_by_threshold(graph, threshold) | |
| return graph | |
| except Exception as e: | |
| print(f"β οΈ Cache-only build failed: {e}, falling back to full build") | |
| # Full build with repository access | |
| models_root = transformers_dir / "src/transformers/models" | |
| # Get missing models and their data | |
| missing, bags, pix_hits = get_missing_models(models_root, multimodal) | |
| # Build dependency graph | |
| mod_files = modular_files(models_root) | |
| deps = dependency_graph(mod_files, models_root) | |
| # Compute similarities | |
| sims = compute_similarities(models_root, missing, bags, threshold, sim_method) | |
| # ---- assemble nodes & links ---- | |
| nodes: Set[str] = set() | |
| links: List[dict] = [] | |
| for drv, lst in deps.items(): | |
| for d in lst: | |
| links.append({ | |
| "source": d["source"], | |
| "target": drv, | |
| "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | |
| "cand": False | |
| }) | |
| nodes.update({d["source"], drv}) | |
| for (a, b), s in sims.items(): | |
| links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | |
| nodes.update({a, b}) | |
| nodes.update(missing) | |
| deg = Counter() | |
| for lk in links: | |
| deg[lk["source"]] += 1 | |
| deg[lk["target"]] += 1 | |
| max_deg = max(deg.values() or [1]) | |
| targets = {lk["target"] for lk in links if not lk["cand"]} | |
| sources = {lk["source"] for lk in links if not lk["cand"]} | |
| missing_only = [m for m in missing if m not in sources and m not in targets] | |
| nodes.update(missing_only) | |
| nodelist = [] | |
| for n in sorted(nodes): | |
| if n in missing_only: | |
| cls = "cand" | |
| elif n in sources and n not in targets: | |
| cls = "base" | |
| else: | |
| cls = "derived" | |
| nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) | |
| graph = {"nodes": nodelist, "links": links} | |
| if threshold > 0.1: | |
| graph = filter_graph_by_threshold(graph, threshold) | |
| return graph | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Timeline functions for chronological visualization | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_model_creation_dates(transformers_dir: Path) -> Dict[str, datetime]: | |
| """Get creation dates for all model directories by finding the earliest add of the directory path.""" | |
| models_root = transformers_dir / "src/transformers/models" | |
| creation_dates: Dict[str, datetime] = {} | |
| if not models_root.exists(): | |
| return creation_dates | |
| def run_git(args: list[str]) -> subprocess.CompletedProcess: | |
| return subprocess.run( | |
| ["git"] + args, | |
| cwd=transformers_dir, | |
| capture_output=True, | |
| text=True, | |
| timeout=120, | |
| ) | |
| # Ensure full history; shallow clones make every path look newly added "today". | |
| shallow = run_git(["rev-parse", "--is-shallow-repository"]) | |
| if shallow.returncode == 0 and shallow.stdout.strip() == "true": | |
| # Try best-effort unshallow; if it fails, we still proceed. | |
| run_git(["fetch", "--unshallow", "--tags", "--prune"]) # ignore return code | |
| # Fallback if server forbids --unshallow | |
| run_git(["fetch", "--depth=100000", "--tags", "--prune"]) | |
| for model_dir in models_root.iterdir(): | |
| if not model_dir.is_dir(): | |
| continue | |
| rel = f"src/transformers/models/{model_dir.name}/" | |
| # Earliest commit that ADDED something under this directory. | |
| # Use a stable delimiter to avoid locale/spacing issues. | |
| proc = run_git([ | |
| "log", | |
| "--reverse", # oldest β newest | |
| "--diff-filter=A", # additions only | |
| "--date=short", # YYYY-MM-DD | |
| '--format=%H|%ad', # hash|date | |
| "--", | |
| rel, | |
| ]) | |
| if proc.returncode != 0 or not proc.stdout.strip(): | |
| # As a fallback, look at the earliest commit touching any tracked file under the dir. | |
| # This can catch cases where files were moved (rename) rather than added. | |
| ls = run_git(["ls-files", rel]) | |
| files = [ln for ln in ls.stdout.splitlines() if ln.strip()] | |
| best_date: datetime | None = None | |
| if files: | |
| for fp in files: | |
| proc_file = run_git([ | |
| "log", | |
| "--reverse", | |
| "--diff-filter=A", | |
| "--date=short", | |
| "--format=%H|%ad", | |
| "--", | |
| fp, | |
| ]) | |
| line = proc_file.stdout.splitlines()[0].strip() if proc_file.stdout else "" | |
| if line and "|" in line: | |
| _, d = line.split("|", 1) | |
| try: | |
| dt = datetime.strptime(d.strip(), "%Y-%m-%d") | |
| if best_date is None or dt < best_date: | |
| best_date = dt | |
| except ValueError: | |
| pass | |
| if best_date is not None: | |
| creation_dates[model_dir.name] = best_date | |
| print(f"β {model_dir.name}: {best_date.strftime('%Y-%m-%d')}") | |
| else: | |
| print(f"β {model_dir.name}: no add commit found") | |
| continue | |
| first_line = proc.stdout.splitlines()[0].strip() # oldest add | |
| if "|" in first_line: | |
| _, date_str = first_line.split("|", 1) | |
| try: | |
| creation_dates[model_dir.name] = datetime.strptime(date_str.strip(), "%Y-%m-%d") | |
| print(f"β {model_dir.name}: {date_str.strip()}") | |
| except ValueError: | |
| print(f"β {model_dir.name}: bad date format: {date_str!r}") | |
| else: | |
| print(f"β {model_dir.name}: unexpected log format: {first_line!r}") | |
| return creation_dates | |
| def build_timeline_json( | |
| transformers_dir: Path, | |
| threshold: float = SIM_DEFAULT, | |
| multimodal: bool = False, | |
| sim_method: str = "jaccard", | |
| ) -> dict: | |
| """Build chronological timeline with modular connections.""" | |
| # Get the standard dependency graph for connections | |
| graph = build_graph_json(transformers_dir, threshold, multimodal, sim_method) | |
| # Get creation dates for chronological positioning | |
| creation_dates = get_model_creation_dates(transformers_dir) | |
| # Enhance nodes with chronological data | |
| for node in graph["nodes"]: | |
| model_name = node["id"] | |
| if model_name in creation_dates: | |
| creation_date = creation_dates[model_name] | |
| node.update({ | |
| "date": creation_date.isoformat(), | |
| "year": creation_date.year, | |
| "timestamp": creation_date.timestamp() | |
| }) | |
| else: | |
| # Fallback for models without date info | |
| node.update({ | |
| "date": "2020-01-01T00:00:00", # Default date | |
| "year": 2020, | |
| "timestamp": datetime(2020, 1, 1).timestamp() | |
| }) | |
| # Add timeline metadata | |
| valid_dates = [n for n in graph["nodes"] if n["timestamp"] > 0] | |
| if valid_dates: | |
| min_year = min(n["year"] for n in valid_dates) | |
| max_year = max(n["year"] for n in valid_dates) | |
| graph["timeline_meta"] = { | |
| "min_year": min_year, | |
| "max_year": max_year, | |
| "total_models": len(graph["nodes"]), | |
| "dated_models": len(valid_dates) | |
| } | |
| else: | |
| graph["timeline_meta"] = { | |
| "min_year": 2018, | |
| "max_year": 2024, | |
| "total_models": len(graph["nodes"]), | |
| "dated_models": 0 | |
| } | |
| return graph | |
| def generate_html(graph: dict) -> str: | |
| """Return the full HTML string with inlined CSS/JS + graph JSON.""" | |
| js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) | |
| return HTML.replace("__CSS__", CSS).replace("__JS__", js) | |
| def generate_timeline_html(timeline: dict) -> str: | |
| """Return the full HTML string for chronological timeline visualization.""" | |
| js = TIMELINE_JS.replace("__TIMELINE_DATA__", json.dumps(timeline, separators=(",", ":"))) | |
| return TIMELINE_HTML.replace("__TIMELINE_CSS__", TIMELINE_CSS).replace("__TIMELINE_JS__", js) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3) HTML (D3.js) boilerplate β CSS + JS templates (unchanged design) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | |
| :root{ | |
| --bg:#ffffff; | |
| --text:#222222; | |
| --muted:#555555; | |
| --outline:#ffffff; | |
| } | |
| @media (prefers-color-scheme: dark){ | |
| :root{ | |
| --bg:#0b0d10; | |
| --text:#e8e8e8; | |
| --muted:#c8c8c8; | |
| --outline:#000000; | |
| } | |
| } | |
| body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } | |
| svg{ width:100vw; height:100vh; } | |
| .link{ stroke:#999; stroke-opacity:.6; } | |
| .link.cand{ stroke:#e63946; stroke-width:2.5; } | |
| .node-label{ | |
| fill:var(--text); | |
| pointer-events:none; | |
| text-anchor:middle; | |
| font-weight:600; | |
| paint-order:stroke fill; | |
| stroke:var(--outline); | |
| stroke-width:3px; | |
| } | |
| .link-label{ | |
| fill:var(--muted); | |
| pointer-events:none; | |
| text-anchor:middle; | |
| font-size:12px; | |
| paint-order:stroke fill; | |
| stroke:var(--bg); | |
| stroke-width:2px; | |
| } | |
| .node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); } | |
| .node.derived circle{ fill:#1f77b4; } | |
| .node.cand circle, .node.cand path{ fill:#e63946; } | |
| #legend{ | |
| position:fixed; top:18px; left:18px; | |
| background:rgba(255,255,255,.92); | |
| padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; | |
| font-size:22px; box-shadow:0 2px 8px rgba(0,0,0,.08); | |
| } | |
| @media (prefers-color-scheme: dark){ | |
| #legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; } | |
| } | |
| """ | |
| JS = """ | |
| function updateVisibility() { | |
| const show = document.getElementById('toggleRed').checked; | |
| svg.selectAll('.link.cand').style('display', show ? null : 'none'); | |
| svg.selectAll('.node.cand').style('display', show ? null : 'none'); | |
| svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none'); | |
| } | |
| document.getElementById('toggleRed').addEventListener('change', updateVisibility); | |
| const graph = __GRAPH_DATA__; | |
| const W = innerWidth, H = innerHeight; | |
| const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); | |
| const g = svg.append('g'); | |
| const link = g.selectAll('line') | |
| .data(graph.links) | |
| .join('line') | |
| .attr('class', d => d.cand ? 'link cand' : 'link'); | |
| const linkLbl = g.selectAll('text.link-label') | |
| .data(graph.links) | |
| .join('text') | |
| .attr('class', 'link-label') | |
| .text(d => d.label); | |
| const node = g.selectAll('g.node') | |
| .data(graph.nodes) | |
| .join('g') | |
| .attr('class', d => `node ${d.cls}`) | |
| .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); | |
| const baseSel = node.filter(d => d.cls === 'base'); | |
| baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b'); | |
| node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz); | |
| node.append('text') | |
| .attr('class','node-label') | |
| .attr('dy','-2.4em') | |
| .style('font-size', d => d.cls === 'base' ? '160px' : '120px') | |
| .style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal') | |
| .text(d => d.id); | |
| const sim = d3.forceSimulation(graph.nodes) | |
| .force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) | |
| .force('charge', d3.forceManyBody().strength(-1200)) | |
| .force('center', d3.forceCenter(W / 2, H / 2)) | |
| .force('collide', d3.forceCollide(d => 50)); | |
| sim.on('tick', () => { | |
| link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) | |
| .attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); | |
| linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) | |
| .attr('y', d=> (d.source.y+d.target.y)/2); | |
| node.attr('transform', d=>`translate(${d.x},${d.y})`); | |
| }); | |
| function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } | |
| function dragged(e,d){ d.fx=e.x; d.fy=e.y; } | |
| function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } | |
| """ | |
| HTML = """ | |
| <!DOCTYPE html> | |
| <html lang='en'><head><meta charset='UTF-8'> | |
| <title>Transformers modular graph</title> | |
| <style>__CSS__</style></head><body> | |
| <div id='legend'> | |
| π‘ base<br>π΅ modular<br>π΄ candidate<br>red edgeΒ = high embedding similarity<br><br> | |
| <label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label> | |
| </div> | |
| <svg id='dependency'></svg> | |
| <script src='https://d3js.org/d3.v7.min.js'></script> | |
| <script>__JS__</script></body></html> | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Timeline HTML Templates | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TIMELINE_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | |
| :root{ | |
| --bg:#ffffff; | |
| --text:#222222; | |
| --muted:#555555; | |
| --outline:#ffffff; | |
| --timeline-line:#dee2e6; | |
| --base-color:#ffbe0b; | |
| --derived-color:#1f77b4; | |
| --candidate-color:#e63946; | |
| } | |
| @media (prefers-color-scheme: dark){ | |
| :root{ | |
| --bg:#0b0d10; | |
| --text:#e8e8e8; | |
| --muted:#c8c8c8; | |
| --outline:#000000; | |
| --timeline-line:#343a40; | |
| } | |
| } | |
| body{ | |
| margin:0; | |
| font-family:'Inter',Arial,sans-serif; | |
| background:var(--bg); | |
| overflow:hidden; | |
| } | |
| svg{ width:100vw; height:100vh; } | |
| /* Enhanced link styles for chronological flow */ | |
| .link{ | |
| stroke:#4a90e2; | |
| stroke-opacity:0.6; | |
| stroke-width:1.5; | |
| fill:none; | |
| transition: stroke-opacity 0.3s ease; | |
| } | |
| .link.cand{ | |
| stroke:var(--candidate-color); | |
| stroke-width:2.5; | |
| stroke-opacity:0.8; | |
| stroke-dasharray: 4,4; | |
| } | |
| .link:hover{ | |
| stroke-opacity:1; | |
| stroke-width:3; | |
| } | |
| /* Improved node label styling */ | |
| .node-label{ | |
| fill:var(--text); | |
| pointer-events:none; | |
| text-anchor:middle; | |
| font-weight:600; | |
| font-size:50px; | |
| paint-order:stroke fill; | |
| stroke:var(--outline); | |
| stroke-width:3px; | |
| cursor:default; | |
| } | |
| /* Enhanced node styling with better visual hierarchy */ | |
| .node.base circle{ | |
| fill:var(--base-color); | |
| stroke:#d4a000; | |
| stroke-width:2; | |
| } | |
| .node.derived circle{ | |
| fill:var(--derived-color); | |
| stroke:#1565c0; | |
| stroke-width:2; | |
| } | |
| .node.cand circle{ | |
| fill:var(--candidate-color); | |
| stroke:#c62828; | |
| stroke-width:2; | |
| } | |
| .node circle{ | |
| transition: r 0.3s ease, stroke-width 0.3s ease; | |
| cursor:grab; | |
| } | |
| .node:hover circle{ | |
| r:22; | |
| stroke-width:3; | |
| } | |
| .node:active{ | |
| cursor:grabbing; | |
| } | |
| /* Timeline axis styling */ | |
| .timeline-axis { | |
| stroke: var(--timeline-line); | |
| stroke-width: 3px; | |
| stroke-opacity: 0.8; | |
| } | |
| .timeline-tick { | |
| stroke: var(--timeline-line); | |
| stroke-width: 2px; | |
| stroke-opacity: 0.6; | |
| } | |
| .timeline-month-tick { | |
| stroke: var(--timeline-line); | |
| stroke-width: 1px; | |
| stroke-opacity: 0.4; | |
| } | |
| .timeline-label { | |
| fill: var(--muted); | |
| font-size: 40px; | |
| font-weight: 600; | |
| text-anchor: middle; | |
| } | |
| .timeline-month-label { | |
| fill: var(--muted); | |
| font-size: 35px; | |
| font-weight: 400; | |
| text-anchor: middle; | |
| opacity: 0.7; | |
| } | |
| .modular-milestone { | |
| stroke: #ff6b35; | |
| stroke-width: 3px; | |
| stroke-opacity: 0.8; | |
| stroke-dasharray: 5,5; | |
| } | |
| .modular-milestone-label { | |
| fill: #ff6b35; | |
| font-size: 35px; | |
| font-weight: 600; | |
| text-anchor: middle; | |
| } | |
| /* Enhanced controls panel */ | |
| #controls{ | |
| position:fixed; top:20px; left:20px; | |
| background:rgba(255,255,255,.95); | |
| padding:20px 26px; border-radius:12px; border:1.5px solid #e0e0e0; | |
| font-size:24px; box-shadow:0 4px 16px rgba(0,0,0,.12); | |
| z-index: 100; | |
| backdrop-filter: blur(8px); | |
| max-width: 280px; | |
| } | |
| @media (prefers-color-scheme: dark){ | |
| #controls{ | |
| background:rgba(20,22,25,.95); | |
| color:#e8e8e8; | |
| border-color:#404040; | |
| } | |
| } | |
| #controls label{ | |
| display:flex; | |
| align-items:center; | |
| margin-top:10px; | |
| cursor:pointer; | |
| } | |
| #controls input[type="checkbox"]{ | |
| margin-right:8px; | |
| cursor:pointer; | |
| } | |
| /* Compact controls */ | |
| #controls{ | |
| position:absolute; top:12px; left:12px; | |
| width:200px; max-width:42vw; | |
| padding:10px 12px; | |
| font-size:12px; line-height:1.25; | |
| background:rgba(0,0,0,.04); | |
| border:1px solid rgba(0,0,0,.08); | |
| border-radius:10px; border-left-width:2px; | |
| backdrop-filter:saturate(140%) blur(6px); | |
| z-index:5; | |
| } | |
| #controls b, #controls strong{ font-weight:600; } | |
| #controls > div:first-child{ margin-bottom:6px; font-size:12px; } | |
| #controls label{ display:block; margin-top:6px; } | |
| #controls input[type="text"]{ | |
| margin-top:8px; width:100%; | |
| padding:4px 6px; font-size:12px; | |
| border-radius:6px; border:1px solid #ccc; background:transparent; | |
| } | |
| #controls .hint{ margin-top:8px; font-size:11px; color:var(--muted); } | |
| /* Slightly smaller on narrow embeds */ | |
| @media (max-width: 900px){ | |
| #controls{ width:180px; font-size:11px; padding:8px 10px; } | |
| #controls input[type="text"]{ font-size:11px; padding:3px 6px; } | |
| } | |
| """ | |
| TIMELINE_JS = """ | |
| function updateVisibility() { | |
| const show = document.getElementById('toggleRed').checked; | |
| svg.selectAll('.link.cand').style('display', show ? null : 'none'); | |
| svg.selectAll('.node.cand').style('display', show ? null : 'none'); | |
| } | |
| document.getElementById('toggleRed').addEventListener('change', updateVisibility); | |
| const timeline = __TIMELINE_DATA__; | |
| const W = innerWidth, H = innerHeight; | |
| // Create SVG with zoom behavior | |
| const svg = d3.select('#timeline-svg'); | |
| const g = svg.append('g'); | |
| // Enhanced timeline configuration for maximum horizontal spread | |
| const MARGIN = { top: 60, right: 200, bottom: 120, left: 200 }; | |
| const CONTENT_HEIGHT = H - MARGIN.top - MARGIN.bottom; | |
| const VERTICAL_LANES = 4; // Number of horizontal lanes for better organization | |
| const zoomBehavior = d3.zoom() | |
| .scaleExtent([0.1, 8]) | |
| .on('zoom', handleZoom); | |
| svg.call(zoomBehavior); | |
| svg.on("click", function(event) { | |
| if (event.target.tagName === "svg") { | |
| node.select("circle").style("opacity", 1); | |
| link.style("opacity", 1); | |
| g.selectAll(".node-label").style("opacity", 1); | |
| } | |
| }); | |
| // Time scale for chronological positioning with much wider spread | |
| const timeExtent = d3.extent(timeline.nodes.filter(d => d.timestamp > 0), d => d.timestamp); | |
| let timeScale; | |
| if (timeExtent[0] && timeExtent[1]) { | |
| // Much wider timeline for maximum horizontal spread | |
| const timeWidth = Math.max(W * 8, 8000); | |
| timeScale = d3.scaleTime() | |
| .domain(timeExtent.map(t => new Date(t * 1000))) | |
| .range([MARGIN.left, timeWidth - MARGIN.right]); | |
| // Timeline axis at the bottom | |
| const timelineG = g.append('g').attr('class', 'timeline'); | |
| const timelineY = H - 80; | |
| timelineG.append('line') | |
| .attr('class', 'timeline-axis') | |
| .attr('x1', MARGIN.left) | |
| .attr('y1', timelineY) | |
| .attr('x2', timeWidth - MARGIN.right) | |
| .attr('y2', timelineY); | |
| // Enhanced year markers with better spacing | |
| const years = d3.timeYear.range(new Date(timeExtent[0] * 1000), new Date(timeExtent[1] * 1000 + 365*24*60*60*1000)); | |
| const months = d3.timeMonth.range(new Date(timeExtent[0] * 1000), new Date(timeExtent[1] * 1000 + 365*24*60*60*1000)); | |
| timelineG.selectAll('.timeline-tick') | |
| .data(years) | |
| .join('line') | |
| .attr('class', 'timeline-tick') | |
| .attr('x1', d => timeScale(d)) | |
| .attr('y1', timelineY - 15) | |
| .attr('x2', d => timeScale(d)) | |
| .attr('y2', timelineY + 15); | |
| timelineG.selectAll('.timeline-month-tick') | |
| .data(months) | |
| .join('line') | |
| .attr('class', 'timeline-month-tick') | |
| .attr('x1', d => timeScale(d)) | |
| .attr('y1', timelineY - 8) | |
| .attr('x2', d => timeScale(d)) | |
| .attr('y2', timelineY + 8); | |
| timelineG.selectAll('.timeline-label') | |
| .data(years) | |
| .join('text') | |
| .attr('class', 'timeline-label') | |
| .attr('x', d => timeScale(d)) | |
| .attr('y', timelineY + 30) | |
| .text(d => d.getFullYear()); | |
| timelineG.selectAll('.timeline-month-label') | |
| .data(months.filter((d, i) => i % 3 === 0)) | |
| .join('text') | |
| .attr('class', 'timeline-month-label') | |
| .attr('x', d => timeScale(d)) | |
| .attr('y', timelineY + 45) | |
| .text(d => d.toLocaleDateString('en', { month: 'short' })); | |
| // Modular logic milestone marker - May 31, 2024 | |
| const modularDate = new Date(2024, 4, 31); | |
| timelineG.append('line') | |
| .attr('class', 'modular-milestone') | |
| .attr('x1', timeScale(modularDate)) | |
| .attr('y1', MARGIN.top) | |
| .attr('x2', timeScale(modularDate)) | |
| .attr('y2', H - MARGIN.bottom); | |
| timelineG.append('text') | |
| .attr('class', 'modular-milestone-label') | |
| .attr('x', timeScale(modularDate)) | |
| .attr('y', MARGIN.top - 10) | |
| .attr('text-anchor', 'middle') | |
| .text('Modular Logic Added'); | |
| } | |
| function handleZoom(event) { | |
| const { transform } = event; | |
| g.attr('transform', transform); | |
| } | |
| // Enhanced curved links for better chronological flow visualization | |
| const link = g.selectAll('path.link') | |
| .data(timeline.links) | |
| .join('path') | |
| .attr('class', d => d.cand ? 'link cand' : 'link') | |
| .attr('fill', 'none') | |
| .attr('stroke-width', d => d.cand ? 2.5 : 1.5); | |
| const linkedByIndex = {}; | |
| timeline.links.forEach(d => { | |
| const s = typeof d.source === 'object' ? d.source.id : d.source; | |
| const t = typeof d.target === 'object' ? d.target.id : d.target; | |
| linkedByIndex[`${s},${t}`] = true; | |
| linkedByIndex[`${t},${s}`] = true; | |
| }); | |
| function isConnected(a, b) { | |
| return linkedByIndex[`${a.id},${b.id}`] || a.id === b.id; | |
| } | |
| function isConnected(a, b) { | |
| return linkedByIndex[`${a.id},${b.id}`] || a.id === b.id; | |
| } | |
| // Nodes with improved positioning strategy | |
| const node = g.selectAll('g.node') | |
| .data(timeline.nodes) | |
| .join('g') | |
| .attr('class', d => `node ${d.cls}`) | |
| .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); | |
| node.on("click", function(event, d) { | |
| event.stopPropagation(); | |
| node.select("circle").style("opacity", o => isConnected(d, o) ? 1 : 0.1); | |
| g.selectAll(".node-label").style("opacity", o => isConnected(d, o) ? 1 : 0.1); | |
| link.style("opacity", o => (o.source.id === d.id || o.target.id === d.id) ? 1 : 0.1); | |
| }); | |
| const baseSel = node.filter(d => d.cls === 'base'); | |
| baseSel.append('circle').attr('r', 20).attr('fill', '#ffbe0b'); | |
| node.filter(d => d.cls !== 'base').append('circle').attr('r', 18); | |
| node.append('text') | |
| .attr('class', 'node-label') | |
| .attr('dy', '-2.2em') | |
| .text(d => d.id); | |
| // Organize nodes by chronological lanes for better vertical distribution | |
| timeline.nodes.forEach((d, i) => { | |
| if (d.timestamp > 0) { | |
| // Assign lane based on chronological order within similar timeframes | |
| const yearNodes = timeline.nodes.filter(n => | |
| n.timestamp > 0 && | |
| Math.abs(n.timestamp - d.timestamp) < 365*24*60*60 | |
| ); | |
| d.lane = yearNodes.indexOf(d) % VERTICAL_LANES; | |
| } else { | |
| d.lane = i % VERTICAL_LANES; | |
| } | |
| }); | |
| // Enhanced force simulation for optimal horizontal chronological layout | |
| const sim = d3.forceSimulation(timeline.nodes) | |
| .force('link', d3.forceLink(timeline.links).id(d => d.id) | |
| .distance(d => d.cand ? 200 : 300) | |
| .strength(d => d.cand ? 0.1 : 0.3)) | |
| .force('charge', d3.forceManyBody().strength(-1600)) | |
| .force('collide', d3.forceCollide(d => 70).strength(1)) | |
| // Very strong chronological X positioning for proper horizontal spread | |
| if (timeScale) { | |
| sim.force('chronological', d3.forceX(d => { | |
| if (d.timestamp > 0) { | |
| return timeScale(new Date(d.timestamp * 1000)); | |
| } | |
| // Place undated models at the end | |
| return timeScale.range()[1] + 100; | |
| }).strength(0.75)); | |
| } | |
| // Organized Y positioning using lanes instead of random spread | |
| sim.force('lanes', d3.forceY(d => { | |
| const centerY = H / 2 - 100; // Position above timeline | |
| const laneHeight = (H - 200) / (VERTICAL_LANES + 1); // Account for timeline space | |
| const targetY = centerY - ((H - 200) / 2) + (d.lane + 1) * laneHeight; | |
| return targetY; | |
| }).strength(0.7)); | |
| // Add center force to prevent rightward drift | |
| sim.force('center', d3.forceCenter(timeScale ? (timeScale.range()[0] + timeScale.range()[1]) / 2 : W / 2, H / 2 - 100).strength(0.1)); | |
| // Custom path generator for curved links that follow chronological flow | |
| function linkPath(d) { | |
| const sourceX = d.source.x || 0; | |
| const sourceY = d.source.y || 0; | |
| const targetX = d.target.x || 0; | |
| const targetY = d.target.y || 0; | |
| // Create curved paths for better visual flow | |
| const dx = targetX - sourceX; | |
| const dy = targetY - sourceY; | |
| const dr = Math.sqrt(dx * dx + dy * dy) * 0.3; | |
| // Curve direction based on chronological order | |
| const curve = dx > 0 ? dr : -dr; | |
| return `M${sourceX},${sourceY}A${dr},${dr} 0 0,1 ${targetX},${targetY}`; | |
| } | |
| function idOf(x){ return typeof x === 'object' ? x.id : x; } | |
| function neighborsOf(id){ | |
| const out = new Set([id]); | |
| Object.keys(linkedByIndex).forEach(k=>{ | |
| const [a,b] = k.split(','); | |
| if(a===id) out.add(b); | |
| if(b===id) out.add(a); | |
| }); | |
| return out; | |
| } | |
| // Highlight matches + neighbors; empty query resets | |
| function applySearch(q){ | |
| q = (q || '').trim().toLowerCase(); | |
| if(!q){ | |
| node.select("circle").style("opacity", 1); | |
| g.selectAll(".node-label").style("opacity", 1); | |
| link.style("opacity", 1); | |
| return; | |
| } | |
| const matches = new Set(timeline.nodes.filter(n => n.id.toLowerCase().includes(q)).map(n=>n.id)); | |
| const keep = new Set(); | |
| matches.forEach(m => neighborsOf(m).forEach(x => keep.add(x))); | |
| node.select("circle").style("opacity", d => keep.has(d.id) ? 1 : 0.08); | |
| g.selectAll(".node-label").style("opacity", d => keep.has(d.id) ? 1 : 0.08); | |
| link.style("opacity", d => { | |
| const s = idOf(d.source), t = idOf(d.target); | |
| return (keep.has(s) && keep.has(t)) ? 1 : 0.08; | |
| }); | |
| } | |
| // wire it up | |
| document.getElementById('searchBox').addEventListener('input', e => applySearch(e.target.value)); | |
| sim.on('tick', () => { | |
| link.attr('d', linkPath); | |
| node.attr('transform', d => `translate(${d.x},${d.y})`); | |
| }); | |
| function dragStart(e, d) { | |
| if (!e.active) sim.alphaTarget(.3).restart(); | |
| d.fx = d.x; | |
| d.fy = d.y; | |
| } | |
| function dragged(e, d) { | |
| d.fx = e.x; | |
| d.fy = e.y; | |
| } | |
| function dragEnd(e, d) { | |
| if (!e.active) sim.alphaTarget(0); | |
| d.fx = d.fy = null; | |
| } | |
| // Initialize | |
| updateVisibility(); | |
| // Auto-fit timeline view with better zoom for horizontal spread | |
| setTimeout(() => { | |
| if (timeScale && timeExtent[0] && timeExtent[1]) { | |
| const timeWidth = timeScale.range()[1] - timeScale.range()[0]; | |
| const scale = Math.min((W * 0.9) / timeWidth, 1); | |
| const translateX = (W - timeWidth * scale) / 2; | |
| const translateY = 0; | |
| svg.transition() | |
| .duration(2000) | |
| .call(zoomBehavior.transform, | |
| d3.zoomIdentity.translate(translateX, translateY).scale(scale)); | |
| } | |
| }, 1500); | |
| """ | |
| TIMELINE_HTML = """ | |
| <!DOCTYPE html> | |
| <html lang='en'><head><meta charset='UTF-8'> | |
| <title>Transformers Chronological Timeline</title> | |
| <style>__TIMELINE_CSS__</style></head><body> | |
| <div id='controls'> | |
| <div style='font-weight:600;'>Chronological Timeline</div> | |
| <div style="margin:4px 0 6px 0;">π‘ base Β· π΅ modular Β· π΄ candidate</div> | |
| <label><input type="checkbox" id="toggleRed" checked> Show candidates</label> | |
| <input id="searchBox" type="text" placeholder="Search modelβ¦"> | |
| <div class="hint">Positioned by creation date. Scroll & zoom to explore.</div> | |
| </div> | |
| <svg id='timeline-svg'></svg> | |
| <script src='https://d3js.org/d3.v7.min.js'></script> | |
| <script>__TIMELINE_JS__</script></body></html> | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HTML writer | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def write_html(graph_data: dict, path: Path): | |
| path.write_text(generate_html(graph_data), encoding="utf-8") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") | |
| ap.add_argument("transformers", help="Path to local π€ transformers repo root") | |
| ap.add_argument("--multimodal", action="store_true", help="filter to models with β₯3 'pixel_values'") | |
| ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) | |
| ap.add_argument("--out", default=HTML_DEFAULT) | |
| ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", | |
| help="Similarity method: 'jaccard' or 'embedding'") | |
| args = ap.parse_args() | |
| graph = build_graph_json( | |
| transformers_dir=Path(args.transformers).expanduser().resolve(), | |
| threshold=args.sim_threshold, | |
| multimodal=args.multimodal, | |
| sim_method=args.sim_method, | |
| ) | |
| write_html(graph, Path(args.out).expanduser()) | |
| if __name__ == "__main__": | |
| main() | |