STEM-Question-Generator / level_classifier_tool_2.py
bhardwaj08sarthak's picture
Upload 5 files
79418f8 verified
raw
history blame
9.36 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any
import math
import torch
from transformers import AutoTokenizer, AutoModel
#import tensorflow
Agg = Literal["mean", "max", "topk_mean"]
# --------------------------- Embedding backend ---------------------------
@dataclass
class HFEmbeddingBackend:
"""
Minimal huggingface transformers encoder for sentence-level embeddings.
Uses mean pooling over last_hidden_state and L2 normalizes the result.
"""
model_name: str = "google/embeddinggemma-300m"
device = "cuda" if torch.cuda.is_available() else "cpu"
TOK = AutoTokenizer.from_pretrained(model_name)
MODEL = AutoModel.from_pretrained(model_name)
MODEL.to(device).eval()
def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[torch.Tensor, list[str]]":
"""
Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
"""
texts_list = list(texts)
if not texts_list:
return torch.empty((0, self.MODEL.config.hidden_size)), [] # type: ignore
all_out = []
with torch.inference_mode():
for i in range(0, len(texts_list), batch_size):
batch = texts_list[i:i + batch_size]
enc = self.TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore
out = self.MODEL(**enc)
last = out.last_hidden_state # [B, T, H]
mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
# mean pool
summed = (last * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1)
pooled = summed / counts
# L2 normalize
pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
all_out.append(pooled.cpu())
embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
return embs, texts_list
# --------------------------- Utilities ---------------------------
def _normalize_whitespace(s: str) -> str:
return " ".join(s.strip().split())
def _default_preprocess(s: str) -> str:
# Keep simple, deterministic preprocessing. Users can override with a custom callable.
return _normalize_whitespace(s)
@dataclass
class PhraseIndex:
phrases_by_level: Dict[str, List[str]]
embeddings_by_level: Dict[str, "Any"] # torch.Tensor, but keep Any to avoid hard dep at import time
model_name: str
def build_phrase_index(
backend: HFEmbeddingBackend,
phrases_by_level: Dict[str, Iterable[str]],
) -> PhraseIndex:
"""
Pre-encode all anchor phrases per level into a searchable index.
"""
# Flatten texts while preserving level boundaries
cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()}
all_texts: List[str] = []
spans: List[Tuple[str, int, int]] = [] # (level, start, end) in the flat list
cur = 0
for lvl, plist in cleaned.items():
start = cur
all_texts.extend(plist)
cur += len(plist)
spans.append((lvl, start, cur))
embs, _ = backend.encode(all_texts)
# Slice embeddings back into level buckets
embeddings_by_level: Dict[str, "Any"] = {}
for lvl, start, end in spans:
embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) # type: ignore
return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
embeddings_by_level=embeddings_by_level,
model_name=backend.model_name)
def _aggregate_sims(
sims: "Any", agg: Agg, topk: int
) -> float:
"""
Aggregate a 1D tensor of similarities into a single score.
"""
if sims.numel() == 0:
return float("nan")
if agg == "mean":
return float(sims.mean().item())
if agg == "max":
return float(sims.max().item())
if agg == "topk_mean":
k = min(topk, sims.numel())
topk_vals, _ = torch.topk(sims, k)
return float(topk_vals.mean().item())
raise ValueError(f"Unknown agg: {agg}")
# --------------------------- Public API ---------------------------
def classify_levels_phrases(
question: str,
blooms_phrases: Dict[str, Iterable[str]],
dok_phrases: Dict[str, Iterable[str]],
*,
model_name: str = "google/embeddinggemma-300m",
agg: Agg = "max",
topk: int = 5,
preprocess: Optional[Callable[[str], str]] = None,
backend: Optional[HFEmbeddingBackend] = None,
prebuilt_bloom_index: Optional[PhraseIndex] = None,
prebuilt_dok_index: Optional[PhraseIndex] = None,
return_phrase_matches: bool = True,
) -> Dict[str, Any]:
"""
Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
using cosine similarity to level-specific anchor phrases.
Parameters
----------
question : str
The input question or prompt.
blooms_phrases : dict[str, Iterable[str]]
Mapping level -> list of anchor phrases for Bloom's.
dok_phrases : dict[str, Iterable[str]]
Mapping level -> list of anchor phrases for DOK.
model_name : str
Hugging Face model name for text embeddings. Ignored when `backend` provided.
agg : {"mean","max","topk_mean"}
Aggregation over phrase similarities within a level.
topk : int
Used only when `agg="topk_mean"`.
preprocess : Optional[Callable[[str], str]]
Preprocessing function for the question string. Defaults to whitespace normalization.
backend : Optional[HFEmbeddingBackend]
Injected embedding backend. If not given, one is constructed.
prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex]
If provided, reuse precomputed phrase embeddings to avoid re-encoding.
return_phrase_matches : bool
If True, returns per-level top contributing phrases.
Returns
-------
dict
{
"question": ...,
"model_name": ...,
"blooms": {
"scores": {level: float, ...},
"best_level": str,
"best_score": float,
"top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
},
"dok": {
"scores": {level: float, ...},
"best_level": str,
"best_score": float,
"top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
},
"config": {"agg": agg, "topk": topk if agg=='topk_mean' else None}
}
"""
preprocess = preprocess or _default_preprocess
question_clean = preprocess(question)
# Prepare backend
be = backend or HFEmbeddingBackend(model_name=model_name)
# Build / reuse indices
bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
# Encode question
q_emb, _ = be.encode([question_clean])
q_emb = q_emb[0:1] # [1, D]
def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
scores: Dict[str, float] = {}
top_contribs: Dict[str, List[Tuple[str, float]]] = {}
for lvl, phrases in index.phrases_by_level.items():
embs = index.embeddings_by_level[lvl] # [N, D]
if embs.numel() == 0:
scores[lvl] = float("nan")
top_contribs[lvl] = []
continue
sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
scores[lvl] = _aggregate_sims(sims, agg, topk)
if return_phrase_matches:
k = min(5, sims.numel())
vals, idxs = torch.topk(sims, k)
top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)]
return scores, top_contribs
bloom_scores, bloom_top = _score_block(bloom_index)
dok_scores, dok_top = _score_block(dok_index)
def _best(scores: Dict[str, float]) -> Tuple[str, float]:
# max with NaN-safe handling
best_lvl, best_val = None, -float("inf")
for lvl, val in scores.items():
if isinstance(val, float) and (not math.isnan(val)) and val > best_val:
best_lvl, best_val = lvl, val
return best_lvl or "", best_val
best_bloom, best_bloom_val = _best(bloom_scores)
best_dok, best_dok_val = _best(dok_scores)
return {
"question": question_clean,
"model_name": be.model_name,
"blooms": {
"scores": bloom_scores,
"best_level": best_bloom,
"best_score": best_bloom_val,
"top_phrases": bloom_top if return_phrase_matches else None,
},
"dok": {
"scores": dok_scores,
"best_level": best_dok,
"best_score": best_dok_val,
"top_phrases": dok_top if return_phrase_matches else None,
},
"config": {
"agg": agg,
"topk": topk if agg == "topk_mean" else None,
},
}