kaburia commited on
Commit
ef26a79
·
1 Parent(s): 1208e23

redesigned modules

Browse files
utils/coherence_bbscore.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install sentence-transformers (if not already)
2
+ import math, re, unicodedata
3
+ from typing import List, Dict, Any, Optional, Tuple
4
+ import numpy as np
5
+ import os, re, unicodedata, numpy as np
6
+ # get the reranked results with no scores
7
+ from retrieve_n_rerank import retrieve_and_rerank
8
+
9
+ try:
10
+ from sentence_transformers import SentenceTransformer
11
+ except Exception:
12
+ SentenceTransformer = None
13
+
14
+ # -----------------------------
15
+ # Text utilities
16
+ # -----------------------------
17
+ def _norm(t: str) -> str:
18
+ if t is None: return ""
19
+ t = unicodedata.normalize("NFKC", str(t))
20
+ t = re.sub(r"\s*\n\s*", " ", t)
21
+ t = re.sub(r"\s{2,}", " ", t)
22
+ return t.strip()
23
+
24
+ def split_sentences(text: str) -> List[str]:
25
+ t = _norm(text)
26
+ parts = re.split(r"(?<=[\.\?\!])\s+(?=[A-Z“\"'])", t)
27
+ return [p.strip() for p in parts if p.strip()]
28
+
29
+ # -----------------------------
30
+ # Embeddings wrapper
31
+ # -----------------------------
32
+ class Embedder:
33
+ def __init__(self, model_name: str = "BAAI/bge-m3", device: str = "cpu"):
34
+ if SentenceTransformer is None:
35
+ raise RuntimeError("Install sentence-transformers to enable coherence scoring.")
36
+ self.model = SentenceTransformer(model_name, device=device)
37
+ def encode(self, sentences: List[str]) -> np.ndarray:
38
+ if not sentences:
39
+ return np.zeros((0, 768), dtype=np.float32)
40
+ X = self.model.encode(sentences, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
41
+ return np.asarray(X, dtype=np.float32)
42
+
43
+ def _cos(a: np.ndarray, b: np.ndarray) -> float:
44
+ return float(np.dot(a, b))
45
+
46
+ def _normalize(v: np.ndarray) -> np.ndarray:
47
+ v = np.asarray(v, dtype=np.float32)
48
+ n = np.linalg.norm(v) + 1e-8
49
+ return v / n
50
+
51
+ # -----------------------------
52
+ # Brownian-bridge style metric
53
+ # -----------------------------
54
+ def bb_coherence(sentences: List[str], E: np.ndarray) -> Dict[str, Any]:
55
+ """
56
+ Brownian-bridge–inspired coherence:
57
+ - Build a main-idea vector (intro+outro+centroid)
58
+ - Compare per-sentence sim to target curve that's high at ends, lower mid
59
+ - Map max bridge deviation -> (0,1] score (higher=more coherent)
60
+ """
61
+ n = len(sentences)
62
+ if n == 0:
63
+ return {"bbscore": 0.0, "sims": [], "off_idx": [], "rep_pairs": [], "sim_matrix": None}
64
+
65
+ k = max(1, min(3, n // 5))
66
+ v_first = E[:k].mean(axis=0)
67
+ v_last = E[-k:].mean(axis=0)
68
+ v_all = E.mean(axis=0)
69
+ v_main = _normalize(0.4*v_first + 0.4*v_last + 0.2*v_all)
70
+
71
+ sims = np.array([_cos(v_main, E[i]) for i in range(n)], dtype=np.float32)
72
+ t = np.linspace(0.0, 1.0, num=n, dtype=np.float32)
73
+ q = 1.0 - 4.0 * t * (1.0 - t) # peaks at ends
74
+ q = q / (q.mean() + 1e-8) * (sims.mean() if sims.size else 0.0)
75
+
76
+ r = sims - q
77
+ r_centered = r - r.mean()
78
+ cumsum = np.cumsum(r_centered)
79
+ B = cumsum - t * (cumsum[-1] if n > 1 else 0.0)
80
+ denom = (np.std(r_centered) * math.sqrt(n)) + 1e-8
81
+ ks = float(np.max(np.abs(B)) / denom)
82
+ bbscore = float(1.0 / (1.0 + ks))
83
+
84
+ # Off-topic: sims < mean - 1σ
85
+ off_thr = float(sims.mean() - sims.std())
86
+ off_idx = [i for i, s in enumerate(sims) if s < off_thr]
87
+
88
+ # Repetition: very high pairwise similarity, skip adjacent
89
+ S = E @ E.T if n > 1 else np.zeros((1,1), dtype=np.float32) # cosine due to normalization
90
+ rep_pairs = []
91
+ if n > 1:
92
+ for i in range(n):
93
+ for j in range(i+2, n): # skip adjacent
94
+ if S[i, j] >= 0.92: # threshold tunable
95
+ rep_pairs.append((i, j, float(S[i, j])))
96
+
97
+ return {"bbscore": round(bbscore, 3), "sims": sims, "off_idx": off_idx, "rep_pairs": rep_pairs, "sim_matrix": S}
98
+
99
+ # -----------------------------
100
+ # Zero-shot labeler (optional)
101
+ # -----------------------------
102
+ def zshot_label(text: str, topic: str = "the main topic") -> Dict[str, float]:
103
+ """
104
+ Optional: zero-shot verdict to complement rule-based label.
105
+ Labels: Coherent, Off topic, Repeated
106
+ """
107
+ try:
108
+ from transformers import pipeline
109
+ except Exception:
110
+ return {}
111
+ clf = pipeline("zero-shot-classification",
112
+ model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
113
+ multi_label=True)
114
+ labels = ["Coherent", "Off topic", "Repeated"]
115
+ res = clf(_norm(text), labels, hypothesis_template=f"This passage is {{}} with respect to {topic}.")
116
+ return {lbl: float(score) for lbl, score in zip(res["labels"], res["scores"])}
117
+
118
+ # -----------------------------
119
+ # Decision logic + reasons
120
+ # -----------------------------
121
+ def decide_label_with_reasons(
122
+ text: str,
123
+ topic_hint: Optional[str],
124
+ bb: Dict[str, Any],
125
+ sentences: List[str],
126
+ zshot_scores: Optional[Dict[str, float]] = None,
127
+ thresholds: Dict[str, float] = None
128
+ ) -> Dict[str, Any]:
129
+ """
130
+ Returns:
131
+ {
132
+ "label": "Coherent" | "Off topic" | "Repeated",
133
+ "reasons": [ "...", "..."],
134
+ "evidence": { "off_topic_examples": [...], "repeated_examples": [...] },
135
+ "bbscore": 0.74
136
+ }
137
+ """
138
+ thr = thresholds or {
139
+ "bb_coherent_min": 0.65, # >= coherent
140
+ "off_topic_ratio_max": 0.20, # <= coherent
141
+ "repeat_pairs_min": 1 # >= repeated (if any)
142
+ }
143
+ n = max(1, len(sentences))
144
+ off_ratio = len(bb["off_idx"]) / n
145
+ has_repeat = len(bb["rep_pairs"]) >= thr["repeat_pairs_min"]
146
+ bbscore = bb["bbscore"]
147
+
148
+ # Rule-based primary decision
149
+ if off_ratio > thr["off_topic_ratio_max"] and bbscore < thr["bb_coherent_min"]:
150
+ label = "Off topic"
151
+ elif has_repeat and bbscore >= 0.5:
152
+ label = "Repeated"
153
+ elif bbscore >= thr["bb_coherent_min"] and off_ratio <= thr["off_topic_ratio_max"] and not has_repeat:
154
+ label = "Coherent"
155
+ else:
156
+ # Tie-breaker using zero-shot if provided
157
+ if zshot_scores:
158
+ label = max(zshot_scores.items(), key=lambda kv: kv[1])[0]
159
+ else:
160
+ # fallback: prefer coherence if bbscore okay, else off-topic
161
+ label = "Coherent" if bbscore >= 0.6 else "Off topic"
162
+
163
+ # Reasons
164
+ reasons = [f"BBScore={bbscore:.3f}."]
165
+ if bb["off_idx"]:
166
+ reasons.append(f"Off-topic fraction={off_ratio:.2f} ({len(bb['off_idx'])}/{n} sentences below main-idea similarity).")
167
+ if has_repeat:
168
+ top_rep = sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]
169
+ reasons.append(f"Repeated content detected (top sim={top_rep[0][2]:.2f}).")
170
+
171
+ if zshot_scores:
172
+ top = sorted(zshot_scores.items(), key=lambda kv: kv[1], reverse=True)[:2]
173
+ reasons.append("Zero-shot support: " + ", ".join([f"{k}={v:.2f}" for k,v in top]))
174
+
175
+ # Evidence snippets
176
+ ev_off = [f'{i}: "{sentences[i]}"' for i in bb["off_idx"][:2]]
177
+ ev_rep = []
178
+ for (i, j, sim) in sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]:
179
+ ev_rep.append(f'({i},{j}) sim={sim:.2f}: "{sentences[i]}", "{sentences[j]}"')
180
+
181
+ return {
182
+ "label": label,
183
+ "reasons": reasons,
184
+ "evidence": {"off_topic_examples": ev_off, "repeated_examples": ev_rep},
185
+ "bbscore": bbscore
186
+ }
187
+
188
+ def _display_title(meta: Dict[str, Any], fallback: str) -> str:
189
+ if meta.get("title"): return str(meta["title"]).strip()
190
+ src = meta.get("source") or meta.get("path")
191
+ if src:
192
+ base = os.path.basename(str(src))
193
+ return re.sub(r"\.pdf$", "", base, flags=re.I)
194
+ return meta.get("doc_id") or fallback
195
+
196
+ def _page_label(meta: Dict[str, Any]) -> str:
197
+ return str(meta.get("page_label") or meta.get("page") or "?")
198
+
199
+ def to_std_doc(item: Any, idx: int = 0) -> Dict[str, Any]:
200
+ """
201
+ Accepts a LangChain Document or dict; returns a standard dict:
202
+ {title, page_label, text}
203
+ """
204
+ if hasattr(item, "page_content"): # LangChain Document
205
+ meta = getattr(item, "metadata", {}) or {}
206
+ return {
207
+ "title": _display_title(meta, f"doc{idx+1}"),
208
+ "page_label": _page_label(meta),
209
+ "text": _norm(item.page_content),
210
+ }
211
+ elif isinstance(item, dict):
212
+ meta = item.get("metadata", {}) or {}
213
+ title = item.get("title") or _display_title(meta, item.get("doc_id", f"doc{idx+1}"))
214
+ page = item.get("page_label") or _page_label(meta)
215
+ text = _norm(item.get("text") or item.get("page_content", ""))
216
+ return {"title": title, "page_label": page, "text": text}
217
+ else:
218
+ raise TypeError(f"Unsupported doc type at index {idx}: {type(item)}")
219
+
220
+ def coherence_assessment_std(
221
+ std_doc: Dict[str, Any],
222
+ embedder,
223
+ topic_hint: Optional[str] = None,
224
+ run_zero_shot: bool = False,
225
+ thresholds: Optional[Dict[str, float]] = None
226
+ ) -> Dict[str, Any]:
227
+ """Same as your coherence_assessment, but expects a standardized dict."""
228
+ text = std_doc.get("text", "")
229
+ sents = split_sentences(text)
230
+ if not sents:
231
+ return {"title": std_doc.get("title","Document"), "label": "Off topic", "bbscore": 0.0,
232
+ "reasons": ["Empty text."], "evidence": {}}
233
+ E = embedder.encode(sents)
234
+ bb = bb_coherence(sents, E)
235
+ zshot_scores = zshot_label(text, topic_hint) if run_zero_shot else None
236
+ decision = decide_label_with_reasons(text, topic_hint, bb, sents, zshot_scores, thresholds)
237
+ return {
238
+ "title": std_doc.get("title","Document"),
239
+ "page_label": std_doc.get("page_label","?"),
240
+ "label": decision["label"],
241
+ "bbscore": decision["bbscore"],
242
+ "reasons": decision["reasons"],
243
+ "evidence": decision["evidence"],
244
+ }
245
+
246
+ # Get the coherence report
247
+ def coherence_report(embedder="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
248
+ input_text=None,
249
+ reranked_results=None,
250
+ run_zero_shot=True):
251
+ embedder = Embedder(embedder) if isinstance(embedder, str) else embedder
252
+ if reranked_results is None:
253
+ reranked_results = retrieve_and_rerank(input_text)
254
+ if not reranked_results:
255
+ return []
256
+ # Convert reranked_results to standardized documents
257
+ std_results = [to_std_doc(doc, i) for i, doc in enumerate(reranked_results)]
258
+ reports = [coherence_assessment_std(d, embedder, topic_hint=input_text, run_zero_shot=run_zero_shot)
259
+ for d in std_results]
260
+ return reports
utils/encoding_input.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Methods to encode text
2
+ import numpy as np
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+
5
+ def encode_text(text, embedding_model='sentence-transformers/all-MiniLM-L6-v2', as_array=True):
6
+ """Encodes the input text using the provided embedding model."""
7
+ embedding_model = HuggingFaceEmbeddings(model_name=embedding_model)
8
+ encoded_input = embedding_model.embed_query(text)
9
+ if as_array:
10
+ return np.array(encoded_input)
11
+ else:
12
+ return encoded_input
utils/generation_streaming.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from langchain_community.embeddings import HuggingFaceEmbeddings
2
+ # from langchain_community.embeddings import CrossEncoder
3
+ import requests
4
+ import numpy as np
5
+ import time
6
+ import json
7
+
8
+ # encode the text
9
+ from encoding_input import encode_text
10
+
11
+ # rertrieve and rerank the documents
12
+ from retrieve_n_rerank import retrieve_and_rerank
13
+
14
+ # sentiment analysis on reranked documents
15
+ from sentiment_analysis import get_sentiment
16
+
17
+ # coherence assessment reports
18
+ from coherence_bbscore import coherence_report
19
+
20
+ # Get the vectorstore
21
+ from loading_embeddings import get_vectorstore
22
+ vectorstore = get_vectorstore()
23
+
24
+ # build message from model generation
25
+ from model_generation import build_messages
26
+
27
+ API_KEY = "sk-do-"
28
+ MODEL = "llama3.3-70b-instruct"
29
+
30
+ def generate_response_stream(query: str, enable_sentiment: bool, enable_coherence: bool):
31
+
32
+
33
+
34
+ # encoded_input = encode_text(query)
35
+
36
+
37
+
38
+ reranked_results = retrieve_and_rerank(
39
+ query_text=query,
40
+ vectorstore=vectorstore,
41
+ k=50, # number of initial documents to retrieve
42
+ rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
43
+ top_m=20, # number of documents to return after reranking
44
+ min_score=0.5, # minimum score for reranked documents
45
+ only_docs=False # return both documents and scores
46
+ )
47
+ top_docs = [doc for doc, score in reranked_results]
48
+
49
+ if not top_docs:
50
+ yield "No relevant documents found."
51
+ return
52
+
53
+ sentiment_rollup = get_sentiment(top_docs) if enable_sentiment else {}
54
+ coherence_report_ = coherence_report(reranked_results=top_docs, input_text= query) if enable_coherence else ""
55
+
56
+ messages = build_messages(
57
+ query=query,
58
+ top_docs=top_docs,
59
+ task_mode="verbatim_sentiment",
60
+ sentiment_rollup=sentiment_rollup,
61
+ coherence_report=coherence_report_,
62
+ )
63
+
64
+ headers = {
65
+ "Authorization": f"Bearer {API_KEY}",
66
+ "Content-Type": "application/json"
67
+ }
68
+
69
+ data = {
70
+ "model": MODEL,
71
+ "messages": messages,
72
+ "temperature": 0.2,
73
+ "stream": True,
74
+ "max_tokens": 2000
75
+ }
76
+
77
+ collected = "" # Accumulate content to show
78
+
79
+ with requests.post("https://inference.do-ai.run/v1/chat/completions", headers=headers, json=data, stream=True) as r:
80
+ if r.status_code != 200:
81
+ yield f"[ERROR] API returned status {r.status_code}: {r.text}"
82
+ return
83
+
84
+ for line in r.iter_lines(decode_unicode=True):
85
+ if not line or line.strip() == "data: [DONE]":
86
+ continue
87
+ if line.startswith("data: "):
88
+ line = line[len("data: "):]
89
+
90
+ try:
91
+ chunk = json.loads(line)
92
+ delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
93
+ if delta:
94
+ collected += delta
95
+ yield collected # yield progressively
96
+ time.sleep(0.01) # slight throttle to improve smoothness
97
+ except Exception as e:
98
+ print("Streaming decode error:", e)
99
+ continue
utils/loading_embeddings.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Loading embeddings from storage
2
+ import os
3
+ from pathlib import Path
4
+ from huggingface_hub import hf_hub_download
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+
8
+ # download it at the data directory
9
+ data_path = os.path.join(Path(os.getcwd()).parent, "data")
10
+ # make the faiss local folder
11
+ local_folder = os.path.join(data_path, 'faiss_index')
12
+
13
+ def download_faiss_index(repo_id="kaburia/epic-a-embeddings", local_folder="faiss_index"):
14
+
15
+ os.makedirs(local_folder, exist_ok=True)
16
+
17
+
18
+ index_faiss_path = os.path.join(local_folder, "index.faiss")
19
+ index_pkl_path = os.path.join(local_folder, "index.pkl")
20
+
21
+ if not os.path.exists(index_faiss_path):
22
+ print("Downloading index.faiss from Hugging Face Dataset...")
23
+ hf_hub_download(
24
+ repo_id=repo_id,
25
+ filename="index.faiss",
26
+ repo_type="dataset",
27
+ local_dir=local_folder,
28
+ local_dir_use_symlinks=False,
29
+ )
30
+
31
+ if not os.path.exists(index_pkl_path):
32
+ print("Downloading index.pkl from Hugging Face Dataset...")
33
+ hf_hub_download(
34
+ repo_id=repo_id,
35
+ filename="index.pkl",
36
+ repo_type="dataset",
37
+ local_dir=local_folder,
38
+ local_dir_use_symlinks=False,
39
+ )
40
+
41
+ def load_vectorstore(index_path="faiss_index"):
42
+ embedding_model = HuggingFaceEmbeddings(
43
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
44
+ )
45
+ db = FAISS.load_local(
46
+ index_path,
47
+ embeddings=embedding_model,
48
+ allow_dangerous_deserialization=True
49
+ )
50
+ return db
51
+
52
+ # download and load vectorstore
53
+ def get_vectorstore(repo_id="kaburia/epic-a-embeddings", local_folder="faiss_index"):
54
+ download_faiss_index(repo_id=repo_id, local_folder=local_folder)
55
+ return load_vectorstore(index_path=local_folder)
utils/model_generation.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from typing import List, Dict, Any, Union
4
+ import time
5
+ import numpy as np
6
+
7
+ import os
8
+
9
+
10
+
11
+
12
+
13
+ PROMPT_TEMPLATES = {
14
+ "verbatim_sentiment": {
15
+ "system": (
16
+ "You are a compliance-grade policy analyst assistant. "
17
+ "Your job is to return a precise, fact-grounded response. "
18
+ "Avoid hallucinations. Base everything strictly on the content provided."
19
+ "if the coherence and or sentiment analysis is not enabled, do not mention it in the response."
20
+ ),
21
+ "user_template": """
22
+ Query: {query}
23
+
24
+ Deliverables:
25
+ 1) **Quoted Policy Excerpts**: Quote key policy content directly. Cite the source using filename and page.
26
+ 2) **Sentiment Summary**: Use the sentiment JSON to explain tone, gaps, penalties, or enforcement clarity in plain English.
27
+ 3) **Coherence Assessment**: Summarize the coherence report below. Highlight:
28
+ - Whether the answer was mostly on-topic or off-topic
29
+ - point out the sections that were coherent, off topic and repeated
30
+
31
+ Topic hint: {topic_hint}
32
+
33
+ Sentiment JSON (rolled-up across top docs):
34
+ {sentiment_json}
35
+
36
+ Coherence report:
37
+ {coherence_report}
38
+
39
+ Context Sources:
40
+ {context_block}
41
+ """
42
+ },
43
+
44
+ "abstractive_summary": {
45
+ "system": (
46
+ "You are a policy analyst summarizing government documents for a general audience. "
47
+ "Your response should paraphrase clearly, avoiding quotes unless absolutely necessary. "
48
+ "Highlight high-level goals, enforcement strategies, and important deadlines or penalties."
49
+ ),
50
+ "user_template": """Query: {query}
51
+
52
+ Summarize the answer in natural, non-technical language. Emphasize clarity and coverage. Avoid quoting unless the phrase is legally binding.
53
+
54
+ Topic hint: {topic_hint}
55
+
56
+ Context DOCS:
57
+ {context_block}
58
+ """
59
+ },
60
+
61
+ "followup_reasoning": {
62
+ "system": (
63
+ "You are an assistant that explains policy documents interactively, reasoning step-by-step. "
64
+ "Always cite document IDs and indicate if certain info is absent."
65
+ ),
66
+ "user_template": """User query: {query}
67
+
68
+ Explain the answer step-by-step. Add follow-up questions that a reader might ask, and try to answer them using the documents below.
69
+
70
+ Topic: {topic_hint}
71
+
72
+ DOCS:
73
+ {context_block}
74
+ """
75
+ },
76
+
77
+ # Add more templates as needed
78
+ }
79
+
80
+
81
+ # --- LLM client ---
82
+ def get_do_completion(api_key, model_name, messages, temperature=0.2, max_tokens=800):
83
+ url = "https://inference.do-ai.run/v1/chat/completions"
84
+ headers = {
85
+ "Authorization": f"Bearer {api_key}",
86
+ "Content-Type": "application/json"
87
+ }
88
+ data = {
89
+ "model": model_name,
90
+ "messages": messages,
91
+ "temperature": temperature,
92
+ "max_tokens": max_tokens
93
+ }
94
+ try:
95
+ resp = requests.post(url, headers=headers, json=data, timeout=90)
96
+ resp.raise_for_status()
97
+ return resp.json()
98
+ except requests.exceptions.HTTPError as e:
99
+ print(f"HTTP error occurred: {e}")
100
+ print(f"Response body: {e.response.text if e.response is not None else ''}")
101
+ return None
102
+ except requests.exceptions.RequestException as e:
103
+ print(f"Request error: {e}")
104
+ return None
105
+ except json.JSONDecodeError as e:
106
+ print(f"Failed to decode JSON: {e}")
107
+ print(f"Response text: {resp.text if 'resp' in locals() else ''}")
108
+ return None
109
+
110
+
111
+
112
+ # --- Prompt context builder ---
113
+ def _clip(text: str, max_chars: int = 1400) -> str:
114
+ """Trim content to limit prompt size."""
115
+ if not text:
116
+ return ""
117
+ text = str(text).strip()
118
+ return text[:max_chars] + ("..." if len(text) > max_chars else "")
119
+
120
+
121
+ def build_context_block(top_docs: List[Dict[str, Any]]) -> str:
122
+ """
123
+ Formats each document with real citation:
124
+ - Extracts file name from 'source' path
125
+ - Uses 'page_label' or falls back to 'page'
126
+ - Returns: <<<SOURCE: {filename}, p. {page_label}>>>
127
+ """
128
+ blocks = []
129
+ for i, item in enumerate(top_docs):
130
+ if hasattr(item, "page_content"):
131
+ text = item.page_content
132
+ meta = getattr(item, "metadata", {})
133
+ else:
134
+ text = item.get("text") or item.get("page_content", "")
135
+ meta = item.get("metadata", {})
136
+
137
+ # Get file name from path
138
+ full_path = meta.get("source", "")
139
+ filename = os.path.basename(full_path) if full_path else f"Document_{i+1}"
140
+
141
+ # Prefer page_label if available, else fallback to raw page
142
+ page_label = meta.get("page_label") or meta.get("page") or "unknown"
143
+
144
+ citation = f"{filename}, p. {page_label}"
145
+
146
+ blocks.append(f"<<<SOURCE: {citation}>>>\n{_clip(text)}\n</SOURCE>")
147
+
148
+ return "\n".join(blocks)
149
+
150
+
151
+ # --- Message builder ---
152
+ def build_messages(
153
+ query: str,
154
+ top_docs: List[Dict[str, Any]],
155
+ task_mode: str,
156
+ sentiment_rollup: Dict[str, List[str]],
157
+ coherence_report: str = "",
158
+ topic_hint: str = "energy policy"
159
+ ) -> List[Dict[str, str]]:
160
+ template = PROMPT_TEMPLATES.get(task_mode)
161
+ if not template:
162
+ raise ValueError(f"Unknown task mode: {task_mode}")
163
+
164
+ context_block = build_context_block(top_docs)
165
+ sentiment_json = json.dumps(sentiment_rollup or {}, ensure_ascii=False)
166
+
167
+ user_prompt = template["user_template"].format(
168
+ query=query,
169
+ topic_hint=topic_hint,
170
+ sentiment_json=sentiment_json,
171
+ context_block=context_block,
172
+ coherence_report=coherence_report
173
+ )
174
+
175
+ return [
176
+ {"role": "system", "content": template["system"]},
177
+ {"role": "user", "content": user_prompt}
178
+ ]
179
+
180
+
181
+ # --- Generation orchestrator ---
182
+ def generate_policy_answer(
183
+ api_key: str,
184
+ model_name: str,
185
+ query: str,
186
+ top_docs: List[Union[Dict[str, Any], Any]],
187
+ sentiment_rollup: Dict[str, List[str]],
188
+ coherence_report: str = "",
189
+ task_mode: str = "verbatim_sentiment",
190
+ temperature: float = 0.2,
191
+ max_tokens: int = 2000
192
+ ) -> str:
193
+ if not top_docs:
194
+ return "No documents available to answer."
195
+
196
+ messages = build_messages(
197
+ query=query,
198
+ top_docs=top_docs,
199
+ task_mode=task_mode,
200
+ sentiment_rollup=sentiment_rollup,
201
+ coherence_report=coherence_report
202
+ )
203
+ resp = get_do_completion(api_key, model_name, messages, temperature=temperature, max_tokens=max_tokens)
204
+ if resp is None:
205
+ return "Upstream model error. No response."
206
+ try:
207
+ return resp["choices"][0]["message"]["content"].strip()
208
+ except Exception:
209
+ return json.dumps(resp, indent=2)
210
+
211
+
utils/retrieve_n_rerank.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load the encoded text and vectorstore
2
+ from encoding_input import encode_text
3
+ from loading_embeddings import get_vectorstore
4
+ from sentence_transformers import CrossEncoder
5
+ import numpy as np
6
+ import faiss
7
+
8
+ def search_vectorstore(encoded_text, vectorstore, k=5, with_score=False):
9
+ """
10
+ Vector similarity search with optional distance/score return.
11
+
12
+ Args:
13
+ encoded_text (np.ndarray | list): 1-D vector.
14
+ vectorstore (langchain.vectorstores.faiss.FAISS): your store.
15
+ k (int): # of neighbors.
16
+ with_score (bool): toggle score output.
17
+
18
+ Returns:
19
+ list: docs or (doc, score) tuples.
20
+ """
21
+
22
+ q = np.asarray(encoded_text, dtype="float32").reshape(1, -1)
23
+
24
+ # ---- Use raw FAISS for full control and consistent behavior-------
25
+ index = vectorstore.index # faiss.Index
26
+ distances, idxs = index.search(q, k) # (1, k) each
27
+ docs = [vectorstore.docstore.search(
28
+ vectorstore.index_to_docstore_id[i]) for i in idxs[0]]
29
+
30
+ # Return with or without scores
31
+ return list(zip(docs, distances[0])) if with_score else docs
32
+
33
+ def rerank_cross_encoder(query_text, docs, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", top_m=20, min_score=None):
34
+ """
35
+ Returns top_m (doc, score) sorted by score desc. If min_score is set, filters below it.
36
+ docs: A list of Document objects.
37
+ """
38
+ ce = CrossEncoder(model_name)
39
+ # Create pairs of (query_text, document_content)
40
+ pairs = [(query_text, doc.page_content) for doc in docs] # Use doc.page_content for the text
41
+ scores = ce.predict(pairs) # higher is better
42
+
43
+ # Pair original documents with their scores and sort
44
+ scored_documents = sorted(zip(docs, scores.tolist()), key=lambda x: x[1], reverse=True)
45
+
46
+ # Apply minimum score filter if specified
47
+ if min_score is not None:
48
+ scored_documents = [r for r in scored_documents if r[1] >= min_score]
49
+
50
+ # Return the top_m reranked (Document, score) tuples
51
+ return scored_documents[:top_m]
52
+
53
+
54
+ # retrieval and reranking function
55
+ def retrieve_and_rerank(query_text, vectorstore, k=50,
56
+ rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
57
+ top_m=20, min_score=None,
58
+ only_docs=True):
59
+ # Step 1: Encode the query text
60
+ encoded_query = encode_text(query_text)
61
+
62
+ # Step 2: Retrieve relevant documents from the vectorstore
63
+ retrieved_docs = search_vectorstore(encoded_query, vectorstore, k=k)
64
+
65
+ # get only the documents
66
+ retrieved_docs = [doc for doc, _ in retrieved_docs] if isinstance(retrieved_docs[0], tuple) else retrieved_docs
67
+
68
+ # If no documents are retrieved, return an empty list
69
+ if not retrieved_docs:
70
+ return []
71
+
72
+ # Step 3: Rerank the retrieved documents
73
+ reranked_docs = rerank_cross_encoder(query_text, retrieved_docs, model_name=rerank_model, top_m=top_m, min_score=min_score)
74
+
75
+ # If only_docs is True, return just the documents
76
+ if only_docs:
77
+ return [doc for doc, _ in reranked_docs]
78
+ # Otherwise, return the reranked documents with their scores
79
+ return reranked_docs
utils/sentiment_analysis.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, math, torch
2
+ from transformers import pipeline
3
+
4
+ # ------------- Model (CPU-friendly); use device=0 + fp16 on GPU -------------
5
+ ZSHOT = pipeline(
6
+ "zero-shot-classification",
7
+ model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
8
+ multi_label=True,
9
+ device=-1,
10
+ model_kwargs={"torch_dtype": torch.float32}
11
+ )
12
+
13
+ # ------------------ Taxonomy with descriptions (helps NLI) -------------------
14
+ TAXO = {
15
+ "intent_type": [
16
+ "objective: declares goals or aims",
17
+ "principle: states guiding values",
18
+ "strategy: outlines measures or actions",
19
+ "obligation: mandates an action (shall/must)",
20
+ "prohibition: forbids an action",
21
+ "permission: allows an action (may)",
22
+ "exception: states conditions where rules change",
23
+ "definition: defines a term",
24
+ "scope: states applicability or coverage"
25
+ ],
26
+ "disposition": [
27
+ "restrictive: limits or constrains the topic",
28
+ "cautionary: warns or urges care",
29
+ "neutral: descriptive with no clear stance",
30
+ "enabling: allows or facilitates the topic",
31
+ "supportive: promotes or expands the topic"
32
+ ],
33
+ "rigidity": [
34
+ "must: mandatory (shall/must)",
35
+ "should: advisory (should)",
36
+ "may: permissive (may/can)"
37
+ ],
38
+ "temporal": [
39
+ "deadline: requires completion by a date or period",
40
+ "schedule: sets a cadence (e.g., annually, quarterly)",
41
+ "ongoing: continuing requirement without end date",
42
+ "effective_date: specifies when rules start/apply"
43
+ ],
44
+ "scope": [
45
+ "actor_specific: targets a group or entity (e.g., county governments, permit holders)",
46
+ "geography_specific: targets a location or region",
47
+ "subject_specific: targets a topic (e.g., permits, sanitation)",
48
+ "nationwide: applies across the country"
49
+ ],
50
+ "enforcement": [
51
+ "penalty: fines or sanctions for non-compliance",
52
+ "remedy: corrective actions required",
53
+ "monitoring: oversight or audits",
54
+ "reporting: reports/returns required",
55
+ "none_detected: no enforcement mechanisms present"
56
+ ],
57
+ "resourcing": [
58
+ "funding: funds or budget allocations",
59
+ "fees_levies: charges or levies",
60
+ "capacity_hr: staffing or training",
61
+ "infrastructure: capital works or equipment",
62
+ "none_detected: no resourcing present"
63
+ ],
64
+ "impact": [
65
+ "low: limited effect on regulated parties",
66
+ "medium: moderate practical effect",
67
+ "high: significant obligations or restrictions"
68
+ ]
69
+ }
70
+
71
+ # ---------------- Axis-specific thresholds (calibrate later) -----------------
72
+ TAU = {
73
+ "intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60,
74
+ "temporal": 0.62, "scope": 0.55,
75
+ "enforcement": 0.50, "resourcing": 0.50, "impact": 0.60
76
+ }
77
+ TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected"
78
+
79
+ # ------------------------- Cleaning & evidence rules -------------------------
80
+ def _clean(t: str) -> str:
81
+ t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t))
82
+ t = re.sub(r"\s{2,}", " ", t).strip()
83
+ return t
84
+
85
+ PAT = {
86
+ "actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b",
87
+ "nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b",
88
+ "objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]",
89
+ "imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)",
90
+ "modal_must": r"\bshall\b|\bmust\b",
91
+ "modal_should": r"\bshould\b",
92
+ "modal_may": r"\bmay\b|\bcan\b",
93
+ "temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b",
94
+ "enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b",
95
+ "resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b"
96
+ }
97
+
98
+ def _spans(text, pattern, max_spans=2):
99
+ spans = []
100
+ for m in re.finditer(pattern, text, flags=re.I):
101
+ # sentence-level extraction
102
+ start = text.rfind('.', 0, m.start()) + 1
103
+ end = text.find('.', m.end())
104
+ if end == -1: end = len(text)
105
+ snippet = text[start:end].strip()
106
+ if snippet and snippet not in spans:
107
+ spans.append(snippet)
108
+ if len(spans) >= max_spans: break
109
+ return spans
110
+
111
+ def _softmax(d):
112
+ vals = list(d.values())
113
+ if not vals: return {k: 0.0 for k in d}
114
+ m = max(vals)
115
+ exps = [math.exp(v - m) for v in vals]
116
+ Z = sum(exps)
117
+ return {k: (e / Z) for k, e in zip(d.keys(), exps)}
118
+
119
+ # -------------------- Main: classify + explanations + % ----------------------
120
+ def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2):
121
+ text = _clean(text)
122
+ if not text:
123
+ return {"decision_summary": "No operative decision; empty passage.",
124
+ "labels": {ax: [] for ax in TAXO},
125
+ "percents_raw": {ax: {} for ax in TAXO},
126
+ "percents_norm": {ax: {} for ax in TAXO},
127
+ "why": [], "text_preview": ""}
128
+
129
+ # Topic-aware hypotheses (improves stance/intent)
130
+ def hyp(axis):
131
+ base = "This passage {} regarding " + topic + "."
132
+ return {
133
+ "intent_type": base.format("states a {}"),
134
+ "disposition": base.format("is {}"),
135
+ "rigidity": "Compliance in this passage is {}.",
136
+ "temporal": base.format("specifies a {} aspect"),
137
+ "scope": base.format("is {} in applicability"),
138
+ "enforcement": base.format("includes {} for compliance"),
139
+ "resourcing": base.format("provides {}"),
140
+ "impact": base.format("has {} impact")
141
+ }[axis]
142
+
143
+ # Single call if supported; else per-axis fallback
144
+ tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)}
145
+ for axis, labels in TAXO.items()]
146
+ try:
147
+ results = ZSHOT(tasks)
148
+ except TypeError:
149
+ results = [ZSHOT(text, labels, hypothesis_template=hyp(axis))
150
+ for axis, labels in TAXO.items()]
151
+
152
+ labels_out, perc_raw, perc_norm, why = {}, {}, {}, []
153
+
154
+ for (axis, labels), r in zip(TAXO.items(), results):
155
+ # raw scores
156
+ raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])}
157
+ perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid
158
+ norm = _softmax(raw)
159
+ perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100%
160
+
161
+ # select labels by threshold
162
+ keep = [k for k, s in raw.items() if s >= TAU[axis]]
163
+ keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k]
164
+ # only emit none_detected when everything else is weak and no heuristic evidence
165
+ if not keep and "none_detected" in raw:
166
+ if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW:
167
+ keep = ["none_detected"]
168
+
169
+ labels_out[axis] = keep
170
+
171
+ # compact "why" with evidence for the top choice
172
+ if keep and keep[0] != "none_detected":
173
+ if axis == "intent_type":
174
+ ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"])
175
+ why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]})
176
+ elif axis == "disposition":
177
+ ev = _spans(text, PAT["imperative"])
178
+ why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]})
179
+ elif axis == "rigidity":
180
+ pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]]
181
+ why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]})
182
+ elif axis == "temporal":
183
+ why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]})
184
+ elif axis == "scope":
185
+ ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"])
186
+ why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]})
187
+ elif axis == "enforcement":
188
+ why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]})
189
+ elif axis == "resourcing":
190
+ why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]})
191
+
192
+ # Decision summary: imperative lines + problem statements; never fabricate
193
+ summary_bits = []
194
+ imperatives = re.findall(PAT["imperative"], text, flags=re.I)
195
+ # pull full imperative sentences
196
+ imp_sents = _spans(text, PAT["imperative"], max_spans=3)
197
+ if imp_sents:
198
+ summary_bits.append("Strategies: " + " ".join(imp_sents))
199
+ if "nationwide" in labels_out.get("scope", []):
200
+ summary_bits.append("Applies nationwide.")
201
+ if labels_out.get("enforcement") == ["none_detected"]:
202
+ summary_bits.append("Enforcement: none detected in this passage.")
203
+ if labels_out.get("resourcing") == ["none_detected"]:
204
+ summary_bits.append("Resourcing: none detected in this passage.")
205
+ decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected."
206
+
207
+ return {
208
+ "decision_summary": decision_summary,
209
+ "labels": labels_out,
210
+ "percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100)
211
+ "percents_norm": perc_norm, # normalized per axis (sums to ~100)
212
+ "why": why,
213
+ "text_preview": text[:300] + ("..." if len(text) > 300 else "")
214
+ }
215
+
216
+ # Get the sentiment for all the docs
217
+ def get_sentiment(texts):
218
+ return [classify_and_explain(texts[i].page_content) for i in range(len(texts))]