|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any |
|
|
import math |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
Agg = Literal["mean", "max", "topk_mean"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HFEmbeddingBackend: |
|
|
""" |
|
|
Minimal huggingface transformers encoder for sentence-level embeddings. |
|
|
Uses mean pooling over last_hidden_state and L2 normalizes the result. |
|
|
""" |
|
|
model_name: str = "google/embeddinggemma-300m" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
TOK = AutoTokenizer.from_pretrained(model_name) |
|
|
MODEL = AutoModel.from_pretrained(model_name) |
|
|
MODEL.to(device).eval() |
|
|
|
|
|
def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[torch.Tensor, list[str]]": |
|
|
""" |
|
|
Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized. |
|
|
""" |
|
|
texts_list = list(texts) |
|
|
if not texts_list: |
|
|
return torch.empty((0, self.MODEL.config.hidden_size)), [] |
|
|
|
|
|
all_out = [] |
|
|
with torch.inference_mode(): |
|
|
for i in range(0, len(texts_list), batch_size): |
|
|
batch = texts_list[i:i + batch_size] |
|
|
enc = self.TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
out = self.MODEL(**enc) |
|
|
last = out.last_hidden_state |
|
|
mask = enc["attention_mask"].unsqueeze(-1) |
|
|
|
|
|
summed = (last * mask).sum(dim=1) |
|
|
counts = mask.sum(dim=1).clamp(min=1) |
|
|
pooled = summed / counts |
|
|
|
|
|
pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12) |
|
|
all_out.append(pooled.cpu()) |
|
|
embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) |
|
|
return embs, texts_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_whitespace(s: str) -> str: |
|
|
return " ".join(s.strip().split()) |
|
|
|
|
|
|
|
|
def _default_preprocess(s: str) -> str: |
|
|
|
|
|
return _normalize_whitespace(s) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PhraseIndex: |
|
|
phrases_by_level: Dict[str, List[str]] |
|
|
embeddings_by_level: Dict[str, "Any"] |
|
|
model_name: str |
|
|
|
|
|
|
|
|
def build_phrase_index( |
|
|
backend: HFEmbeddingBackend, |
|
|
phrases_by_level: Dict[str, Iterable[str]], |
|
|
) -> PhraseIndex: |
|
|
""" |
|
|
Pre-encode all anchor phrases per level into a searchable index. |
|
|
""" |
|
|
|
|
|
cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()} |
|
|
all_texts: List[str] = [] |
|
|
spans: List[Tuple[str, int, int]] = [] |
|
|
cur = 0 |
|
|
for lvl, plist in cleaned.items(): |
|
|
start = cur |
|
|
all_texts.extend(plist) |
|
|
cur += len(plist) |
|
|
spans.append((lvl, start, cur)) |
|
|
|
|
|
embs, _ = backend.encode(all_texts) |
|
|
|
|
|
embeddings_by_level: Dict[str, "Any"] = {} |
|
|
for lvl, start, end in spans: |
|
|
embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) |
|
|
|
|
|
return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()}, |
|
|
embeddings_by_level=embeddings_by_level, |
|
|
model_name=backend.model_name) |
|
|
|
|
|
|
|
|
def _aggregate_sims( |
|
|
sims: "Any", agg: Agg, topk: int |
|
|
) -> float: |
|
|
""" |
|
|
Aggregate a 1D tensor of similarities into a single score. |
|
|
""" |
|
|
if sims.numel() == 0: |
|
|
return float("nan") |
|
|
if agg == "mean": |
|
|
return float(sims.mean().item()) |
|
|
if agg == "max": |
|
|
return float(sims.max().item()) |
|
|
if agg == "topk_mean": |
|
|
k = min(topk, sims.numel()) |
|
|
topk_vals, _ = torch.topk(sims, k) |
|
|
return float(topk_vals.mean().item()) |
|
|
raise ValueError(f"Unknown agg: {agg}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_levels_phrases( |
|
|
question: str, |
|
|
blooms_phrases: Dict[str, Iterable[str]], |
|
|
dok_phrases: Dict[str, Iterable[str]], |
|
|
*, |
|
|
model_name: str = "google/embeddinggemma-300m", |
|
|
agg: Agg = "max", |
|
|
topk: int = 5, |
|
|
preprocess: Optional[Callable[[str], str]] = None, |
|
|
backend: Optional[HFEmbeddingBackend] = None, |
|
|
prebuilt_bloom_index: Optional[PhraseIndex] = None, |
|
|
prebuilt_dok_index: Optional[PhraseIndex] = None, |
|
|
return_phrase_matches: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Score a question against Bloom's taxonomy and DOK (Depth of Knowledge) |
|
|
using cosine similarity to level-specific anchor phrases. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
question : str |
|
|
The input question or prompt. |
|
|
blooms_phrases : dict[str, Iterable[str]] |
|
|
Mapping level -> list of anchor phrases for Bloom's. |
|
|
dok_phrases : dict[str, Iterable[str]] |
|
|
Mapping level -> list of anchor phrases for DOK. |
|
|
model_name : str |
|
|
Hugging Face model name for text embeddings. Ignored when `backend` provided. |
|
|
agg : {"mean","max","topk_mean"} |
|
|
Aggregation over phrase similarities within a level. |
|
|
topk : int |
|
|
Used only when `agg="topk_mean"`. |
|
|
preprocess : Optional[Callable[[str], str]] |
|
|
Preprocessing function for the question string. Defaults to whitespace normalization. |
|
|
backend : Optional[HFEmbeddingBackend] |
|
|
Injected embedding backend. If not given, one is constructed. |
|
|
prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex] |
|
|
If provided, reuse precomputed phrase embeddings to avoid re-encoding. |
|
|
return_phrase_matches : bool |
|
|
If True, returns per-level top contributing phrases. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict |
|
|
{ |
|
|
"question": ..., |
|
|
"model_name": ..., |
|
|
"blooms": { |
|
|
"scores": {level: float, ...}, |
|
|
"best_level": str, |
|
|
"best_score": float, |
|
|
"top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches |
|
|
}, |
|
|
"dok": { |
|
|
"scores": {level: float, ...}, |
|
|
"best_level": str, |
|
|
"best_score": float, |
|
|
"top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches |
|
|
}, |
|
|
"config": {"agg": agg, "topk": topk if agg=='topk_mean' else None} |
|
|
} |
|
|
""" |
|
|
preprocess = preprocess or _default_preprocess |
|
|
question_clean = preprocess(question) |
|
|
|
|
|
|
|
|
be = backend or HFEmbeddingBackend(model_name=model_name) |
|
|
|
|
|
|
|
|
bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases) |
|
|
dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases) |
|
|
|
|
|
|
|
|
q_emb, _ = be.encode([question_clean]) |
|
|
q_emb = q_emb[0:1] |
|
|
|
|
|
def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]: |
|
|
scores: Dict[str, float] = {} |
|
|
top_contribs: Dict[str, List[Tuple[str, float]]] = {} |
|
|
|
|
|
for lvl, phrases in index.phrases_by_level.items(): |
|
|
embs = index.embeddings_by_level[lvl] |
|
|
if embs.numel() == 0: |
|
|
scores[lvl] = float("nan") |
|
|
top_contribs[lvl] = [] |
|
|
continue |
|
|
sims = (q_emb @ embs.T).squeeze(0) |
|
|
scores[lvl] = _aggregate_sims(sims, agg, topk) |
|
|
if return_phrase_matches: |
|
|
k = min(5, sims.numel()) |
|
|
vals, idxs = torch.topk(sims, k) |
|
|
top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)] |
|
|
return scores, top_contribs |
|
|
|
|
|
bloom_scores, bloom_top = _score_block(bloom_index) |
|
|
dok_scores, dok_top = _score_block(dok_index) |
|
|
|
|
|
def _best(scores: Dict[str, float]) -> Tuple[str, float]: |
|
|
|
|
|
best_lvl, best_val = None, -float("inf") |
|
|
for lvl, val in scores.items(): |
|
|
if isinstance(val, float) and (not math.isnan(val)) and val > best_val: |
|
|
best_lvl, best_val = lvl, val |
|
|
return best_lvl or "", best_val |
|
|
|
|
|
best_bloom, best_bloom_val = _best(bloom_scores) |
|
|
best_dok, best_dok_val = _best(dok_scores) |
|
|
|
|
|
return { |
|
|
"question": question_clean, |
|
|
"model_name": be.model_name, |
|
|
"blooms": { |
|
|
"scores": bloom_scores, |
|
|
"best_level": best_bloom, |
|
|
"best_score": best_bloom_val, |
|
|
"top_phrases": bloom_top if return_phrase_matches else None, |
|
|
}, |
|
|
"dok": { |
|
|
"scores": dok_scores, |
|
|
"best_level": best_dok, |
|
|
"best_score": best_dok_val, |
|
|
"top_phrases": dok_top if return_phrase_matches else None, |
|
|
}, |
|
|
"config": { |
|
|
"agg": agg, |
|
|
"topk": topk if agg == "topk_mean" else None, |
|
|
}, |
|
|
} |
|
|
|