Spaces:
Sleeping
Sleeping
redesigned modules
Browse files- utils/coherence_bbscore.py +260 -0
- utils/encoding_input.py +12 -0
- utils/generation_streaming.py +99 -0
- utils/loading_embeddings.py +55 -0
- utils/model_generation.py +211 -0
- utils/retrieve_n_rerank.py +79 -0
- utils/sentiment_analysis.py +218 -0
utils/coherence_bbscore.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install sentence-transformers (if not already)
|
| 2 |
+
import math, re, unicodedata
|
| 3 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os, re, unicodedata, numpy as np
|
| 6 |
+
# get the reranked results with no scores
|
| 7 |
+
from retrieve_n_rerank import retrieve_and_rerank
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
except Exception:
|
| 12 |
+
SentenceTransformer = None
|
| 13 |
+
|
| 14 |
+
# -----------------------------
|
| 15 |
+
# Text utilities
|
| 16 |
+
# -----------------------------
|
| 17 |
+
def _norm(t: str) -> str:
|
| 18 |
+
if t is None: return ""
|
| 19 |
+
t = unicodedata.normalize("NFKC", str(t))
|
| 20 |
+
t = re.sub(r"\s*\n\s*", " ", t)
|
| 21 |
+
t = re.sub(r"\s{2,}", " ", t)
|
| 22 |
+
return t.strip()
|
| 23 |
+
|
| 24 |
+
def split_sentences(text: str) -> List[str]:
|
| 25 |
+
t = _norm(text)
|
| 26 |
+
parts = re.split(r"(?<=[\.\?\!])\s+(?=[A-Z“\"'])", t)
|
| 27 |
+
return [p.strip() for p in parts if p.strip()]
|
| 28 |
+
|
| 29 |
+
# -----------------------------
|
| 30 |
+
# Embeddings wrapper
|
| 31 |
+
# -----------------------------
|
| 32 |
+
class Embedder:
|
| 33 |
+
def __init__(self, model_name: str = "BAAI/bge-m3", device: str = "cpu"):
|
| 34 |
+
if SentenceTransformer is None:
|
| 35 |
+
raise RuntimeError("Install sentence-transformers to enable coherence scoring.")
|
| 36 |
+
self.model = SentenceTransformer(model_name, device=device)
|
| 37 |
+
def encode(self, sentences: List[str]) -> np.ndarray:
|
| 38 |
+
if not sentences:
|
| 39 |
+
return np.zeros((0, 768), dtype=np.float32)
|
| 40 |
+
X = self.model.encode(sentences, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
|
| 41 |
+
return np.asarray(X, dtype=np.float32)
|
| 42 |
+
|
| 43 |
+
def _cos(a: np.ndarray, b: np.ndarray) -> float:
|
| 44 |
+
return float(np.dot(a, b))
|
| 45 |
+
|
| 46 |
+
def _normalize(v: np.ndarray) -> np.ndarray:
|
| 47 |
+
v = np.asarray(v, dtype=np.float32)
|
| 48 |
+
n = np.linalg.norm(v) + 1e-8
|
| 49 |
+
return v / n
|
| 50 |
+
|
| 51 |
+
# -----------------------------
|
| 52 |
+
# Brownian-bridge style metric
|
| 53 |
+
# -----------------------------
|
| 54 |
+
def bb_coherence(sentences: List[str], E: np.ndarray) -> Dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Brownian-bridge–inspired coherence:
|
| 57 |
+
- Build a main-idea vector (intro+outro+centroid)
|
| 58 |
+
- Compare per-sentence sim to target curve that's high at ends, lower mid
|
| 59 |
+
- Map max bridge deviation -> (0,1] score (higher=more coherent)
|
| 60 |
+
"""
|
| 61 |
+
n = len(sentences)
|
| 62 |
+
if n == 0:
|
| 63 |
+
return {"bbscore": 0.0, "sims": [], "off_idx": [], "rep_pairs": [], "sim_matrix": None}
|
| 64 |
+
|
| 65 |
+
k = max(1, min(3, n // 5))
|
| 66 |
+
v_first = E[:k].mean(axis=0)
|
| 67 |
+
v_last = E[-k:].mean(axis=0)
|
| 68 |
+
v_all = E.mean(axis=0)
|
| 69 |
+
v_main = _normalize(0.4*v_first + 0.4*v_last + 0.2*v_all)
|
| 70 |
+
|
| 71 |
+
sims = np.array([_cos(v_main, E[i]) for i in range(n)], dtype=np.float32)
|
| 72 |
+
t = np.linspace(0.0, 1.0, num=n, dtype=np.float32)
|
| 73 |
+
q = 1.0 - 4.0 * t * (1.0 - t) # peaks at ends
|
| 74 |
+
q = q / (q.mean() + 1e-8) * (sims.mean() if sims.size else 0.0)
|
| 75 |
+
|
| 76 |
+
r = sims - q
|
| 77 |
+
r_centered = r - r.mean()
|
| 78 |
+
cumsum = np.cumsum(r_centered)
|
| 79 |
+
B = cumsum - t * (cumsum[-1] if n > 1 else 0.0)
|
| 80 |
+
denom = (np.std(r_centered) * math.sqrt(n)) + 1e-8
|
| 81 |
+
ks = float(np.max(np.abs(B)) / denom)
|
| 82 |
+
bbscore = float(1.0 / (1.0 + ks))
|
| 83 |
+
|
| 84 |
+
# Off-topic: sims < mean - 1σ
|
| 85 |
+
off_thr = float(sims.mean() - sims.std())
|
| 86 |
+
off_idx = [i for i, s in enumerate(sims) if s < off_thr]
|
| 87 |
+
|
| 88 |
+
# Repetition: very high pairwise similarity, skip adjacent
|
| 89 |
+
S = E @ E.T if n > 1 else np.zeros((1,1), dtype=np.float32) # cosine due to normalization
|
| 90 |
+
rep_pairs = []
|
| 91 |
+
if n > 1:
|
| 92 |
+
for i in range(n):
|
| 93 |
+
for j in range(i+2, n): # skip adjacent
|
| 94 |
+
if S[i, j] >= 0.92: # threshold tunable
|
| 95 |
+
rep_pairs.append((i, j, float(S[i, j])))
|
| 96 |
+
|
| 97 |
+
return {"bbscore": round(bbscore, 3), "sims": sims, "off_idx": off_idx, "rep_pairs": rep_pairs, "sim_matrix": S}
|
| 98 |
+
|
| 99 |
+
# -----------------------------
|
| 100 |
+
# Zero-shot labeler (optional)
|
| 101 |
+
# -----------------------------
|
| 102 |
+
def zshot_label(text: str, topic: str = "the main topic") -> Dict[str, float]:
|
| 103 |
+
"""
|
| 104 |
+
Optional: zero-shot verdict to complement rule-based label.
|
| 105 |
+
Labels: Coherent, Off topic, Repeated
|
| 106 |
+
"""
|
| 107 |
+
try:
|
| 108 |
+
from transformers import pipeline
|
| 109 |
+
except Exception:
|
| 110 |
+
return {}
|
| 111 |
+
clf = pipeline("zero-shot-classification",
|
| 112 |
+
model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
|
| 113 |
+
multi_label=True)
|
| 114 |
+
labels = ["Coherent", "Off topic", "Repeated"]
|
| 115 |
+
res = clf(_norm(text), labels, hypothesis_template=f"This passage is {{}} with respect to {topic}.")
|
| 116 |
+
return {lbl: float(score) for lbl, score in zip(res["labels"], res["scores"])}
|
| 117 |
+
|
| 118 |
+
# -----------------------------
|
| 119 |
+
# Decision logic + reasons
|
| 120 |
+
# -----------------------------
|
| 121 |
+
def decide_label_with_reasons(
|
| 122 |
+
text: str,
|
| 123 |
+
topic_hint: Optional[str],
|
| 124 |
+
bb: Dict[str, Any],
|
| 125 |
+
sentences: List[str],
|
| 126 |
+
zshot_scores: Optional[Dict[str, float]] = None,
|
| 127 |
+
thresholds: Dict[str, float] = None
|
| 128 |
+
) -> Dict[str, Any]:
|
| 129 |
+
"""
|
| 130 |
+
Returns:
|
| 131 |
+
{
|
| 132 |
+
"label": "Coherent" | "Off topic" | "Repeated",
|
| 133 |
+
"reasons": [ "...", "..."],
|
| 134 |
+
"evidence": { "off_topic_examples": [...], "repeated_examples": [...] },
|
| 135 |
+
"bbscore": 0.74
|
| 136 |
+
}
|
| 137 |
+
"""
|
| 138 |
+
thr = thresholds or {
|
| 139 |
+
"bb_coherent_min": 0.65, # >= coherent
|
| 140 |
+
"off_topic_ratio_max": 0.20, # <= coherent
|
| 141 |
+
"repeat_pairs_min": 1 # >= repeated (if any)
|
| 142 |
+
}
|
| 143 |
+
n = max(1, len(sentences))
|
| 144 |
+
off_ratio = len(bb["off_idx"]) / n
|
| 145 |
+
has_repeat = len(bb["rep_pairs"]) >= thr["repeat_pairs_min"]
|
| 146 |
+
bbscore = bb["bbscore"]
|
| 147 |
+
|
| 148 |
+
# Rule-based primary decision
|
| 149 |
+
if off_ratio > thr["off_topic_ratio_max"] and bbscore < thr["bb_coherent_min"]:
|
| 150 |
+
label = "Off topic"
|
| 151 |
+
elif has_repeat and bbscore >= 0.5:
|
| 152 |
+
label = "Repeated"
|
| 153 |
+
elif bbscore >= thr["bb_coherent_min"] and off_ratio <= thr["off_topic_ratio_max"] and not has_repeat:
|
| 154 |
+
label = "Coherent"
|
| 155 |
+
else:
|
| 156 |
+
# Tie-breaker using zero-shot if provided
|
| 157 |
+
if zshot_scores:
|
| 158 |
+
label = max(zshot_scores.items(), key=lambda kv: kv[1])[0]
|
| 159 |
+
else:
|
| 160 |
+
# fallback: prefer coherence if bbscore okay, else off-topic
|
| 161 |
+
label = "Coherent" if bbscore >= 0.6 else "Off topic"
|
| 162 |
+
|
| 163 |
+
# Reasons
|
| 164 |
+
reasons = [f"BBScore={bbscore:.3f}."]
|
| 165 |
+
if bb["off_idx"]:
|
| 166 |
+
reasons.append(f"Off-topic fraction={off_ratio:.2f} ({len(bb['off_idx'])}/{n} sentences below main-idea similarity).")
|
| 167 |
+
if has_repeat:
|
| 168 |
+
top_rep = sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]
|
| 169 |
+
reasons.append(f"Repeated content detected (top sim={top_rep[0][2]:.2f}).")
|
| 170 |
+
|
| 171 |
+
if zshot_scores:
|
| 172 |
+
top = sorted(zshot_scores.items(), key=lambda kv: kv[1], reverse=True)[:2]
|
| 173 |
+
reasons.append("Zero-shot support: " + ", ".join([f"{k}={v:.2f}" for k,v in top]))
|
| 174 |
+
|
| 175 |
+
# Evidence snippets
|
| 176 |
+
ev_off = [f'{i}: "{sentences[i]}"' for i in bb["off_idx"][:2]]
|
| 177 |
+
ev_rep = []
|
| 178 |
+
for (i, j, sim) in sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]:
|
| 179 |
+
ev_rep.append(f'({i},{j}) sim={sim:.2f}: "{sentences[i]}", "{sentences[j]}"')
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
"label": label,
|
| 183 |
+
"reasons": reasons,
|
| 184 |
+
"evidence": {"off_topic_examples": ev_off, "repeated_examples": ev_rep},
|
| 185 |
+
"bbscore": bbscore
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
def _display_title(meta: Dict[str, Any], fallback: str) -> str:
|
| 189 |
+
if meta.get("title"): return str(meta["title"]).strip()
|
| 190 |
+
src = meta.get("source") or meta.get("path")
|
| 191 |
+
if src:
|
| 192 |
+
base = os.path.basename(str(src))
|
| 193 |
+
return re.sub(r"\.pdf$", "", base, flags=re.I)
|
| 194 |
+
return meta.get("doc_id") or fallback
|
| 195 |
+
|
| 196 |
+
def _page_label(meta: Dict[str, Any]) -> str:
|
| 197 |
+
return str(meta.get("page_label") or meta.get("page") or "?")
|
| 198 |
+
|
| 199 |
+
def to_std_doc(item: Any, idx: int = 0) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
Accepts a LangChain Document or dict; returns a standard dict:
|
| 202 |
+
{title, page_label, text}
|
| 203 |
+
"""
|
| 204 |
+
if hasattr(item, "page_content"): # LangChain Document
|
| 205 |
+
meta = getattr(item, "metadata", {}) or {}
|
| 206 |
+
return {
|
| 207 |
+
"title": _display_title(meta, f"doc{idx+1}"),
|
| 208 |
+
"page_label": _page_label(meta),
|
| 209 |
+
"text": _norm(item.page_content),
|
| 210 |
+
}
|
| 211 |
+
elif isinstance(item, dict):
|
| 212 |
+
meta = item.get("metadata", {}) or {}
|
| 213 |
+
title = item.get("title") or _display_title(meta, item.get("doc_id", f"doc{idx+1}"))
|
| 214 |
+
page = item.get("page_label") or _page_label(meta)
|
| 215 |
+
text = _norm(item.get("text") or item.get("page_content", ""))
|
| 216 |
+
return {"title": title, "page_label": page, "text": text}
|
| 217 |
+
else:
|
| 218 |
+
raise TypeError(f"Unsupported doc type at index {idx}: {type(item)}")
|
| 219 |
+
|
| 220 |
+
def coherence_assessment_std(
|
| 221 |
+
std_doc: Dict[str, Any],
|
| 222 |
+
embedder,
|
| 223 |
+
topic_hint: Optional[str] = None,
|
| 224 |
+
run_zero_shot: bool = False,
|
| 225 |
+
thresholds: Optional[Dict[str, float]] = None
|
| 226 |
+
) -> Dict[str, Any]:
|
| 227 |
+
"""Same as your coherence_assessment, but expects a standardized dict."""
|
| 228 |
+
text = std_doc.get("text", "")
|
| 229 |
+
sents = split_sentences(text)
|
| 230 |
+
if not sents:
|
| 231 |
+
return {"title": std_doc.get("title","Document"), "label": "Off topic", "bbscore": 0.0,
|
| 232 |
+
"reasons": ["Empty text."], "evidence": {}}
|
| 233 |
+
E = embedder.encode(sents)
|
| 234 |
+
bb = bb_coherence(sents, E)
|
| 235 |
+
zshot_scores = zshot_label(text, topic_hint) if run_zero_shot else None
|
| 236 |
+
decision = decide_label_with_reasons(text, topic_hint, bb, sents, zshot_scores, thresholds)
|
| 237 |
+
return {
|
| 238 |
+
"title": std_doc.get("title","Document"),
|
| 239 |
+
"page_label": std_doc.get("page_label","?"),
|
| 240 |
+
"label": decision["label"],
|
| 241 |
+
"bbscore": decision["bbscore"],
|
| 242 |
+
"reasons": decision["reasons"],
|
| 243 |
+
"evidence": decision["evidence"],
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
# Get the coherence report
|
| 247 |
+
def coherence_report(embedder="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
|
| 248 |
+
input_text=None,
|
| 249 |
+
reranked_results=None,
|
| 250 |
+
run_zero_shot=True):
|
| 251 |
+
embedder = Embedder(embedder) if isinstance(embedder, str) else embedder
|
| 252 |
+
if reranked_results is None:
|
| 253 |
+
reranked_results = retrieve_and_rerank(input_text)
|
| 254 |
+
if not reranked_results:
|
| 255 |
+
return []
|
| 256 |
+
# Convert reranked_results to standardized documents
|
| 257 |
+
std_results = [to_std_doc(doc, i) for i, doc in enumerate(reranked_results)]
|
| 258 |
+
reports = [coherence_assessment_std(d, embedder, topic_hint=input_text, run_zero_shot=run_zero_shot)
|
| 259 |
+
for d in std_results]
|
| 260 |
+
return reports
|
utils/encoding_input.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Methods to encode text
|
| 2 |
+
import numpy as np
|
| 3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 4 |
+
|
| 5 |
+
def encode_text(text, embedding_model='sentence-transformers/all-MiniLM-L6-v2', as_array=True):
|
| 6 |
+
"""Encodes the input text using the provided embedding model."""
|
| 7 |
+
embedding_model = HuggingFaceEmbeddings(model_name=embedding_model)
|
| 8 |
+
encoded_input = embedding_model.embed_query(text)
|
| 9 |
+
if as_array:
|
| 10 |
+
return np.array(encoded_input)
|
| 11 |
+
else:
|
| 12 |
+
return encoded_input
|
utils/generation_streaming.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 2 |
+
# from langchain_community.embeddings import CrossEncoder
|
| 3 |
+
import requests
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
# encode the text
|
| 9 |
+
from encoding_input import encode_text
|
| 10 |
+
|
| 11 |
+
# rertrieve and rerank the documents
|
| 12 |
+
from retrieve_n_rerank import retrieve_and_rerank
|
| 13 |
+
|
| 14 |
+
# sentiment analysis on reranked documents
|
| 15 |
+
from sentiment_analysis import get_sentiment
|
| 16 |
+
|
| 17 |
+
# coherence assessment reports
|
| 18 |
+
from coherence_bbscore import coherence_report
|
| 19 |
+
|
| 20 |
+
# Get the vectorstore
|
| 21 |
+
from loading_embeddings import get_vectorstore
|
| 22 |
+
vectorstore = get_vectorstore()
|
| 23 |
+
|
| 24 |
+
# build message from model generation
|
| 25 |
+
from model_generation import build_messages
|
| 26 |
+
|
| 27 |
+
API_KEY = "sk-do-"
|
| 28 |
+
MODEL = "llama3.3-70b-instruct"
|
| 29 |
+
|
| 30 |
+
def generate_response_stream(query: str, enable_sentiment: bool, enable_coherence: bool):
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# encoded_input = encode_text(query)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
reranked_results = retrieve_and_rerank(
|
| 39 |
+
query_text=query,
|
| 40 |
+
vectorstore=vectorstore,
|
| 41 |
+
k=50, # number of initial documents to retrieve
|
| 42 |
+
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 43 |
+
top_m=20, # number of documents to return after reranking
|
| 44 |
+
min_score=0.5, # minimum score for reranked documents
|
| 45 |
+
only_docs=False # return both documents and scores
|
| 46 |
+
)
|
| 47 |
+
top_docs = [doc for doc, score in reranked_results]
|
| 48 |
+
|
| 49 |
+
if not top_docs:
|
| 50 |
+
yield "No relevant documents found."
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
sentiment_rollup = get_sentiment(top_docs) if enable_sentiment else {}
|
| 54 |
+
coherence_report_ = coherence_report(reranked_results=top_docs, input_text= query) if enable_coherence else ""
|
| 55 |
+
|
| 56 |
+
messages = build_messages(
|
| 57 |
+
query=query,
|
| 58 |
+
top_docs=top_docs,
|
| 59 |
+
task_mode="verbatim_sentiment",
|
| 60 |
+
sentiment_rollup=sentiment_rollup,
|
| 61 |
+
coherence_report=coherence_report_,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
headers = {
|
| 65 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 66 |
+
"Content-Type": "application/json"
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
data = {
|
| 70 |
+
"model": MODEL,
|
| 71 |
+
"messages": messages,
|
| 72 |
+
"temperature": 0.2,
|
| 73 |
+
"stream": True,
|
| 74 |
+
"max_tokens": 2000
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
collected = "" # Accumulate content to show
|
| 78 |
+
|
| 79 |
+
with requests.post("https://inference.do-ai.run/v1/chat/completions", headers=headers, json=data, stream=True) as r:
|
| 80 |
+
if r.status_code != 200:
|
| 81 |
+
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
for line in r.iter_lines(decode_unicode=True):
|
| 85 |
+
if not line or line.strip() == "data: [DONE]":
|
| 86 |
+
continue
|
| 87 |
+
if line.startswith("data: "):
|
| 88 |
+
line = line[len("data: "):]
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
chunk = json.loads(line)
|
| 92 |
+
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 93 |
+
if delta:
|
| 94 |
+
collected += delta
|
| 95 |
+
yield collected # yield progressively
|
| 96 |
+
time.sleep(0.01) # slight throttle to improve smoothness
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print("Streaming decode error:", e)
|
| 99 |
+
continue
|
utils/loading_embeddings.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Loading embeddings from storage
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from langchain_community.vectorstores import FAISS
|
| 6 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 7 |
+
|
| 8 |
+
# download it at the data directory
|
| 9 |
+
data_path = os.path.join(Path(os.getcwd()).parent, "data")
|
| 10 |
+
# make the faiss local folder
|
| 11 |
+
local_folder = os.path.join(data_path, 'faiss_index')
|
| 12 |
+
|
| 13 |
+
def download_faiss_index(repo_id="kaburia/epic-a-embeddings", local_folder="faiss_index"):
|
| 14 |
+
|
| 15 |
+
os.makedirs(local_folder, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
index_faiss_path = os.path.join(local_folder, "index.faiss")
|
| 19 |
+
index_pkl_path = os.path.join(local_folder, "index.pkl")
|
| 20 |
+
|
| 21 |
+
if not os.path.exists(index_faiss_path):
|
| 22 |
+
print("Downloading index.faiss from Hugging Face Dataset...")
|
| 23 |
+
hf_hub_download(
|
| 24 |
+
repo_id=repo_id,
|
| 25 |
+
filename="index.faiss",
|
| 26 |
+
repo_type="dataset",
|
| 27 |
+
local_dir=local_folder,
|
| 28 |
+
local_dir_use_symlinks=False,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if not os.path.exists(index_pkl_path):
|
| 32 |
+
print("Downloading index.pkl from Hugging Face Dataset...")
|
| 33 |
+
hf_hub_download(
|
| 34 |
+
repo_id=repo_id,
|
| 35 |
+
filename="index.pkl",
|
| 36 |
+
repo_type="dataset",
|
| 37 |
+
local_dir=local_folder,
|
| 38 |
+
local_dir_use_symlinks=False,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def load_vectorstore(index_path="faiss_index"):
|
| 42 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 43 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 44 |
+
)
|
| 45 |
+
db = FAISS.load_local(
|
| 46 |
+
index_path,
|
| 47 |
+
embeddings=embedding_model,
|
| 48 |
+
allow_dangerous_deserialization=True
|
| 49 |
+
)
|
| 50 |
+
return db
|
| 51 |
+
|
| 52 |
+
# download and load vectorstore
|
| 53 |
+
def get_vectorstore(repo_id="kaburia/epic-a-embeddings", local_folder="faiss_index"):
|
| 54 |
+
download_faiss_index(repo_id=repo_id, local_folder=local_folder)
|
| 55 |
+
return load_vectorstore(index_path=local_folder)
|
utils/model_generation.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import requests
|
| 3 |
+
from typing import List, Dict, Any, Union
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
PROMPT_TEMPLATES = {
|
| 14 |
+
"verbatim_sentiment": {
|
| 15 |
+
"system": (
|
| 16 |
+
"You are a compliance-grade policy analyst assistant. "
|
| 17 |
+
"Your job is to return a precise, fact-grounded response. "
|
| 18 |
+
"Avoid hallucinations. Base everything strictly on the content provided."
|
| 19 |
+
"if the coherence and or sentiment analysis is not enabled, do not mention it in the response."
|
| 20 |
+
),
|
| 21 |
+
"user_template": """
|
| 22 |
+
Query: {query}
|
| 23 |
+
|
| 24 |
+
Deliverables:
|
| 25 |
+
1) **Quoted Policy Excerpts**: Quote key policy content directly. Cite the source using filename and page.
|
| 26 |
+
2) **Sentiment Summary**: Use the sentiment JSON to explain tone, gaps, penalties, or enforcement clarity in plain English.
|
| 27 |
+
3) **Coherence Assessment**: Summarize the coherence report below. Highlight:
|
| 28 |
+
- Whether the answer was mostly on-topic or off-topic
|
| 29 |
+
- point out the sections that were coherent, off topic and repeated
|
| 30 |
+
|
| 31 |
+
Topic hint: {topic_hint}
|
| 32 |
+
|
| 33 |
+
Sentiment JSON (rolled-up across top docs):
|
| 34 |
+
{sentiment_json}
|
| 35 |
+
|
| 36 |
+
Coherence report:
|
| 37 |
+
{coherence_report}
|
| 38 |
+
|
| 39 |
+
Context Sources:
|
| 40 |
+
{context_block}
|
| 41 |
+
"""
|
| 42 |
+
},
|
| 43 |
+
|
| 44 |
+
"abstractive_summary": {
|
| 45 |
+
"system": (
|
| 46 |
+
"You are a policy analyst summarizing government documents for a general audience. "
|
| 47 |
+
"Your response should paraphrase clearly, avoiding quotes unless absolutely necessary. "
|
| 48 |
+
"Highlight high-level goals, enforcement strategies, and important deadlines or penalties."
|
| 49 |
+
),
|
| 50 |
+
"user_template": """Query: {query}
|
| 51 |
+
|
| 52 |
+
Summarize the answer in natural, non-technical language. Emphasize clarity and coverage. Avoid quoting unless the phrase is legally binding.
|
| 53 |
+
|
| 54 |
+
Topic hint: {topic_hint}
|
| 55 |
+
|
| 56 |
+
Context DOCS:
|
| 57 |
+
{context_block}
|
| 58 |
+
"""
|
| 59 |
+
},
|
| 60 |
+
|
| 61 |
+
"followup_reasoning": {
|
| 62 |
+
"system": (
|
| 63 |
+
"You are an assistant that explains policy documents interactively, reasoning step-by-step. "
|
| 64 |
+
"Always cite document IDs and indicate if certain info is absent."
|
| 65 |
+
),
|
| 66 |
+
"user_template": """User query: {query}
|
| 67 |
+
|
| 68 |
+
Explain the answer step-by-step. Add follow-up questions that a reader might ask, and try to answer them using the documents below.
|
| 69 |
+
|
| 70 |
+
Topic: {topic_hint}
|
| 71 |
+
|
| 72 |
+
DOCS:
|
| 73 |
+
{context_block}
|
| 74 |
+
"""
|
| 75 |
+
},
|
| 76 |
+
|
| 77 |
+
# Add more templates as needed
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# --- LLM client ---
|
| 82 |
+
def get_do_completion(api_key, model_name, messages, temperature=0.2, max_tokens=800):
|
| 83 |
+
url = "https://inference.do-ai.run/v1/chat/completions"
|
| 84 |
+
headers = {
|
| 85 |
+
"Authorization": f"Bearer {api_key}",
|
| 86 |
+
"Content-Type": "application/json"
|
| 87 |
+
}
|
| 88 |
+
data = {
|
| 89 |
+
"model": model_name,
|
| 90 |
+
"messages": messages,
|
| 91 |
+
"temperature": temperature,
|
| 92 |
+
"max_tokens": max_tokens
|
| 93 |
+
}
|
| 94 |
+
try:
|
| 95 |
+
resp = requests.post(url, headers=headers, json=data, timeout=90)
|
| 96 |
+
resp.raise_for_status()
|
| 97 |
+
return resp.json()
|
| 98 |
+
except requests.exceptions.HTTPError as e:
|
| 99 |
+
print(f"HTTP error occurred: {e}")
|
| 100 |
+
print(f"Response body: {e.response.text if e.response is not None else ''}")
|
| 101 |
+
return None
|
| 102 |
+
except requests.exceptions.RequestException as e:
|
| 103 |
+
print(f"Request error: {e}")
|
| 104 |
+
return None
|
| 105 |
+
except json.JSONDecodeError as e:
|
| 106 |
+
print(f"Failed to decode JSON: {e}")
|
| 107 |
+
print(f"Response text: {resp.text if 'resp' in locals() else ''}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# --- Prompt context builder ---
|
| 113 |
+
def _clip(text: str, max_chars: int = 1400) -> str:
|
| 114 |
+
"""Trim content to limit prompt size."""
|
| 115 |
+
if not text:
|
| 116 |
+
return ""
|
| 117 |
+
text = str(text).strip()
|
| 118 |
+
return text[:max_chars] + ("..." if len(text) > max_chars else "")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_context_block(top_docs: List[Dict[str, Any]]) -> str:
|
| 122 |
+
"""
|
| 123 |
+
Formats each document with real citation:
|
| 124 |
+
- Extracts file name from 'source' path
|
| 125 |
+
- Uses 'page_label' or falls back to 'page'
|
| 126 |
+
- Returns: <<<SOURCE: {filename}, p. {page_label}>>>
|
| 127 |
+
"""
|
| 128 |
+
blocks = []
|
| 129 |
+
for i, item in enumerate(top_docs):
|
| 130 |
+
if hasattr(item, "page_content"):
|
| 131 |
+
text = item.page_content
|
| 132 |
+
meta = getattr(item, "metadata", {})
|
| 133 |
+
else:
|
| 134 |
+
text = item.get("text") or item.get("page_content", "")
|
| 135 |
+
meta = item.get("metadata", {})
|
| 136 |
+
|
| 137 |
+
# Get file name from path
|
| 138 |
+
full_path = meta.get("source", "")
|
| 139 |
+
filename = os.path.basename(full_path) if full_path else f"Document_{i+1}"
|
| 140 |
+
|
| 141 |
+
# Prefer page_label if available, else fallback to raw page
|
| 142 |
+
page_label = meta.get("page_label") or meta.get("page") or "unknown"
|
| 143 |
+
|
| 144 |
+
citation = f"{filename}, p. {page_label}"
|
| 145 |
+
|
| 146 |
+
blocks.append(f"<<<SOURCE: {citation}>>>\n{_clip(text)}\n</SOURCE>")
|
| 147 |
+
|
| 148 |
+
return "\n".join(blocks)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# --- Message builder ---
|
| 152 |
+
def build_messages(
|
| 153 |
+
query: str,
|
| 154 |
+
top_docs: List[Dict[str, Any]],
|
| 155 |
+
task_mode: str,
|
| 156 |
+
sentiment_rollup: Dict[str, List[str]],
|
| 157 |
+
coherence_report: str = "",
|
| 158 |
+
topic_hint: str = "energy policy"
|
| 159 |
+
) -> List[Dict[str, str]]:
|
| 160 |
+
template = PROMPT_TEMPLATES.get(task_mode)
|
| 161 |
+
if not template:
|
| 162 |
+
raise ValueError(f"Unknown task mode: {task_mode}")
|
| 163 |
+
|
| 164 |
+
context_block = build_context_block(top_docs)
|
| 165 |
+
sentiment_json = json.dumps(sentiment_rollup or {}, ensure_ascii=False)
|
| 166 |
+
|
| 167 |
+
user_prompt = template["user_template"].format(
|
| 168 |
+
query=query,
|
| 169 |
+
topic_hint=topic_hint,
|
| 170 |
+
sentiment_json=sentiment_json,
|
| 171 |
+
context_block=context_block,
|
| 172 |
+
coherence_report=coherence_report
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return [
|
| 176 |
+
{"role": "system", "content": template["system"]},
|
| 177 |
+
{"role": "user", "content": user_prompt}
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# --- Generation orchestrator ---
|
| 182 |
+
def generate_policy_answer(
|
| 183 |
+
api_key: str,
|
| 184 |
+
model_name: str,
|
| 185 |
+
query: str,
|
| 186 |
+
top_docs: List[Union[Dict[str, Any], Any]],
|
| 187 |
+
sentiment_rollup: Dict[str, List[str]],
|
| 188 |
+
coherence_report: str = "",
|
| 189 |
+
task_mode: str = "verbatim_sentiment",
|
| 190 |
+
temperature: float = 0.2,
|
| 191 |
+
max_tokens: int = 2000
|
| 192 |
+
) -> str:
|
| 193 |
+
if not top_docs:
|
| 194 |
+
return "No documents available to answer."
|
| 195 |
+
|
| 196 |
+
messages = build_messages(
|
| 197 |
+
query=query,
|
| 198 |
+
top_docs=top_docs,
|
| 199 |
+
task_mode=task_mode,
|
| 200 |
+
sentiment_rollup=sentiment_rollup,
|
| 201 |
+
coherence_report=coherence_report
|
| 202 |
+
)
|
| 203 |
+
resp = get_do_completion(api_key, model_name, messages, temperature=temperature, max_tokens=max_tokens)
|
| 204 |
+
if resp is None:
|
| 205 |
+
return "Upstream model error. No response."
|
| 206 |
+
try:
|
| 207 |
+
return resp["choices"][0]["message"]["content"].strip()
|
| 208 |
+
except Exception:
|
| 209 |
+
return json.dumps(resp, indent=2)
|
| 210 |
+
|
| 211 |
+
|
utils/retrieve_n_rerank.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# load the encoded text and vectorstore
|
| 2 |
+
from encoding_input import encode_text
|
| 3 |
+
from loading_embeddings import get_vectorstore
|
| 4 |
+
from sentence_transformers import CrossEncoder
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss
|
| 7 |
+
|
| 8 |
+
def search_vectorstore(encoded_text, vectorstore, k=5, with_score=False):
|
| 9 |
+
"""
|
| 10 |
+
Vector similarity search with optional distance/score return.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
encoded_text (np.ndarray | list): 1-D vector.
|
| 14 |
+
vectorstore (langchain.vectorstores.faiss.FAISS): your store.
|
| 15 |
+
k (int): # of neighbors.
|
| 16 |
+
with_score (bool): toggle score output.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
list: docs or (doc, score) tuples.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
q = np.asarray(encoded_text, dtype="float32").reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
# ---- Use raw FAISS for full control and consistent behavior-------
|
| 25 |
+
index = vectorstore.index # faiss.Index
|
| 26 |
+
distances, idxs = index.search(q, k) # (1, k) each
|
| 27 |
+
docs = [vectorstore.docstore.search(
|
| 28 |
+
vectorstore.index_to_docstore_id[i]) for i in idxs[0]]
|
| 29 |
+
|
| 30 |
+
# Return with or without scores
|
| 31 |
+
return list(zip(docs, distances[0])) if with_score else docs
|
| 32 |
+
|
| 33 |
+
def rerank_cross_encoder(query_text, docs, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", top_m=20, min_score=None):
|
| 34 |
+
"""
|
| 35 |
+
Returns top_m (doc, score) sorted by score desc. If min_score is set, filters below it.
|
| 36 |
+
docs: A list of Document objects.
|
| 37 |
+
"""
|
| 38 |
+
ce = CrossEncoder(model_name)
|
| 39 |
+
# Create pairs of (query_text, document_content)
|
| 40 |
+
pairs = [(query_text, doc.page_content) for doc in docs] # Use doc.page_content for the text
|
| 41 |
+
scores = ce.predict(pairs) # higher is better
|
| 42 |
+
|
| 43 |
+
# Pair original documents with their scores and sort
|
| 44 |
+
scored_documents = sorted(zip(docs, scores.tolist()), key=lambda x: x[1], reverse=True)
|
| 45 |
+
|
| 46 |
+
# Apply minimum score filter if specified
|
| 47 |
+
if min_score is not None:
|
| 48 |
+
scored_documents = [r for r in scored_documents if r[1] >= min_score]
|
| 49 |
+
|
| 50 |
+
# Return the top_m reranked (Document, score) tuples
|
| 51 |
+
return scored_documents[:top_m]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# retrieval and reranking function
|
| 55 |
+
def retrieve_and_rerank(query_text, vectorstore, k=50,
|
| 56 |
+
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 57 |
+
top_m=20, min_score=None,
|
| 58 |
+
only_docs=True):
|
| 59 |
+
# Step 1: Encode the query text
|
| 60 |
+
encoded_query = encode_text(query_text)
|
| 61 |
+
|
| 62 |
+
# Step 2: Retrieve relevant documents from the vectorstore
|
| 63 |
+
retrieved_docs = search_vectorstore(encoded_query, vectorstore, k=k)
|
| 64 |
+
|
| 65 |
+
# get only the documents
|
| 66 |
+
retrieved_docs = [doc for doc, _ in retrieved_docs] if isinstance(retrieved_docs[0], tuple) else retrieved_docs
|
| 67 |
+
|
| 68 |
+
# If no documents are retrieved, return an empty list
|
| 69 |
+
if not retrieved_docs:
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
# Step 3: Rerank the retrieved documents
|
| 73 |
+
reranked_docs = rerank_cross_encoder(query_text, retrieved_docs, model_name=rerank_model, top_m=top_m, min_score=min_score)
|
| 74 |
+
|
| 75 |
+
# If only_docs is True, return just the documents
|
| 76 |
+
if only_docs:
|
| 77 |
+
return [doc for doc, _ in reranked_docs]
|
| 78 |
+
# Otherwise, return the reranked documents with their scores
|
| 79 |
+
return reranked_docs
|
utils/sentiment_analysis.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re, math, torch
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
|
| 4 |
+
# ------------- Model (CPU-friendly); use device=0 + fp16 on GPU -------------
|
| 5 |
+
ZSHOT = pipeline(
|
| 6 |
+
"zero-shot-classification",
|
| 7 |
+
model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
|
| 8 |
+
multi_label=True,
|
| 9 |
+
device=-1,
|
| 10 |
+
model_kwargs={"torch_dtype": torch.float32}
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# ------------------ Taxonomy with descriptions (helps NLI) -------------------
|
| 14 |
+
TAXO = {
|
| 15 |
+
"intent_type": [
|
| 16 |
+
"objective: declares goals or aims",
|
| 17 |
+
"principle: states guiding values",
|
| 18 |
+
"strategy: outlines measures or actions",
|
| 19 |
+
"obligation: mandates an action (shall/must)",
|
| 20 |
+
"prohibition: forbids an action",
|
| 21 |
+
"permission: allows an action (may)",
|
| 22 |
+
"exception: states conditions where rules change",
|
| 23 |
+
"definition: defines a term",
|
| 24 |
+
"scope: states applicability or coverage"
|
| 25 |
+
],
|
| 26 |
+
"disposition": [
|
| 27 |
+
"restrictive: limits or constrains the topic",
|
| 28 |
+
"cautionary: warns or urges care",
|
| 29 |
+
"neutral: descriptive with no clear stance",
|
| 30 |
+
"enabling: allows or facilitates the topic",
|
| 31 |
+
"supportive: promotes or expands the topic"
|
| 32 |
+
],
|
| 33 |
+
"rigidity": [
|
| 34 |
+
"must: mandatory (shall/must)",
|
| 35 |
+
"should: advisory (should)",
|
| 36 |
+
"may: permissive (may/can)"
|
| 37 |
+
],
|
| 38 |
+
"temporal": [
|
| 39 |
+
"deadline: requires completion by a date or period",
|
| 40 |
+
"schedule: sets a cadence (e.g., annually, quarterly)",
|
| 41 |
+
"ongoing: continuing requirement without end date",
|
| 42 |
+
"effective_date: specifies when rules start/apply"
|
| 43 |
+
],
|
| 44 |
+
"scope": [
|
| 45 |
+
"actor_specific: targets a group or entity (e.g., county governments, permit holders)",
|
| 46 |
+
"geography_specific: targets a location or region",
|
| 47 |
+
"subject_specific: targets a topic (e.g., permits, sanitation)",
|
| 48 |
+
"nationwide: applies across the country"
|
| 49 |
+
],
|
| 50 |
+
"enforcement": [
|
| 51 |
+
"penalty: fines or sanctions for non-compliance",
|
| 52 |
+
"remedy: corrective actions required",
|
| 53 |
+
"monitoring: oversight or audits",
|
| 54 |
+
"reporting: reports/returns required",
|
| 55 |
+
"none_detected: no enforcement mechanisms present"
|
| 56 |
+
],
|
| 57 |
+
"resourcing": [
|
| 58 |
+
"funding: funds or budget allocations",
|
| 59 |
+
"fees_levies: charges or levies",
|
| 60 |
+
"capacity_hr: staffing or training",
|
| 61 |
+
"infrastructure: capital works or equipment",
|
| 62 |
+
"none_detected: no resourcing present"
|
| 63 |
+
],
|
| 64 |
+
"impact": [
|
| 65 |
+
"low: limited effect on regulated parties",
|
| 66 |
+
"medium: moderate practical effect",
|
| 67 |
+
"high: significant obligations or restrictions"
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# ---------------- Axis-specific thresholds (calibrate later) -----------------
|
| 72 |
+
TAU = {
|
| 73 |
+
"intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60,
|
| 74 |
+
"temporal": 0.62, "scope": 0.55,
|
| 75 |
+
"enforcement": 0.50, "resourcing": 0.50, "impact": 0.60
|
| 76 |
+
}
|
| 77 |
+
TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected"
|
| 78 |
+
|
| 79 |
+
# ------------------------- Cleaning & evidence rules -------------------------
|
| 80 |
+
def _clean(t: str) -> str:
|
| 81 |
+
t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t))
|
| 82 |
+
t = re.sub(r"\s{2,}", " ", t).strip()
|
| 83 |
+
return t
|
| 84 |
+
|
| 85 |
+
PAT = {
|
| 86 |
+
"actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b",
|
| 87 |
+
"nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b",
|
| 88 |
+
"objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]",
|
| 89 |
+
"imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)",
|
| 90 |
+
"modal_must": r"\bshall\b|\bmust\b",
|
| 91 |
+
"modal_should": r"\bshould\b",
|
| 92 |
+
"modal_may": r"\bmay\b|\bcan\b",
|
| 93 |
+
"temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b",
|
| 94 |
+
"enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b",
|
| 95 |
+
"resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b"
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def _spans(text, pattern, max_spans=2):
|
| 99 |
+
spans = []
|
| 100 |
+
for m in re.finditer(pattern, text, flags=re.I):
|
| 101 |
+
# sentence-level extraction
|
| 102 |
+
start = text.rfind('.', 0, m.start()) + 1
|
| 103 |
+
end = text.find('.', m.end())
|
| 104 |
+
if end == -1: end = len(text)
|
| 105 |
+
snippet = text[start:end].strip()
|
| 106 |
+
if snippet and snippet not in spans:
|
| 107 |
+
spans.append(snippet)
|
| 108 |
+
if len(spans) >= max_spans: break
|
| 109 |
+
return spans
|
| 110 |
+
|
| 111 |
+
def _softmax(d):
|
| 112 |
+
vals = list(d.values())
|
| 113 |
+
if not vals: return {k: 0.0 for k in d}
|
| 114 |
+
m = max(vals)
|
| 115 |
+
exps = [math.exp(v - m) for v in vals]
|
| 116 |
+
Z = sum(exps)
|
| 117 |
+
return {k: (e / Z) for k, e in zip(d.keys(), exps)}
|
| 118 |
+
|
| 119 |
+
# -------------------- Main: classify + explanations + % ----------------------
|
| 120 |
+
def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2):
|
| 121 |
+
text = _clean(text)
|
| 122 |
+
if not text:
|
| 123 |
+
return {"decision_summary": "No operative decision; empty passage.",
|
| 124 |
+
"labels": {ax: [] for ax in TAXO},
|
| 125 |
+
"percents_raw": {ax: {} for ax in TAXO},
|
| 126 |
+
"percents_norm": {ax: {} for ax in TAXO},
|
| 127 |
+
"why": [], "text_preview": ""}
|
| 128 |
+
|
| 129 |
+
# Topic-aware hypotheses (improves stance/intent)
|
| 130 |
+
def hyp(axis):
|
| 131 |
+
base = "This passage {} regarding " + topic + "."
|
| 132 |
+
return {
|
| 133 |
+
"intent_type": base.format("states a {}"),
|
| 134 |
+
"disposition": base.format("is {}"),
|
| 135 |
+
"rigidity": "Compliance in this passage is {}.",
|
| 136 |
+
"temporal": base.format("specifies a {} aspect"),
|
| 137 |
+
"scope": base.format("is {} in applicability"),
|
| 138 |
+
"enforcement": base.format("includes {} for compliance"),
|
| 139 |
+
"resourcing": base.format("provides {}"),
|
| 140 |
+
"impact": base.format("has {} impact")
|
| 141 |
+
}[axis]
|
| 142 |
+
|
| 143 |
+
# Single call if supported; else per-axis fallback
|
| 144 |
+
tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)}
|
| 145 |
+
for axis, labels in TAXO.items()]
|
| 146 |
+
try:
|
| 147 |
+
results = ZSHOT(tasks)
|
| 148 |
+
except TypeError:
|
| 149 |
+
results = [ZSHOT(text, labels, hypothesis_template=hyp(axis))
|
| 150 |
+
for axis, labels in TAXO.items()]
|
| 151 |
+
|
| 152 |
+
labels_out, perc_raw, perc_norm, why = {}, {}, {}, []
|
| 153 |
+
|
| 154 |
+
for (axis, labels), r in zip(TAXO.items(), results):
|
| 155 |
+
# raw scores
|
| 156 |
+
raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])}
|
| 157 |
+
perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid
|
| 158 |
+
norm = _softmax(raw)
|
| 159 |
+
perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100%
|
| 160 |
+
|
| 161 |
+
# select labels by threshold
|
| 162 |
+
keep = [k for k, s in raw.items() if s >= TAU[axis]]
|
| 163 |
+
keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k]
|
| 164 |
+
# only emit none_detected when everything else is weak and no heuristic evidence
|
| 165 |
+
if not keep and "none_detected" in raw:
|
| 166 |
+
if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW:
|
| 167 |
+
keep = ["none_detected"]
|
| 168 |
+
|
| 169 |
+
labels_out[axis] = keep
|
| 170 |
+
|
| 171 |
+
# compact "why" with evidence for the top choice
|
| 172 |
+
if keep and keep[0] != "none_detected":
|
| 173 |
+
if axis == "intent_type":
|
| 174 |
+
ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"])
|
| 175 |
+
why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]})
|
| 176 |
+
elif axis == "disposition":
|
| 177 |
+
ev = _spans(text, PAT["imperative"])
|
| 178 |
+
why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]})
|
| 179 |
+
elif axis == "rigidity":
|
| 180 |
+
pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]]
|
| 181 |
+
why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]})
|
| 182 |
+
elif axis == "temporal":
|
| 183 |
+
why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]})
|
| 184 |
+
elif axis == "scope":
|
| 185 |
+
ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"])
|
| 186 |
+
why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]})
|
| 187 |
+
elif axis == "enforcement":
|
| 188 |
+
why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]})
|
| 189 |
+
elif axis == "resourcing":
|
| 190 |
+
why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]})
|
| 191 |
+
|
| 192 |
+
# Decision summary: imperative lines + problem statements; never fabricate
|
| 193 |
+
summary_bits = []
|
| 194 |
+
imperatives = re.findall(PAT["imperative"], text, flags=re.I)
|
| 195 |
+
# pull full imperative sentences
|
| 196 |
+
imp_sents = _spans(text, PAT["imperative"], max_spans=3)
|
| 197 |
+
if imp_sents:
|
| 198 |
+
summary_bits.append("Strategies: " + " ".join(imp_sents))
|
| 199 |
+
if "nationwide" in labels_out.get("scope", []):
|
| 200 |
+
summary_bits.append("Applies nationwide.")
|
| 201 |
+
if labels_out.get("enforcement") == ["none_detected"]:
|
| 202 |
+
summary_bits.append("Enforcement: none detected in this passage.")
|
| 203 |
+
if labels_out.get("resourcing") == ["none_detected"]:
|
| 204 |
+
summary_bits.append("Resourcing: none detected in this passage.")
|
| 205 |
+
decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected."
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"decision_summary": decision_summary,
|
| 209 |
+
"labels": labels_out,
|
| 210 |
+
"percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100)
|
| 211 |
+
"percents_norm": perc_norm, # normalized per axis (sums to ~100)
|
| 212 |
+
"why": why,
|
| 213 |
+
"text_preview": text[:300] + ("..." if len(text) > 300 else "")
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
# Get the sentiment for all the docs
|
| 217 |
+
def get_sentiment(texts):
|
| 218 |
+
return [classify_and_explain(texts[i].page_content) for i in range(len(texts))]
|