RishiRP commited on
Commit
f066995
·
verified ·
1 Parent(s): 28f5fab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +546 -0
app.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, re, sys, time, json, zipfile, statistics
2
+ from pathlib import Path
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Tuple
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
+
11
+ # ---------------- Constants / Labels ----------------
12
+
13
+ ALLOWED_LABELS = [
14
+ "plan_contact",
15
+ "schedule_meeting",
16
+ "update_contact_info_non_postal",
17
+ "update_contact_info_postal_address",
18
+ "update_kyc_activity",
19
+ "update_kyc_origin_of_assets",
20
+ "update_kyc_purpose_of_businessrelation",
21
+ "update_kyc_total_assets",
22
+ ]
23
+ LABEL_TO_IDX = {l:i for i,l in enumerate(ALLOWED_LABELS)}
24
+ FN_PENALTY = 2.0
25
+ FP_PENALTY = 1.0
26
+
27
+ # ---------------- Helpers ----------------
28
+
29
+ def safe_json_load(s: str):
30
+ try:
31
+ return json.loads(s)
32
+ except Exception:
33
+ pass
34
+ m = re.search(r'\{.*\}', s, re.S)
35
+ if m:
36
+ try:
37
+ return json.loads(m.group(0))
38
+ except Exception:
39
+ pass
40
+ return {"labels": [], "notes": "WARN: model output not valid JSON; fallback used"}
41
+
42
+ def _coerce_labels_list(x):
43
+ if isinstance(x, list):
44
+ out = []
45
+ for it in x:
46
+ if isinstance(it, str): out.append(it)
47
+ elif isinstance(it, dict):
48
+ for k in ("label","value","task","category","name"):
49
+ v = it.get(k)
50
+ if isinstance(v, str):
51
+ out.append(v); break
52
+ else:
53
+ if isinstance(it.get("labels"), list):
54
+ out += [s for s in it["labels"] if isinstance(s, str)]
55
+ # dedupe
56
+ seen=set(); norm=[]
57
+ for s in out:
58
+ if s not in seen:
59
+ norm.append(s); seen.add(s)
60
+ return norm
61
+ if isinstance(x, dict):
62
+ for k in ("expected_labels","labels","targets","y_true"):
63
+ if k in x: return _coerce_labels_list(x[k])
64
+ if "one_hot" in x and isinstance(x["one_hot"], dict):
65
+ return [k for k,v in x["one_hot"].items() if v]
66
+ return []
67
+
68
+ def classic_metrics(pred_labels, exp_labels):
69
+ pred_labels = [str(x) for x in (pred_labels or []) if isinstance(x, (str,int,float,bool))]
70
+ exp_labels = [str(x) for x in (exp_labels or []) if isinstance(x, (str,int,float,bool))]
71
+ pred = set(pred_labels); gold = set(exp_labels)
72
+ if not pred and not gold:
73
+ return True, 1.0, 1.0, 1.0, 1.0
74
+ inter = pred & gold; union = pred | gold
75
+ exact = (sorted(pred) == sorted(gold))
76
+ precision = (len(inter) / (len(pred) if pred else 1e-9))
77
+ recall = (len(inter) / (len(gold) if gold else 1e-9))
78
+ f1 = 0.0 if len(inter) == 0 else 2*len(inter) / (len(pred)+len(gold)+1e-9)
79
+ hamming = (len(inter) / (len(union) if union else 1e-9))
80
+ return exact, precision, recall, f1, hamming
81
+
82
+ def ubs_score_one(true_labels, pred_labels) -> float:
83
+ tset = [l for l in (true_labels or []) if l in LABEL_TO_IDX]
84
+ pset = [l for l in (pred_labels or []) if l in LABEL_TO_IDX]
85
+ n_labels = len(ALLOWED_LABELS)
86
+ tpos = set(tset); ppos = set(pset)
87
+ fn = sum(1 for l in ALLOWED_LABELS if (l in tpos and l not in ppos))
88
+ fp = sum(1 for l in ALLOWED_LABELS if (l not in tpos and l in ppos))
89
+ weighted = FN_PENALTY*fn + FP_PENALTY*fp
90
+ t_count = len(tpos)
91
+ max_err = FN_PENALTY*t_count + FP_PENALTY*(n_labels - t_count)
92
+ score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
93
+ return float(max(0.0, min(1.0, score)))
94
+
95
+ # ---------------- Preprocess ----------------
96
+
97
+ EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
98
+ TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
99
+ DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
100
+ MEET_RX = re.compile(r'\b(meet(ing)?|call|appointment|schedule|invite|agenda|online|in[- ]?person|phone|zoom|teams)\b', re.I)
101
+ MODAL_RX = re.compile(r'\b(online|in[- ]?person|phone|zoom|teams)\b', re.I)
102
+ SMALLTALK_RX = re.compile(r'^\s*(user|advisor):\s*(thanks( you)?|thank you|anything else|have a great day|you too)\b', re.I)
103
+
104
+ TYPO_FIXES = [
105
+ (re.compile(r'\bschedulin\s*g\b', re.I), 'scheduling'),
106
+ (re.compile(r'\beeting\b', re.I), 'meeting'),
107
+ (re.compile(r'\bdi?i?gtal\b', re.I), 'digital'),
108
+ (re.compile(r'\bdigi\s+tal\b', re.I), 'digital'),
109
+ (re.compile(r'\bspread\s*sheet\b', re.I), 'spreadsheet'),
110
+ (re.compile(r'\bseats\b', re.I), 'sheets'),
111
+ (re.compile(r'\bver(s|z)ion meters\b', re.I), 'version metrics'),
112
+ ]
113
+
114
+ def normalize_text(text: str, fix_typos: bool = True) -> str:
115
+ t = text.replace('\r\n', '\n')
116
+ t = re.sub(r'^\s*Speaker\s*1\s*:\s*', 'USER: ', t, flags=re.I | re.M)
117
+ t = re.sub(r'^\s*Speaker\s*2\s*:\s*', 'ADVISOR: ', t, flags=re.I | re.M)
118
+ t = re.sub(r'[ \t]+', ' ', t)
119
+ t = re.sub(r'\n{3,}', '\n\n', t)
120
+ if fix_typos:
121
+ for rx, rep in TYPO_FIXES:
122
+ t = rx.sub(rep, t)
123
+ return t.strip()
124
+
125
+ def extract_cues(text: str):
126
+ emails = EMAIL_RX.findall(text)
127
+ email_new, email_old = (emails[-1], emails[-2]) if len(emails)>=2 else ((emails[-1], None) if emails else (None, None))
128
+ has_time = bool(TIME_RX.search(text))
129
+ has_date = bool(DATE_RX.search(text))
130
+ has_meet = bool(MEET_RX.search(text))
131
+ modality = None
132
+ m = MODAL_RX.search(text)
133
+ if m:
134
+ modality = m.group(0).upper().replace('IN PERSON','IN_PERSON').replace('IN-PERSON','IN_PERSON')
135
+ meeting_confirmed = (has_meet and (has_time or has_date))
136
+ tm = TIME_RX.search(text)
137
+ norm_tm = tm.group(0) if tm else None
138
+ return {
139
+ "email_new": email_new,
140
+ "email_old": email_old,
141
+ "contact_pref": "EMAIL" if email_new else None,
142
+ "meeting_time_fragment": norm_tm,
143
+ "meeting_modality": modality,
144
+ "meeting_confirmed": meeting_confirmed
145
+ }
146
+
147
+ def build_cues_header(cues: dict) -> str:
148
+ has_any = any([cues.get("email_new"), cues.get("email_old"), cues.get("contact_pref"), cues.get("meeting_confirmed")])
149
+ if not has_any:
150
+ return ""
151
+ lines = ["[DETECTED_CUES]"]
152
+ if cues.get("email_new"): lines.append(f"EMAIL_NEW: {cues['email_new']}")
153
+ if cues.get("email_old"): lines.append(f"EMAIL_OLD: {cues['email_old']}")
154
+ if cues.get("contact_pref"): lines.append(f"CONTACT_PREF: {cues['contact_pref']}")
155
+ if cues.get("meeting_confirmed"):
156
+ mod = cues.get("meeting_modality") or ""
157
+ tm = cues.get("meeting_time_fragment") or ""
158
+ lines.append(f"MEETING: {(tm + ' ' + mod).strip()} CONFIRMED")
159
+ lines.append("[/DETECTED_CUES]")
160
+ return "\n".join(lines)
161
+
162
+ def find_cue_lines(lines):
163
+ idx = set()
164
+ for i, ln in enumerate(lines):
165
+ if EMAIL_RX.search(ln) or (MEET_RX.search(ln) and (TIME_RX.search(ln) or DATE_RX.search(ln))):
166
+ idx.add(i)
167
+ return sorted(idx)
168
+
169
+ def prune_by_window(lines, cue_idx, window=3, strip_smalltalk=False):
170
+ n = len(lines); keep = set()
171
+ for k in cue_idx:
172
+ lo, hi = max(0, k-window), min(n-1, k+window)
173
+ keep.update(range(lo,hi+1))
174
+ out=[]
175
+ for i, ln in enumerate(lines):
176
+ if i in keep:
177
+ if strip_smalltalk and SMALLTALK_RX.search(ln): continue
178
+ out.append(ln)
179
+ return out
180
+
181
+ # ---------------- HF Model wrapper ----------------
182
+
183
+ class HFModel:
184
+ def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
185
+ self.repo_id = repo_id
186
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=trust_remote_code)
187
+ quant = None
188
+ torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}.get(dtype, torch.bfloat16)
189
+
190
+ if load_4bit:
191
+ quant = BitsAndBytesConfig(load_in_4bit=True,
192
+ bnb_4bit_use_double_quant=True,
193
+ bnb_4bit_compute_dtype=torch_dtype,
194
+ bnb_4bit_quant_type="nf4")
195
+ self.model = AutoModelForCausalLM.from_pretrained(
196
+ repo_id, device_map="auto", trust_remote_code=trust_remote_code,
197
+ quantization_config=quant, torch_dtype=torch_dtype
198
+ )
199
+ else:
200
+ self.model = AutoModelForCausalLM.from_pretrained(
201
+ repo_id, device_map="auto", trust_remote_code=trust_remote_code,
202
+ torch_dtype=torch_dtype
203
+ )
204
+
205
+ self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
206
+ or getattr(self.model.config, "max_sequence_length", None) or 8192
207
+
208
+ def encode_len(self, text: str) -> int:
209
+ return len(self.tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids)
210
+
211
+ def apply_chat_template(self, system_text: str, user_text: str) -> str:
212
+ if getattr(self.tokenizer, "chat_template", None):
213
+ messages = [{"role":"system","content":system_text},
214
+ {"role":"user","content":user_text}]
215
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
216
+ return ("### System\n" + system_text.strip() + "\n\n" +
217
+ "### User\n" + user_text.strip() + "\n\n" +
218
+ "### Assistant\n")
219
+
220
+ @torch.inference_mode()
221
+ def generate_json(self, system_text: str, user_text: str, max_new_tokens: int = 256):
222
+ prompt = self.apply_chat_template(system_text, user_text)
223
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
224
+ t0 = time.perf_counter()
225
+ out = self.model.generate(
226
+ **inputs,
227
+ max_new_tokens=max_new_tokens,
228
+ do_sample=False,
229
+ temperature=None,
230
+ top_p=None,
231
+ eos_token_id=self.tokenizer.eos_token_id,
232
+ )
233
+ latency_ms = int((time.perf_counter() - t0) * 1000)
234
+ text = self.tokenizer.decode(out[0], skip_special_tokens=True)
235
+ if text.startswith(prompt):
236
+ text = text[len(prompt):]
237
+ return latency_ms, text, prompt
238
+
239
+ # ---------------- Core pipeline ----------------
240
+
241
+ def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
242
+ min_lines_keep: int = 30,
243
+ apply_only_if_ratio: float = 1.15) -> str:
244
+ ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
245
+ est = len(ids)
246
+ threshold = int(soft_cap_tokens * apply_only_if_ratio)
247
+ if est <= threshold:
248
+ return text
249
+ parts = text.splitlines()
250
+ if len(parts) <= min_lines_keep:
251
+ return text
252
+
253
+ # keep header + cue-like lines
254
+ keep_flags=[]
255
+ for ln in parts:
256
+ is_header = ln.startswith("[DETECTED_CUES]") or ln.startswith("[/DETECTED_CUES]") \
257
+ or ln.startswith("EMAIL_") or ln.startswith("CONTACT_") or ln.startswith("MEETING:")
258
+ is_cue = bool(EMAIL_RX.search(ln) or MEET_RX.search(ln) or DATE_RX.search(ln) or TIME_RX.search(ln))
259
+ keep_flags.append(is_header or is_cue)
260
+
261
+ pruned = [ln for ln, keep in zip(parts, keep_flags) if keep]
262
+ if len(pruned) < min_lines_keep:
263
+ pad_needed = min_lines_keep - len(pruned)
264
+ non_cue_lines = [ln for ln, keep in zip(parts, keep_flags) if not keep]
265
+ pruned = pruned + non_cue_lines[:pad_needed]
266
+
267
+ candidate = "\n".join(pruned)
268
+ cand_tokens = len(tokenizer(candidate, return_tensors=None, add_special_tokens=False).input_ids)
269
+ if cand_tokens > threshold:
270
+ mid = len(parts)//2
271
+ half = max(min_lines_keep//2, 50)
272
+ slice_parts = parts[max(0, mid-half): min(len(parts), mid+half)]
273
+ candidate2 = "\n".join(slice_parts)
274
+ candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
275
+ candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
276
+
277
+ if len(candidate.splitlines()) < min_lines_keep:
278
+ return text
279
+ return candidate
280
+
281
+ def enforce_rules(labels, transcript_text):
282
+ labels = set(labels or [])
283
+ if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
284
+ labels.add("schedule_meeting")
285
+ labels.discard("plan_contact")
286
+ if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
287
+ labels.add("update_contact_info_non_postal")
288
+ kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
289
+ if "update_kyc_activity" in labels and not kyc_rx.search(transcript_text):
290
+ labels.discard("update_kyc_activity")
291
+ return sorted(labels)
292
+
293
+ # ---------------- Gradio app logic ----------------
294
+
295
+ MODEL_CACHE: Dict[str, HFModel] = {}
296
+
297
+ def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
298
+ if repo_id not in MODEL_CACHE:
299
+ MODEL_CACHE[repo_id] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
300
+ return MODEL_CACHE[repo_id]
301
+
302
+ def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
303
+ """
304
+ Returns mapping: sample_id -> (transcript_text, expected_labels[])
305
+ Expect pairs: <id>.txt and <id>.json (json optional).
306
+ """
307
+ zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
308
+ names = zf.namelist()
309
+ samples = {}
310
+ for n in names:
311
+ p = Path(n)
312
+ if p.suffix.lower() == ".txt":
313
+ sample_id = p.stem
314
+ txt = zf.read(n).decode("utf-8", "replace")
315
+ samples.setdefault(sample_id, ["", []])[0] = txt
316
+ elif p.suffix.lower() == ".json":
317
+ sample_id = p.stem
318
+ try:
319
+ js = json.loads(zf.read(n).decode("utf-8", "replace"))
320
+ except Exception:
321
+ js = []
322
+ samples.setdefault(sample_id, ["", []])[1] = _coerce_labels_list(js)
323
+ return samples
324
+
325
+ def run_batch_ui(models_str, instructions_text, context_text, dataset_zip,
326
+ soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
327
+ repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
328
+
329
+ if not dataset_zip:
330
+ return pd.DataFrame(), None, "Please upload a ZIP with *.txt (+ optional matching *.json)."
331
+
332
+ models = [m.strip() for m in (models_str or "").split(",") if m.strip()]
333
+ if not models:
334
+ return pd.DataFrame(), None, "Please enter at least one model repo id (e.g., mistralai/Mistral-7B-Instruct-v0.2)."
335
+
336
+ try:
337
+ samples = parse_zip(dataset_zip)
338
+ except Exception as e:
339
+ return pd.DataFrame(), None, f"Failed to read ZIP: {e}"
340
+
341
+ rows = []
342
+ total_runs = 0
343
+ all_artifacts = io.BytesIO()
344
+ zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
345
+
346
+ for repo_id in models:
347
+ hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
348
+ for sample_id, (transcript_text, exp_labels) in samples.items():
349
+ if not transcript_text.strip():
350
+ continue
351
+ latencies = []
352
+ last_pred = None
353
+ for r in range(1, repeats+1):
354
+ if total_runs >= max_total_runs:
355
+ break
356
+
357
+ # ---- Preprocess
358
+ before_tok = hf.encode_len(transcript_text)
359
+ proc_text = transcript_text
360
+ if preprocess:
361
+ t_norm = normalize_text(proc_text, fix_typos=True)
362
+ lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
363
+ cue_lines = find_cue_lines(lines)
364
+ if cue_lines:
365
+ lines_kept = prune_by_window(lines, cue_lines, window=pre_window, strip_smalltalk=strip_smalltalk)
366
+ else:
367
+ lines_kept = [ln for ln in lines if not (strip_smalltalk and SMALLTALK_RX.search(ln))]
368
+ t_kept = "\n".join(lines_kept)
369
+ cues = extract_cues(t_kept)
370
+ header = build_cues_header(cues) if add_cues else ""
371
+ proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
372
+ proc_text = shrink_to_token_cap_by_lines(proc_text, soft_cap, hf.tokenizer)
373
+ if len(proc_text.splitlines()) < 30:
374
+ proc_text = t_norm
375
+ after_tok = hf.encode_len(proc_text)
376
+
377
+ system_text = instructions_text.strip()
378
+ user_text = context_text.strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
379
+
380
+ t0 = time.perf_counter()
381
+ latency_ms, raw_text, prompt = hf.generate_json(system_text, user_text, max_new_tokens=256)
382
+ latency_ms = int((time.perf_counter() - t0) * 1000) # includes tokenization overhead
383
+
384
+ out = safe_json_load(raw_text)
385
+ pred_labels = enforce_rules(out.get("labels", []), proc_text)
386
+
387
+ exact, prec, rec, f1, ham = classic_metrics(pred_labels, exp_labels)
388
+ ubs = ubs_score_one(exp_labels, pred_labels)
389
+
390
+ rows.append({
391
+ "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
392
+ "sample_id": sample_id,
393
+ "model": repo_id,
394
+ "is_summary": False,
395
+ "run_index": r,
396
+ "preprocess": preprocess,
397
+ "pre_window": pre_window,
398
+ "add_cues_header": add_cues,
399
+ "strip_smalltalk": strip_smalltalk,
400
+ "soft_cap": soft_cap,
401
+ "latency_ms": latency_ms,
402
+ "token_before": before_tok,
403
+ "token_after": after_tok,
404
+ "model_calls": 1,
405
+ "pred_labels": json.dumps(pred_labels, ensure_ascii=False),
406
+ "exp_labels": json.dumps(exp_labels, ensure_ascii=False),
407
+ "exact_match": exact,
408
+ "precision": round(prec, 6),
409
+ "recall": round(rec, 6),
410
+ "f1": round(f1, 6),
411
+ "hamming": round(ham, 6),
412
+ "ubs_score": round(ubs, 6),
413
+ })
414
+
415
+ # artifacts
416
+ base = f"{repo_id.replace('/','_')}/{sample_id}/pre{int(preprocess)}_win{pre_window}_cues{int(add_cues)}_small{int(strip_smalltalk)}_cap{soft_cap}_r{r}"
417
+ zout.writestr(base + "/PREPROCESSED.txt", proc_text)
418
+ zout.writestr(base + "/MODEL_OUTPUT.raw.txt", raw_text)
419
+ final_json = {
420
+ "labels": pred_labels,
421
+ "diagnostics": {
422
+ "model_name": repo_id,
423
+ "latency_ms": latency_ms,
424
+ "token_in_est_before": before_tok,
425
+ "token_in_est_after": after_tok,
426
+ "preprocess": preprocess,
427
+ "pre_window": pre_window,
428
+ "pre_add_cues_header": add_cues if preprocess else False,
429
+ "pre_strip_smalltalk": strip_smalltalk if preprocess else False,
430
+ "pre_soft_token_cap": soft_cap if preprocess else None,
431
+ "model_calls": 1
432
+ }
433
+ }
434
+ zout.writestr(base + "/FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
435
+
436
+ latencies.append(latency_ms)
437
+ last_pred = pred_labels
438
+ total_runs += 1
439
+
440
+ if latencies:
441
+ med = int(statistics.median(latencies))
442
+ exact, prec, rec, f1, ham = classic_metrics(last_pred, exp_labels) if last_pred is not None else (None,)*5
443
+ ubs = ubs_score_one(exp_labels, last_pred) if last_pred is not None else None
444
+ rows.append({
445
+ "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
446
+ "sample_id": sample_id,
447
+ "model": repo_id,
448
+ "is_summary": True,
449
+ "run_index": None,
450
+ "preprocess": preprocess,
451
+ "pre_window": pre_window,
452
+ "add_cues_header": add_cues,
453
+ "strip_smalltalk": strip_smalltalk,
454
+ "soft_cap": soft_cap,
455
+ "median_latency_ms": med,
456
+ "latency_ms": None,
457
+ "token_before": None,
458
+ "token_after": None,
459
+ "model_calls": None,
460
+ "pred_labels": json.dumps(last_pred or [], ensure_ascii=False),
461
+ "exp_labels": json.dumps(exp_labels or [], ensure_ascii=False),
462
+ "exact_match": exact,
463
+ "precision": round(prec, 6) if prec is not None else None,
464
+ "recall": round(rec, 6) if rec is not None else None,
465
+ "f1": round(f1, 6) if f1 is not None else None,
466
+ "hamming": round(ham, 6) if ham is not None else None,
467
+ "ubs_score": round(ubs, 6) if ubs is not None else None,
468
+ })
469
+
470
+ if total_runs >= max_total_runs:
471
+ break
472
+
473
+ zout.close()
474
+ df = pd.DataFrame(rows)
475
+ csv_bytes = df.to_csv(index=False).encode("utf-8")
476
+ return df, ("results.csv", csv_bytes), all_artifacts.getvalue()
477
+
478
+ # ---------------- Gradio UI ----------------
479
+
480
+ with gr.Blocks(title="From Talk to Task — HF Space") as demo:
481
+ gr.Markdown("# From Talk to Task — Batch Task Extraction (Hugging Face Space)")
482
+ with gr.Row():
483
+ models = gr.Textbox(label="Models (comma-separated HF repo IDs)", value="mistralai/Mistral-7B-Instruct-v0.2")
484
+ with gr.Row():
485
+ instructions = gr.Textbox(label="Instructions (System)", lines=8, value=(
486
+ "You are a task extraction assistant. "
487
+ "Always output valid JSON with a field \"labels\" (list of strings). "
488
+ "Use only from this set: "
489
+ + json.dumps(ALLOWED_LABELS)
490
+ + ". Return JSON only."
491
+ ))
492
+ with gr.Row():
493
+ context = gr.Textbox(label="Context (User prefix before transcript)", lines=6, value=(
494
+ "- plan_contact: conversation without a concrete meeting (no date/time)\n"
495
+ "- schedule_meeting: explicit date/time/modality confirmation\n"
496
+ "- update_contact_info_non_postal: changes to email/phone\n"
497
+ "- update_contact_info_postal_address: changes to mailing address\n"
498
+ "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
499
+ ))
500
+ with gr.Row():
501
+ dataset_zip = gr.File(label="Upload ZIP of transcripts (*.txt) + expected (*.json)", file_types=[".zip"])
502
+
503
+ gr.Markdown("### Parameters")
504
+ with gr.Row():
505
+ soft_cap = gr.Slider(1024, 32768, value=8192, step=512, label="Soft token cap")
506
+ preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
507
+ pre_window = gr.Slider(0, 6, value=3, step=1, label="Window ± lines around cues")
508
+ add_cues = gr.Checkbox(value=True, label="Add cues header")
509
+ strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
510
+ with gr.Row():
511
+ repeats = gr.Slider(1, 6, value=4, step=1, label="Repeats per config")
512
+ max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
513
+
514
+ gr.Markdown("### Model loading")
515
+ with gr.Row():
516
+ load_4bit = gr.Checkbox(value=True, label="Load in 4-bit (bitsandbytes, GPU)")
517
+ dtype = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
518
+ trust_remote_code = gr.Checkbox(value=True, label="Trust remote code")
519
+
520
+ run_btn = gr.Button("Run Batch")
521
+ with gr.Row():
522
+ table = gr.Dataframe(label="Results", interactive=False, wrap=True, height=400)
523
+ with gr.Row():
524
+ csv_dl = gr.File(label="Download CSV", interactive=False)
525
+ zip_dl = gr.File(label="Download Artifacts ZIP", interactive=False)
526
+ status = gr.Markdown("")
527
+
528
+ def _run(*args):
529
+ df, csv_pair, zip_bytes = run_batch_ui(*args)
530
+ if isinstance(df, pd.DataFrame) and not df.empty:
531
+ csv_name, csv_data = csv_pair
532
+ csv_buf = io.BytesIO(csv_data); csv_buf.name = csv_name
533
+ zip_buf = io.BytesIO(zip_bytes); zip_buf.name = "artifacts.zip"
534
+ return df, csv_buf, zip_buf, "Done."
535
+ else:
536
+ return pd.DataFrame(), None, None, csv_pair # csv_pair holds error string here
537
+
538
+ run_btn.click(
539
+ _run,
540
+ inputs=[models, instructions, context, dataset_zip,
541
+ soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
542
+ repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
543
+ outputs=[table, csv_dl, zip_dl, status]
544
+ )
545
+
546
+ demo.queue().launch()