from __future__ import annotations import os from dataclasses import dataclass, field from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any import math import torch from transformers import AutoTokenizer, AutoModel #import tensorflow Agg = Literal["mean", "max", "topk_mean"] @dataclass class HFEmbeddingBackend: """ Minimal huggingface transformers encoder for sentence-level embeddings. Uses mean pooling over last_hidden_state and L2 normalizes the result. """ model_name: str = "google/embeddinggemma-300m" device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu") # Lazy-initialized in __post_init__ TOK: Any = field(init=False, repr=False) MODEL: Any = field(init=False, repr=False) def __post_init__(self): # Nuke Spaces ZeroGPU if present (still good to keep) os.environ.setdefault("SPACES_ZERO_DISABLED", "1") try: import sys, importlib for modname in ( "spaces.zero", "spaces.zero.torch.patching", "spaces.zero.torch", "spaces.zero.patch", "spaces.zero.patching" ): try: m = sys.modules.get(modname) or importlib.import_module(modname) except Exception: continue for attr in ("disable", "unpatch", "deactivate"): fn = getattr(m, attr, None) if callable(fn): try: fn() except Exception: pass except Exception: pass # Prefer simple math attention kernels try: torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False) except Exception: pass # Make eager attention the default everywhere os.environ.setdefault("TRANSFORMERS_ATTENTION_IMPLEMENTATION", "eager") # Load tokenizer/model with eager attention self.TOK = AutoTokenizer.from_pretrained(self.model_name) self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager") self.MODEL.to(self.device).eval() def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]": """ Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized. """ texts_list = list(texts) if not texts_list: return torch.empty((0, self.MODEL.config.hidden_size)), [] all_out = [] with torch.inference_mode(): for i in range(0, len(texts_list), batch_size): batch = texts_list[i:i + batch_size] enc = self.TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore out = self.MODEL(**enc) last = out.last_hidden_state # [B, T, H] mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1] # Mean pool summed = (last * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1) pooled = summed / counts # L2 normalize pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12) all_out.append(pooled.cpu()) embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore return embs, texts_list def _normalize_whitespace(s: str) -> str: return " ".join(s.strip().split()) def _default_preprocess(s: str) -> str: # Keep simple, deterministic preprocessing. Users can override with a custom callable. return _normalize_whitespace(s) @dataclass class PhraseIndex: phrases_by_level: Dict[str, List[str]] embeddings_by_level: Dict[str, "Any"] # torch.Tensor, but keep Any to avoid hard dep at import time model_name: str def build_phrase_index( backend: HFEmbeddingBackend, phrases_by_level: Dict[str, Iterable[str]], ) -> PhraseIndex: """ Pre-encode all anchor phrases per level into a searchable index. """ # Flatten texts while preserving level boundaries cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()} all_texts: List[str] = [] spans: List[Tuple[str, int, int]] = [] # (level, start, end) in the flat list cur = 0 for lvl, plist in cleaned.items(): start = cur all_texts.extend(plist) cur += len(plist) spans.append((lvl, start, cur)) embs, _ = backend.encode(all_texts) # Slice embeddings back into level buckets embeddings_by_level: Dict[str, "Any"] = {} for lvl, start, end in spans: embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) # type: ignore return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()}, embeddings_by_level=embeddings_by_level, model_name=backend.model_name) def _aggregate_sims( sims: "Any", agg: Agg, topk: int ) -> float: """ Aggregate a 1D tensor of similarities into a single score. """ if sims.numel() == 0: return float("nan") if agg == "mean": return float(sims.mean().item()) if agg == "max": return float(sims.max().item()) if agg == "topk_mean": k = min(topk, sims.numel()) topk_vals, _ = torch.topk(sims, k) return float(topk_vals.mean().item()) raise ValueError(f"Unknown agg: {agg}") # --------------------------- Public API --------------------------- def classify_levels_phrases( question: str, blooms_phrases: Dict[str, Iterable[str]], dok_phrases: Dict[str, Iterable[str]], *, model_name: str = "google/embeddinggemma-300m", agg: Agg = "max", topk: int = 5, preprocess: Optional[Callable[[str], str]] = None, backend: Optional[HFEmbeddingBackend] = None, prebuilt_bloom_index: Optional[PhraseIndex] = None, prebuilt_dok_index: Optional[PhraseIndex] = None, return_phrase_matches: bool = True, ) -> Dict[str, Any]: """ Score a question against Bloom's taxonomy and DOK (Depth of Knowledge) using cosine similarity to level-specific anchor phrases. Parameters ---------- question : str The input question or prompt. blooms_phrases : dict[str, Iterable[str]] Mapping level -> list of anchor phrases for Bloom's. dok_phrases : dict[str, Iterable[str]] Mapping level -> list of anchor phrases for DOK. model_name : str Hugging Face model name for text embeddings. Ignored when `backend` provided. agg : {"mean","max","topk_mean"} Aggregation over phrase similarities within a level. topk : int Used only when `agg="topk_mean"`. preprocess : Optional[Callable[[str], str]] Preprocessing function for the question string. Defaults to whitespace normalization. backend : Optional[HFEmbeddingBackend] Injected embedding backend. If not given, one is constructed. prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex] If provided, reuse precomputed phrase embeddings to avoid re-encoding. return_phrase_matches : bool If True, returns per-level top contributing phrases. Returns ------- dict { "question": ..., "model_name": ..., "blooms": { "scores": {level: float, ...}, "best_level": str, "best_score": float, "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches }, "dok": { "scores": {level: float, ...}, "best_level": str, "best_score": float, "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches }, "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None} } """ preprocess = preprocess or _default_preprocess question_clean = preprocess(question) # Prepare backend be = backend or HFEmbeddingBackend(model_name=model_name) # Build / reuse indices bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases) dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases) # Encode question q_emb, _ = be.encode([question_clean]) q_emb = q_emb[0:1] # [1, D] def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]: scores: Dict[str, float] = {} top_contribs: Dict[str, List[Tuple[str, float]]] = {} for lvl, phrases in index.phrases_by_level.items(): embs = index.embeddings_by_level[lvl] # [N, D] if embs.numel() == 0: scores[lvl] = float("nan") top_contribs[lvl] = [] continue sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm scores[lvl] = _aggregate_sims(sims, agg, topk) if return_phrase_matches: k = min(5, sims.numel()) vals, idxs = torch.topk(sims, k) top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)] return scores, top_contribs bloom_scores, bloom_top = _score_block(bloom_index) dok_scores, dok_top = _score_block(dok_index) def _best(scores: Dict[str, float]) -> Tuple[str, float]: # max with NaN-safe handling best_lvl, best_val = None, -float("inf") for lvl, val in scores.items(): if isinstance(val, float) and (not math.isnan(val)) and val > best_val: best_lvl, best_val = lvl, val return best_lvl or "", best_val best_bloom, best_bloom_val = _best(bloom_scores) best_dok, best_dok_val = _best(dok_scores) return { "question": question_clean, "model_name": be.model_name, "blooms": { "scores": bloom_scores, "best_level": best_bloom, "best_score": best_bloom_val, "top_phrases": bloom_top if return_phrase_matches else None, }, "dok": { "scores": dok_scores, "best_level": best_dok, "best_score": best_dok_val, "top_phrases": dok_top if return_phrase_matches else None, }, "config": { "agg": agg, "topk": topk if agg == "topk_mean" else None, }, }