import re, math, torch from transformers import pipeline import spaces # ------------- Model (CPU-friendly); use device=0 + fp16 on GPU ------------- ZSHOT = pipeline( "zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0", multi_label=True, device=-1, model_kwargs={"torch_dtype": torch.float32} ) # ------------------ Taxonomy with descriptions (helps NLI) ------------------- TAXO = { "intent_type": [ "objective: declares goals or aims", "principle: states guiding values", "strategy: outlines measures or actions", "obligation: mandates an action (shall/must)", "prohibition: forbids an action", "permission: allows an action (may)", "exception: states conditions where rules change", "definition: defines a term", "scope: states applicability or coverage" ], "disposition": [ "restrictive: limits or constrains the topic", "cautionary: warns or urges care", "neutral: descriptive with no clear stance", "enabling: allows or facilitates the topic", "supportive: promotes or expands the topic" ], "rigidity": [ "must: mandatory (shall/must)", "should: advisory (should)", "may: permissive (may/can)" ], "temporal": [ "deadline: requires completion by a date or period", "schedule: sets a cadence (e.g., annually, quarterly)", "ongoing: continuing requirement without end date", "effective_date: specifies when rules start/apply" ], "scope": [ "actor_specific: targets a group or entity (e.g., county governments, permit holders)", "geography_specific: targets a location or region", "subject_specific: targets a topic (e.g., permits, sanitation)", "nationwide: applies across the country" ], "enforcement": [ "penalty: fines or sanctions for non-compliance", "remedy: corrective actions required", "monitoring: oversight or audits", "reporting: reports/returns required", "none_detected: no enforcement mechanisms present" ], "resourcing": [ "funding: funds or budget allocations", "fees_levies: charges or levies", "capacity_hr: staffing or training", "infrastructure: capital works or equipment", "none_detected: no resourcing present" ], "impact": [ "low: limited effect on regulated parties", "medium: moderate practical effect", "high: significant obligations or restrictions" ] } # ---------------- Axis-specific thresholds (calibrate later) ----------------- TAU = { "intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60, "temporal": 0.62, "scope": 0.55, "enforcement": 0.50, "resourcing": 0.50, "impact": 0.60 } TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected" # ------------------------- Cleaning & evidence rules ------------------------- def _clean(t: str) -> str: t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t)) t = re.sub(r"\s{2,}", " ", t).strip() return t PAT = { "actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b", "nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b", "objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]", "imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)", "modal_must": r"\bshall\b|\bmust\b", "modal_should": r"\bshould\b", "modal_may": r"\bmay\b|\bcan\b", "temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b", "enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b", "resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b" } def _spans(text, pattern, max_spans=2): spans = [] for m in re.finditer(pattern, text, flags=re.I): # sentence-level extraction start = text.rfind('.', 0, m.start()) + 1 end = text.find('.', m.end()) if end == -1: end = len(text) snippet = text[start:end].strip() if snippet and snippet not in spans: spans.append(snippet) if len(spans) >= max_spans: break return spans def _softmax(d): vals = list(d.values()) if not vals: return {k: 0.0 for k in d} m = max(vals) exps = [math.exp(v - m) for v in vals] Z = sum(exps) return {k: (e / Z) for k, e in zip(d.keys(), exps)} # -------------------- Main: classify + explanations + % ---------------------- def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2): text = _clean(text) if not text: return {"decision_summary": "No operative decision; empty passage.", "labels": {ax: [] for ax in TAXO}, "percents_raw": {ax: {} for ax in TAXO}, "percents_norm": {ax: {} for ax in TAXO}, "why": [], "text_preview": ""} # Topic-aware hypotheses (improves stance/intent) def hyp(axis): base = "This passage {} regarding " + topic + "." return { "intent_type": base.format("states a {}"), "disposition": base.format("is {}"), "rigidity": "Compliance in this passage is {}.", "temporal": base.format("specifies a {} aspect"), "scope": base.format("is {} in applicability"), "enforcement": base.format("includes {} for compliance"), "resourcing": base.format("provides {}"), "impact": base.format("has {} impact") }[axis] # Single call if supported; else per-axis fallback tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)} for axis, labels in TAXO.items()] try: results = ZSHOT(tasks) except TypeError: results = [ZSHOT(text, labels, hypothesis_template=hyp(axis)) for axis, labels in TAXO.items()] labels_out, perc_raw, perc_norm, why = {}, {}, {}, [] for (axis, labels), r in zip(TAXO.items(), results): # raw scores raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])} perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid norm = _softmax(raw) perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100% # select labels by threshold keep = [k for k, s in raw.items() if s >= TAU[axis]] keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k] # only emit none_detected when everything else is weak and no heuristic evidence if not keep and "none_detected" in raw: if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW: keep = ["none_detected"] labels_out[axis] = keep # compact "why" with evidence for the top choice if keep and keep[0] != "none_detected": if axis == "intent_type": ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"]) why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]}) elif axis == "disposition": ev = _spans(text, PAT["imperative"]) why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]}) elif axis == "rigidity": pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]] why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]}) elif axis == "temporal": why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]}) elif axis == "scope": ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"]) why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]}) elif axis == "enforcement": why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]}) elif axis == "resourcing": why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]}) # Decision summary: imperative lines + problem statements; never fabricate summary_bits = [] imperatives = re.findall(PAT["imperative"], text, flags=re.I) # pull full imperative sentences imp_sents = _spans(text, PAT["imperative"], max_spans=3) if imp_sents: summary_bits.append("Strategies: " + " ".join(imp_sents)) if "nationwide" in labels_out.get("scope", []): summary_bits.append("Applies nationwide.") if labels_out.get("enforcement") == ["none_detected"]: summary_bits.append("Enforcement: none detected in this passage.") if labels_out.get("resourcing") == ["none_detected"]: summary_bits.append("Resourcing: none detected in this passage.") decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected." return { "decision_summary": decision_summary, "labels": labels_out, "percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100) "percents_norm": perc_norm, # normalized per axis (sums to ~100) "why": why, "text_preview": text[:300] + ("..." if len(text) > 300 else "") } # Get the sentiment for all the docs @spaces.GPU(duration=120) def get_sentiment(texts): return [classify_and_explain(texts[i].page_content) for i in range(len(texts))]