Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import re | |
| import io | |
| import json | |
| import time | |
| import zipfile | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| GenerationConfig, | |
| ) | |
| # ========================= | |
| # Global config | |
| # ========================= | |
| SPACE_CACHE = Path.home() / ".cache" / "huggingface" | |
| SPACE_CACHE.mkdir(parents=True, exist_ok=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| GEN_CONFIG = GenerationConfig( | |
| temperature=0.2, | |
| top_p=0.9, | |
| do_sample=False, | |
| max_new_tokens=256, | |
| ) | |
| # Official UBS label set (strict) | |
| OFFICIAL_LABELS = [ | |
| "plan_contact", | |
| "schedule_meeting", | |
| "update_contact_info_non_postal", | |
| "update_contact_info_postal_address", | |
| "update_kyc_activity", | |
| "update_kyc_origin_of_assets", | |
| "update_kyc_purpose_of_businessrelation", | |
| "update_kyc_total_assets", | |
| ] | |
| OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS) | |
| # Per-label keyword cues (static prompt context to improve recall) | |
| LABEL_KEYWORDS: Dict[str, List[str]] = { | |
| "plan_contact": [ | |
| "call back", "follow up", "reach out", "contact later", "check-in", | |
| "email them", "touch base", "remind", "send a note" | |
| ], | |
| "schedule_meeting": [ | |
| "book a meeting", "set up a meeting", "schedule a call", | |
| "appointment", "calendar", "meeting next week", "meet on", "time slot" | |
| ], | |
| "update_contact_info_non_postal": [ | |
| "phone change", "new phone", "email change", "new email", | |
| "update contact details", "update mobile", "alternate phone" | |
| ], | |
| "update_contact_info_postal_address": [ | |
| "moved to", "new address", "postal address", "mailing address", | |
| "change of address", "residential address" | |
| ], | |
| "update_kyc_activity": [ | |
| "activity update", "economic activity", "employment status", | |
| "occupation", "job change", "business activity" | |
| ], | |
| "update_kyc_origin_of_assets": [ | |
| "source of funds", "origin of assets", "where money comes from", | |
| "inheritance", "salary", "business income", "asset origin" | |
| ], | |
| "update_kyc_purpose_of_businessrelation": [ | |
| "purpose of relationship", "why the account", "reason for banking", | |
| "investment purpose", "relationship purpose" | |
| ], | |
| "update_kyc_total_assets": [ | |
| "total assets", "net worth", "assets under ownership", | |
| "portfolio size", "how much you own" | |
| ], | |
| } | |
| # ========================= | |
| # Instructions (string-safe; concatenated) | |
| # ========================= | |
| SYSTEM_PROMPT = ( | |
| "You are a precise banking assistant that extracts ACTIONABLE TASKS from " | |
| "client–advisor transcripts. Be conservative with hallucinations but " | |
| "prioritise RECALL: if unsure and the transcript plausibly implies an " | |
| "action, include the label and explain briefly.\n\n" | |
| "Output STRICT JSON only:\n\n" | |
| "{\n" | |
| ' "labels": ["<Label1>", "..."],\n' | |
| ' "tasks": [\n' | |
| ' {"label": "<Label1>", "explanation": "<why>", "evidence": "<quoted text/snippet>"}\n' | |
| " ]\n" | |
| "}\n\n" | |
| "Rules:\n" | |
| "- Use ONLY allowed labels supplied to you. Case-insensitive during reasoning, " | |
| " but output the canonical label text exactly.\n" | |
| "- If none truly apply, return empty lists.\n" | |
| "- Keep explanations concise; put the minimal evidence snippet that justifies the task.\n" | |
| ) | |
| USER_PROMPT_TEMPLATE = ( | |
| "Transcript (cleaned):\n" | |
| "```\n{transcript}\n```\n\n" | |
| "Allowed Labels (canonical; use only these):\n" | |
| "{allowed_labels_list}\n\n" | |
| "Context cues (keywords/phrases that often indicate each label):\n" | |
| "{keyword_context}\n\n" | |
| "Instructions:\n" | |
| "- Identify EVERY concrete task implied by the conversation.\n" | |
| "- Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).\n" | |
| "- Return STRICT JSON only in the exact schema described by the system prompt.\n" | |
| ) | |
| # ========================= | |
| # Utilities | |
| # ========================= | |
| def _now_ms() -> int: | |
| return int(time.time() * 1000) | |
| def normalize_labels(labels: List[str]) -> List[str]: | |
| return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()])) | |
| def canonicalize_map(allowed: List[str]) -> Dict[str, str]: | |
| return {lab.lower(): lab for lab in allowed} | |
| def robust_json_extract(text: str) -> Dict[str, Any]: | |
| if not text: | |
| return {"labels": [], "tasks": []} | |
| start, end = text.find("{"), text.rfind("}") | |
| candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| candidate = re.sub(r",\s*}", "}", candidate) | |
| candidate = re.sub(r",\s*]", "]", candidate) | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| return {"labels": [], "tasks": []} | |
| def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]: | |
| out = {"labels": [], "tasks": []} | |
| allowed_map = canonicalize_map(allowed) | |
| # labels | |
| filt_labels = [] | |
| for l in pred.get("labels", []) or []: | |
| k = str(l).strip().lower() | |
| if k in allowed_map: | |
| filt_labels.append(allowed_map[k]) | |
| filt_labels = normalize_labels(filt_labels) | |
| # tasks | |
| filt_tasks = [] | |
| for t in pred.get("tasks", []) or []: | |
| if not isinstance(t, dict): | |
| continue | |
| k = str(t.get("label", "")).strip().lower() | |
| if k in allowed_map: | |
| new_t = dict(t); new_t["label"] = allowed_map[k] | |
| filt_tasks.append(new_t) | |
| merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks})) | |
| out["labels"] = merged | |
| out["tasks"] = filt_tasks | |
| return out | |
| # ========================= | |
| # Default pre-processing (toggleable) | |
| # ========================= | |
| _DISCLAIMER_PATTERNS = [ | |
| r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)", | |
| r"(?is)^\s*the information contained.+?(?:\n{2,}|$)", | |
| r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)", | |
| ] | |
| _FOOTER_PATTERNS = [ | |
| r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$", | |
| r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$", | |
| ] | |
| _TIMESTAMP_SPEAKER = [ | |
| r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02] | |
| r"^\s*(advisor|client)\s*:\s*", # Advisor: / Client: | |
| r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1: | |
| ] | |
| def clean_transcript(text: str) -> str: | |
| if not text: | |
| return text | |
| s = text | |
| # remove timestamps/speaker prefixes line-wise | |
| lines = [] | |
| for ln in s.splitlines(): | |
| ln2 = ln | |
| for pat in _TIMESTAMP_SPEAKER: | |
| ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE) | |
| lines.append(ln2) | |
| s = "\n".join(lines) | |
| # remove top disclaimers | |
| for pat in _DISCLAIMER_PATTERNS: | |
| s = re.sub(pat, "", s).strip() | |
| # remove trailing footers | |
| for pat in _FOOTER_PATTERNS: | |
| s = re.sub(pat, "", s) | |
| # collapse whitespace | |
| s = re.sub(r"[ \t]+", " ", s) | |
| s = re.sub(r"\n{3,}", "\n\n", s).strip() | |
| return s | |
| def read_text_file_any(file_input) -> str: | |
| """Works for gr.File(type='filepath') and raw strings/Path and file-like.""" | |
| if not file_input: | |
| return "" | |
| if isinstance(file_input, (str, Path)): | |
| try: | |
| return Path(file_input).read_text(encoding="utf-8", errors="ignore") | |
| except Exception: | |
| return "" | |
| try: | |
| data = file_input.read() | |
| return data.decode("utf-8", errors="ignore") | |
| except Exception: | |
| return "" | |
| def read_json_file_any(file_input) -> Optional[dict]: | |
| if not file_input: | |
| return None | |
| if isinstance(file_input, (str, Path)): | |
| try: | |
| return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore")) | |
| except Exception: | |
| return None | |
| try: | |
| return json.loads(file_input.read().decode("utf-8", errors="ignore")) | |
| except Exception: | |
| return None | |
| def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str: | |
| toks = tokenizer(text, add_special_tokens=False)["input_ids"] | |
| if len(toks) <= max_tokens: | |
| return text | |
| return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True) | |
| # ========================= | |
| # HF model wrapper | |
| # ========================= | |
| class ModelWrapper: | |
| def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool): | |
| self.repo_id = repo_id | |
| self.hf_token = hf_token | |
| self.load_in_4bit = load_in_4bit | |
| self.tokenizer = None | |
| self.model = None | |
| def load(self): | |
| qcfg = None | |
| if self.load_in_4bit and DEVICE == "cuda": | |
| qcfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tok = AutoTokenizer.from_pretrained( | |
| self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE), | |
| trust_remote_code=True, use_fast=True, | |
| ) | |
| if tok.pad_token is None and tok.eos_token: | |
| tok.pad_token = tok.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE), | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=True, quantization_config=qcfg, | |
| attn_implementation="sdpa", | |
| ) | |
| self.tokenizer = tok | |
| self.model = model | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| # Build inputs as input_ids=... (avoid **tensor bug from earlier) | |
| if hasattr(self.tokenizer, "apply_chat_template"): | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = input_ids.to(self.model.device) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| generation_config=GEN_CONFIG, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| else: | |
| enc = self.tokenizer( | |
| f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n", | |
| return_tensors="pt" | |
| ).to(self.model.device) | |
| gen_kwargs = dict( | |
| **enc, | |
| generation_config=GEN_CONFIG, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): | |
| out_ids = self.model.generate(**gen_kwargs) | |
| return self.tokenizer.decode(out_ids[0], skip_special_tokens=True) | |
| _MODEL_CACHE: Dict[str, ModelWrapper] = {} | |
| def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper: | |
| key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}" | |
| if key not in _MODEL_CACHE: | |
| m = ModelWrapper(repo_id, hf_token, load_in_4bit) | |
| m.load() | |
| _MODEL_CACHE[key] = m | |
| return _MODEL_CACHE[key] | |
| # ========================= | |
| # Official evaluation (from README) | |
| # ========================= | |
| def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float: | |
| ALLOWED_LABELS = OFFICIAL_LABELS | |
| LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)} | |
| def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]: | |
| if not isinstance(sample_labels, list): | |
| raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}") | |
| seen, uniq = set(), [] | |
| for label in sample_labels: | |
| if not isinstance(label, str): | |
| raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})") | |
| if label in seen: | |
| raise ValueError(f"{sample_name} contains duplicate label: '{label}'") | |
| if label not in ALLOWED_LABELS: | |
| raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}") | |
| seen.add(label); uniq.append(label) | |
| return uniq | |
| if len(y_true) != len(y_pred): | |
| raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}") | |
| n_samples = len(y_true) | |
| n_labels = len(OFFICIAL_LABELS) | |
| y_true_binary = np.zeros((n_samples, n_labels), dtype=int) | |
| y_pred_binary = np.zeros((n_samples, n_labels), dtype=int) | |
| for i, sample_labels in enumerate(y_true): | |
| for label in _process_sample_labels(sample_labels, f"y_true[{i}]"): | |
| y_true_binary[i, LABEL_TO_IDX[label]] = 1 | |
| for i, sample_labels in enumerate(y_pred): | |
| for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"): | |
| y_pred_binary[i, LABEL_TO_IDX[label]] = 1 | |
| fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1) # penalty 2x | |
| fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1) # penalty 1x | |
| weighted = 2.0 * fn + 1.0 * fp | |
| max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1)) | |
| per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0) | |
| return float(max(0.0, min(1.0, np.mean(per_sample)))) | |
| # ========================= | |
| # Fallback: keyword heuristics if model returns empty | |
| # ========================= | |
| def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]: | |
| low = text.lower() | |
| labels = [] | |
| tasks = [] | |
| for lab in allowed: | |
| hits = [] | |
| for kw in LABEL_KEYWORDS.get(lab, []): | |
| k = kw.lower() | |
| if k in low: | |
| # capture small evidence window | |
| i = low.find(k) | |
| start = max(0, i - 40); end = min(len(text), i + len(k) + 40) | |
| hits.append(text[start:end].strip()) | |
| if hits: | |
| labels.append(lab) | |
| tasks.append({ | |
| "label": lab, | |
| "explanation": "Keyword match in transcript.", | |
| "evidence": hits[0] | |
| }) | |
| return {"labels": normalize_labels(labels), "tasks": tasks} | |
| # ========================= | |
| # Inference helpers | |
| # ========================= | |
| def build_keyword_context(allowed: List[str]) -> str: | |
| parts = [] | |
| for lab in allowed: | |
| kws = LABEL_KEYWORDS.get(lab, []) | |
| parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)")) | |
| return "\n".join(parts) | |
| def run_single( | |
| transcript_text: str, | |
| transcript_file, # filepath or file-like | |
| gt_json_text: str, | |
| gt_json_file, # filepath or file-like | |
| use_cleaning: bool, | |
| use_keyword_fallback: bool, | |
| allowed_labels_text: str, | |
| model_repo: str, | |
| use_4bit: bool, | |
| max_input_tokens: int, | |
| hf_token: str, | |
| ) -> Tuple[str, str, str, str, str, str, str]: | |
| t0 = _now_ms() | |
| # Transcript | |
| raw_text = "" | |
| if transcript_file: | |
| raw_text = read_text_file_any(transcript_file) | |
| raw_text = (raw_text or transcript_text or "").strip() | |
| if not raw_text: | |
| return "", "", "No transcript provided.", "", "", "", "" | |
| text = clean_transcript(raw_text) if use_cleaning else raw_text | |
| # Allowed labels (pre-filled defaults) | |
| user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()] | |
| allowed = normalize_labels(user_allowed or OFFICIAL_LABELS) | |
| # Model | |
| try: | |
| model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit) | |
| except Exception as e: | |
| return "", "", f"Model load failed: {e}", "", "", "", "" | |
| # Truncate | |
| trunc = truncate_tokens(model.tokenizer, text, max_input_tokens) | |
| # Build prompt | |
| allowed_list_str = "\n".join(f"- {l}" for l in allowed) | |
| keyword_ctx = build_keyword_context(allowed) | |
| user_prompt = USER_PROMPT_TEMPLATE.format( | |
| transcript=trunc, | |
| allowed_labels_list=allowed_list_str, | |
| keyword_context=keyword_ctx, | |
| ) | |
| # Generate | |
| t1 = _now_ms() | |
| try: | |
| out = model.generate(SYSTEM_PROMPT, user_prompt) | |
| except Exception as e: | |
| return "", "", f"Generation error: {e}", "", "", "", "" | |
| t2 = _now_ms() | |
| parsed = robust_json_extract(out) | |
| filtered = restrict_to_allowed(parsed, allowed) | |
| # Fallback if empty | |
| if use_keyword_fallback and not filtered.get("labels"): | |
| fb = keyword_fallback(trunc, allowed) | |
| if fb["labels"]: | |
| filtered = fb | |
| # Diagnostics | |
| diag = "\n".join([ | |
| f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})", | |
| f"Model: {model_repo}", | |
| f"Input cleaned: {'Yes' if use_cleaning else 'No'}", | |
| f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}", | |
| f"Tokens (input, approx): ≤ {max_input_tokens}", | |
| f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms", | |
| f"Allowed labels: {', '.join(allowed)}", | |
| ]) | |
| # Context & instructions preview shown in UI | |
| context_preview = ( | |
| "### Allowed Labels\n" | |
| + "\n".join(f"- {l}" for l in allowed) | |
| + "\n\n### Keyword cues per label\n" | |
| + keyword_ctx | |
| ) | |
| instructions_preview = "```\n" + SYSTEM_PROMPT + "\n```" | |
| # Summary & JSON | |
| labs = filtered.get("labels", []) | |
| tasks = filtered.get("tasks", []) | |
| summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)") | |
| if tasks: | |
| summary += "\n\nTasks:\n" + "\n".join( | |
| f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}" | |
| for t in tasks | |
| ) | |
| else: | |
| summary += "\n\nTasks: (none)" | |
| json_out = json.dumps(filtered, indent=2, ensure_ascii=False) | |
| # Optional single-file scoring if GT provided | |
| metrics = "" | |
| true_labels = None | |
| if gt_json_file or (gt_json_text and gt_json_text.strip()): | |
| truth_obj = None | |
| if gt_json_file: | |
| truth_obj = read_json_file_any(gt_json_file) | |
| if (not truth_obj) and gt_json_text: | |
| try: | |
| truth_obj = json.loads(gt_json_text) | |
| except Exception: | |
| pass | |
| if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list): | |
| true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS] | |
| pred_labels = labs | |
| try: | |
| score = evaluate_predictions([true_labels], [pred_labels]) | |
| tp = len(set(true_labels) & set(pred_labels)) | |
| fp = len(set(pred_labels) - set(true_labels)) | |
| fn = len(set(true_labels) - set(pred_labels)) | |
| recall = tp / (tp + fn) if (tp + fn) else 1.0 | |
| precision = tp / (tp + fp) if (tp + fp) else 1.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0 | |
| metrics = ( | |
| f"Weighted score: {score:.3f}\n" | |
| f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n" | |
| f"TP={tp} FP={fp} FN={fn}\n" | |
| f"Truth: {', '.join(true_labels)}" | |
| ) | |
| except Exception as e: | |
| metrics = f"Scoring error: {e}" | |
| else: | |
| metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}." | |
| return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics | |
| # ========================= | |
| # Batch mode (ZIP with transcripts + truths) | |
| # ========================= | |
| def read_zip_from_path(path: str, exdir: Path) -> List[Path]: | |
| exdir.mkdir(parents=True, exist_ok=True) | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| with zipfile.ZipFile(io.BytesIO(data)) as zf: | |
| zf.extractall(exdir) | |
| return [p for p in exdir.rglob("*") if p.is_file()] | |
| def run_batch( | |
| zip_path, # filepath string | |
| use_cleaning: bool, | |
| use_keyword_fallback: bool, | |
| model_repo: str, | |
| use_4bit: bool, | |
| max_input_tokens: int, | |
| hf_token: str, | |
| limit_files: int, | |
| ) -> Tuple[str, str, pd.DataFrame, str]: | |
| if not zip_path: | |
| return ("No ZIP provided.", "", pd.DataFrame(), "") | |
| work = Path("/tmp/batch") | |
| if work.exists(): | |
| for p in sorted(work.rglob("*"), reverse=True): | |
| try: p.unlink() | |
| except Exception: pass | |
| try: work.rmdir() | |
| except Exception: pass | |
| work.mkdir(parents=True, exist_ok=True) | |
| files = read_zip_from_path(zip_path, work) | |
| txts: Dict[str, Path] = {} | |
| gts: Dict[str, Path] = {} | |
| for p in files: | |
| if p.suffix.lower() == ".txt": | |
| txts[p.stem] = p | |
| elif p.suffix.lower() == ".json": | |
| gts[p.stem] = p | |
| stems = sorted(txts.keys()) | |
| if limit_files > 0: | |
| stems = stems[:limit_files] | |
| if not stems: | |
| return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "") | |
| try: | |
| model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit) | |
| except Exception as e: | |
| return (f"Model load failed: {e}", "", pd.DataFrame(), "") | |
| allowed = OFFICIAL_LABELS[:] | |
| allowed_list_str = "\n".join(f"- {l}" for l in allowed) | |
| keyword_ctx = build_keyword_context(allowed) | |
| y_true, y_pred = [], [] | |
| rows = [] | |
| t_start = _now_ms() | |
| for stem in stems: | |
| raw = txts[stem].read_text(encoding="utf-8", errors="ignore") | |
| text = clean_transcript(raw) if use_cleaning else raw | |
| trunc = truncate_tokens(model.tokenizer, text, max_input_tokens) | |
| user_prompt = USER_PROMPT_TEMPLATE.format( | |
| transcript=trunc, | |
| allowed_labels_list=allowed_list_str, | |
| keyword_context=keyword_ctx, | |
| ) | |
| t0 = _now_ms() | |
| out = model.generate(SYSTEM_PROMPT, user_prompt) | |
| t1 = _now_ms() | |
| parsed = robust_json_extract(out) | |
| filtered = restrict_to_allowed(parsed, allowed) | |
| if use_keyword_fallback and not filtered.get("labels"): | |
| fb = keyword_fallback(trunc, allowed) | |
| if fb["labels"]: | |
| filtered = fb | |
| pred_labels = filtered.get("labels", []) | |
| y_pred.append(pred_labels) | |
| gt_labels = [] | |
| if stem in gts: | |
| try: | |
| gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore")) | |
| if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list): | |
| gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS] | |
| except Exception: | |
| pass | |
| y_true.append(gt_labels) | |
| gt_set, pr_set = set(gt_labels), set(pred_labels) | |
| tp = sorted(gt_set & pr_set) | |
| fp = sorted(pr_set - gt_set) | |
| fn = sorted(gt_set - pr_set) | |
| rows.append({ | |
| "file": stem, | |
| "true_labels": ", "..join(gt_labels), | |
| "pred_labels": ", ".join(pred_labels), | |
| "TP": len(tp), "FP": len(fp), "FN": len(fn), | |
| "gen_ms": t1 - t0 | |
| }) | |
| have_truth = any(len(v) > 0 for v in y_true) | |
| score = evaluate_predictions(y_true, y_pred) if have_truth else None | |
| df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"]) | |
| diag = [ | |
| f"Processed files: {len(stems)}", | |
| f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})", | |
| f"Model: {model_repo}", | |
| f"Input cleaned: {'Yes' if use_cleaning else 'No'}", | |
| f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}", | |
| f"Tokens (input, approx): ≤ {max_input_tokens}", | |
| f"Batch time: {_now_ms()-t_start} ms", | |
| ] | |
| if have_truth and score is not None: | |
| total_tp = int(df["TP"].sum()) | |
| total_fp = int(df["FP"].sum()) | |
| total_fn = int(df["FN"].sum()) | |
| recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0 | |
| precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0 | |
| diag += [ | |
| f"Official weighted score (0–1): {score:.3f}", | |
| f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}", | |
| f"Total TP={total_tp} FP={total_fp} FN={total_fn}", | |
| ] | |
| diag_str = "\n".join(diag) | |
| # save CSV for download | |
| out_csv = Path("/tmp/batch_results.csv") | |
| df.to_csv(out_csv, index=False, encoding="utf-8") | |
| return ("Batch done.", diag_str, df, str(out_csv)) | |
| # ========================= | |
| # UI | |
| # ========================= | |
| MODEL_CHOICES = [ | |
| "swiss-ai/Apertus-8B-Instruct-2509", | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| ] | |
| custom_css = """ | |
| :root { --radius: 14px; } | |
| .gradio-container { font-family: Inter, ui-sans-serif, system-ui; } | |
| .card { border: 1px solid rgba(255,255,255,.08); border-radius: var(--radius); padding: 14px 16px; background: rgba(255,255,255,.02); box-shadow: 0 1px 10px rgba(0,0,0,.12) inset; } | |
| .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; } | |
| .subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; } | |
| hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; } | |
| .accordion-title { font-weight: 600; } | |
| .gr-button { border-radius: 12px !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo: | |
| gr.Markdown("<div class='header'>Talk2Task — Task Extraction (UBS Challenge)</div>") | |
| gr.Markdown("<div class='subtle'>False negatives are penalised 2× more than false positives in the official score. This UI biases for recall, shows the exact instructions & context, and supports single or batch evaluation.</div>") | |
| with gr.Tab("Single transcript"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("<div class='card'><div class='header'>Transcript</div>", elem_id="card1") | |
| file = gr.File( | |
| label="Drag & drop transcript (.txt / .md / .json)", | |
| file_types=[".txt", ".md", ".json"], | |
| type="filepath", | |
| ) | |
| text = gr.Textbox(label="Or paste transcript", lines=10) | |
| gr.Markdown("<hr class='sep'/>") | |
| gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>", elem_id="card1b") | |
| gt_file = gr.File( | |
| label="Upload ground truth JSON (expects {'labels': [...]})", | |
| file_types=[".json"], | |
| type="filepath", | |
| ) | |
| gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}') | |
| gr.Markdown("</div>") # close card | |
| gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>", elem_id="card2") | |
| use_cleaning = gr.Checkbox( | |
| label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", | |
| value=True, | |
| ) | |
| use_keyword_fallback = gr.Checkbox( | |
| label="Keyword fallback if model returns empty", | |
| value=True, | |
| ) | |
| gr.Markdown("</div>") | |
| gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>", elem_id="card3") | |
| labels_text = gr.Textbox( | |
| label="Allowed Labels (one per line)", | |
| value=OFFICIAL_LABELS_TEXT, # prefilled | |
| lines=8, | |
| ) | |
| reset_btn = gr.Button("Reset to official labels") | |
| gr.Markdown("</div>") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card4") | |
| repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True) | |
| max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096) | |
| hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN","")) | |
| run_btn = gr.Button("Run Extraction", variant="primary") | |
| gr.Markdown("</div>") | |
| gr.Markdown("<div class='card'><div class='header'>Outputs</div>", elem_id="card5") | |
| summary = gr.Textbox(label="Summary", lines=12) | |
| json_out = gr.Code(label="Strict JSON Output", language="json") | |
| diag = gr.Textbox(label="Diagnostics", lines=8) | |
| raw = gr.Textbox(label="Raw Model Output", lines=8) | |
| gr.Markdown("</div>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Instructions used (system prompt)", open=False): | |
| instr_md = gr.Markdown("") | |
| with gr.Column(): | |
| with gr.Accordion("Context used (allowed labels + keyword cues)", open=True): | |
| context_md = gr.Markdown("") | |
| # reset button behavior | |
| def _reset_labels(): | |
| return OFFICIAL_LABELS_TEXT | |
| reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text) | |
| # single run | |
| def _pack_context_md(allowed: str) -> str: | |
| allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()] | |
| ctx = build_keyword_context(allowed_list) | |
| return "### Allowed Labels\n" + "\n".join(f"- {l}" for l in allowed_list) + "\n\n### Keyword cues per label\n" + ctx | |
| run_btn.click( | |
| fn=run_single, | |
| inputs=[ | |
| text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback, | |
| labels_text, repo, use_4bit, max_tokens, hf_token | |
| ], | |
| outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False)], | |
| ) | |
| # also keep instructions visible at initial load | |
| instr_md.value = "```\n" + SYSTEM_PROMPT + "\n```" | |
| context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT) | |
| with gr.Tab("Batch evaluation"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("<div class='card'><div class='header'>ZIP input</div>", elem_id="card6") | |
| zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath") | |
| use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True) | |
| use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True) | |
| gr.Markdown("</div>") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card7") | |
| repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True) | |
| max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096) | |
| hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN","")) | |
| limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0) | |
| run_batch_btn = gr.Button("Run Batch", variant="primary") | |
| gr.Markdown("</div>") | |
| with gr.Row(): | |
| gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>", elem_id="card8") | |
| status = gr.Textbox(label="Status", lines=1) | |
| diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12) | |
| df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False) | |
| csv_out = gr.File(label="Download CSV", interactive=False) | |
| gr.Markdown("</div>") | |
| run_batch_btn.click( | |
| fn=run_batch, | |
| inputs=[zip_in, use_cleaning_b, use_keyword_fallback_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files], | |
| outputs=[status, diag_b, df_out, csv_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |