RishiRP commited on
Commit
62c9ed8
·
verified ·
1 Parent(s): f214cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +353 -880
app.py CHANGED
@@ -1,937 +1,410 @@
1
- # app.py — From Talk to Task (robust snapshot loader, revision pinning)
2
- # Keeps your full feature set: Single + Batch, preprocessing, metrics, UBS score, artifacts.
3
- # Key fix: models are downloaded atomically via snapshot_download at a pinned revision
4
- # and then loaded from local dir to avoid partial shard errors (e.g., *-00003-of-00003.safetensors).
5
-
6
- import os, io, re, sys, time, json, zipfile, statistics
 
 
 
 
 
 
 
7
  from pathlib import Path
8
- from typing import List, Dict, Tuple, Union, Optional
9
 
10
  import gradio as gr
11
- import pandas as pd
12
- import torch
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
14
- from huggingface_hub import snapshot_download # <-- robust model fetch
15
 
16
- # ========= ZeroGPU support =========
17
- try:
18
- import spaces # available on HF Spaces
19
- except Exception:
20
- class _DummySpaces:
21
- def GPU(self, *args, **kwargs):
22
- def deco(f): return f
23
- return deco
24
- spaces = _DummySpaces()
25
-
26
- # ========= Persistent cache for Spaces =========
27
- # Ensures model files survive restarts and prevents re-downloading shards.
28
- os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
29
-
30
- # ========= Auth token =========
31
- HF_TOKEN = (
32
- os.getenv("HF_TOKEN")
33
- or os.getenv("HUGGINGFACE_HUB_TOKEN")
34
- or os.getenv("HUGGINGFACEHUB_API_TOKEN")
35
- )
36
 
37
- # Console warning at startup (helps when logs are open)
38
- if not HF_TOKEN:
39
- print(
40
- "[WARN] HF_TOKEN is not set. Gated models will fail. "
41
- "Set it in Space → Settings → Variables and secrets.",
42
- file=sys.stderr
 
43
  )
 
 
 
 
 
44
 
45
- # ========= Labels & metrics =========
46
- ALLOWED_LABELS = [
47
- "plan_contact",
48
- "schedule_meeting",
49
- "update_contact_info_non_postal",
50
- "update_contact_info_postal_address",
51
- "update_kyc_activity",
52
- "update_kyc_origin_of_assets",
53
- "update_kyc_purpose_of_businessrelation",
54
- "update_kyc_total_assets",
55
- ]
56
- LABEL_TO_IDX = {l: i for i, l in enumerate(ALLOWED_LABELS)}
57
- FN_PENALTY = 2.0
58
- FP_PENALTY = 1.0
59
 
60
- def safe_json_load(s: str):
61
- """Best-effort JSON extractor; returns {'labels': []} shape on fallback."""
62
- try:
63
- return json.loads(s)
64
- except Exception:
65
- pass
66
- m = re.search(r"\{.*\}", s, re.S)
67
- if m:
68
- try:
69
- return json.loads(m.group(0))
70
- except Exception:
71
- pass
72
- return {"labels": [], "notes": "WARN: model output not valid JSON; fallback used"}
73
-
74
- def _coerce_labels_list(x):
75
- if isinstance(x, list):
76
- out = []
77
- for it in x:
78
- if isinstance(it, str): out.append(it)
79
- elif isinstance(it, dict):
80
- for k in ("label", "value", "task", "category", "name"):
81
- v = it.get(k)
82
- if isinstance(v, str):
83
- out.append(v); break
84
- else:
85
- if isinstance(it.get("labels"), list):
86
- out += [s for s in it["labels"] if isinstance(s, str)]
87
- # dedupe keep order
88
- seen = set(); norm = []
89
- for s in out:
90
- if s not in seen:
91
- norm.append(s); seen.add(s)
92
- return norm
93
- if isinstance(x, dict):
94
- for k in ("expected_labels", "labels", "targets", "y_true"):
95
- if k in x: return _coerce_labels_list(x[k])
96
- if "one_hot" in x and isinstance(x["one_hot"], dict):
97
- return [k for k, v in x["one_hot"].items() if v]
98
- return []
99
-
100
- def classic_metrics(pred_labels, exp_labels):
101
- pred = set([str(x) for x in (pred_labels or []) if isinstance(x, (str,int,float,bool))])
102
- gold = set([str(x) for x in (exp_labels or []) if isinstance(x, (str,int,float,bool))])
103
- if not pred and not gold:
104
- return True, 1.0, 1.0, 1.0, 1.0
105
- inter = pred & gold; union = pred | gold
106
- exact = (sorted(pred) == sorted(gold))
107
- precision = (len(inter) / (len(pred) if pred else 1e-9))
108
- recall = (len(inter) / (len(gold) if gold else 1e-9))
109
- f1 = 0.0 if len(inter) == 0 else 2*len(inter) / (len(pred)+len(gold)+1e-9)
110
- hamming = (len(inter) / (len(union) if union else 1e-9))
111
- return exact, precision, recall, f1, hamming
112
-
113
- def ubs_score_one(true_labels, pred_labels) -> float:
114
- tset = [l for l in (true_labels or []) if l in LABEL_TO_IDX]
115
- pset = [l for l in (pred_labels or []) if l in LABEL_TO_IDX]
116
- n_labels = len(ALLOWED_LABELS)
117
- tpos = set(tset); ppos = set(pset)
118
- fn = sum(1 for l in ALLOWED_LABELS if (l in tpos and l not in ppos))
119
- fp = sum(1 for l in ALLOWED_LABELS if (l not in tpos and l in ppos))
120
- weighted = FN_PENALTY*fn + FP_PENALTY*fp
121
- t_count = len(tpos)
122
- max_err = FN_PENALTY*t_count + FP_PENALTY*(n_labels - t_count)
123
- score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
124
- return float(max(0.0, min(1.0, score)))
125
-
126
- # ========= Lightweight preprocessing =========
127
- EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
128
- TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
129
- DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
130
- MEET_RX = re.compile(r'\b(meet(ing)?|call|appointment|schedule|invite|agenda|online|in[- ]?person|phone|zoom|teams)\b', re.I)
131
- MODAL_RX = re.compile(r'\b(online|in[- ]?person|phone|zoom|teams)\b', re.I)
132
- SMALLTALK_RX = re.compile(r'^\s*(user|advisor):\s*(thanks( you)?|thank you|anything else|have a great day|you too)\b', re.I)
133
-
134
- TYPO_FIXES = [
135
- (re.compile(r'\bschedulin\s*g\b', re.I), 'scheduling'),
136
- (re.compile(r'\beeting\b', re.I), 'meeting'),
137
- (re.compile(r'\bdi?i?gtal\b', re.I), 'digital'),
138
- (re.compile(r'\bdigi\s+tal\b', re.I), 'digital'),
139
- (re.compile(r'\bspread\s*sheet\b', re.I), 'spreadsheet'),
140
- (re.compile(r'\bseats\b', re.I), 'sheets'),
141
- (re.compile(r'\bver(s|z)ion meters\b', re.I), 'version metrics'),
142
  ]
143
 
144
- def normalize_text(text: str, fix_typos: bool = True) -> str:
145
- t = text.replace('\r\n', '\n')
146
- t = re.sub(r'^\s*Speaker\s*1\s*:\s*', 'USER: ', t, flags=re.I | re.M)
147
- t = re.sub(r'^\s*Speaker\s*2\s*:\s*', 'ADVISOR: ', t, flags=re.I | re.M)
148
- t = re.sub(r'[ \t]+', ' ', t)
149
- t = re.sub(r'\n{3,}', '\n\n', t)
150
- if fix_typos:
151
- for rx, rep in TYPO_FIXES:
152
- t = rx.sub(rep, t)
153
- return t.strip()
154
-
155
- def extract_cues(text: str):
156
- emails = EMAIL_RX.findall(text)
157
- email_new, email_old = (emails[-1], emails[-2]) if len(emails)>=2 else ((emails[-1], None) if emails else (None, None))
158
- has_time = bool(TIME_RX.search(text))
159
- has_date = bool(DATE_RX.search(text))
160
- has_meet = bool(MEET_RX.search(text))
161
- modality = None
162
- m = MODAL_RX.search(text)
163
- if m:
164
- modality = m.group(0).upper().replace('IN PERSON','IN_PERSON').replace('IN-PERSON','IN_PERSON')
165
- meeting_confirmed = (has_meet and (has_time or has_date))
166
- tm = TIME_RX.search(text)
167
- norm_tm = tm.group(0) if tm else None
168
- return {
169
- "email_new": email_new,
170
- "email_old": email_old,
171
- "contact_pref": "EMAIL" if email_new else None,
172
- "meeting_time_fragment": norm_tm,
173
- "meeting_modality": modality,
174
- "meeting_confirmed": meeting_confirmed
175
- }
176
-
177
- def build_cues_header(cues: dict) -> str:
178
- has_any = any([cues.get("email_new"), cues.get("email_old"), cues.get("contact_pref"), cues.get("meeting_confirmed")])
179
- if not has_any:
180
- return ""
181
- lines = ["[DETECTED_CUES]"]
182
- if cues.get("email_new"): lines.append(f"EMAIL_NEW: {cues['email_new']}")
183
- if cues.get("email_old"): lines.append(f"EMAIL_OLD: {cues['email_old']}")
184
- if cues.get("contact_pref"): lines.append(f"CONTACT_PREF: {cues['contact_pref']}")
185
- if cues.get("meeting_confirmed"):
186
- mod = cues.get("meeting_modality") or ""
187
- tm = cues.get("meeting_time_fragment") or ""
188
- lines.append(f"MEETING: {(tm + ' ' + mod).strip()} CONFIRMED")
189
- lines.append("[/DETECTED_CUES]")
190
- return "\n".join(lines)
191
-
192
- def find_cue_lines(lines):
193
- idx = set()
194
- for i, ln in enumerate(lines):
195
- if EMAIL_RX.search(ln) or (MEET_RX.search(ln) and (TIME_RX.search(ln) or DATE_RX.search(ln))):
196
- idx.add(i)
197
- return sorted(idx)
198
-
199
- def prune_by_window(lines, cue_idx, window=3, strip_smalltalk=False):
200
- n = len(lines); keep = set()
201
- for k in cue_idx:
202
- lo, hi = max(0, k-window), min(n-1, k+window)
203
- keep.update(range(lo,hi+1))
204
- out=[]
205
- for i, ln in enumerate(lines):
206
- if i in keep:
207
- if strip_smalltalk and SMALLTALK_RX.search(ln): continue
208
- out.append(ln)
209
- return out
210
-
211
- def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
212
- min_lines_keep: int = 30,
213
- apply_only_if_ratio: float = 1.15) -> str:
214
- ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
215
- est = len(ids)
216
- threshold = int(soft_cap_tokens * apply_only_if_ratio)
217
- if est <= threshold: return text
218
- parts = text.splitlines()
219
- if len(parts) <= min_lines_keep: return text
220
-
221
- keep_flags=[]
222
- for ln in parts:
223
- is_header = ln.startswith("[DETECTED_CUES]") or ln.startswith("[/DETECTED_CUES]") \
224
- or ln.startswith("EMAIL_") or ln.startswith("CONTACT_") or ln.startswith("MEETING:")
225
- is_cue = bool(EMAIL_RX.search(ln) or MEET_RX.search(ln) or DATE_RX.search(ln) or TIME_RX.search(ln))
226
- keep_flags.append(is_header or is_cue)
227
-
228
- pruned = [ln for ln, keep in zip(parts, keep_flags) if keep]
229
- if len(pruned) < min_lines_keep:
230
- pad_needed = min_lines_keep - len(pruned)
231
- non_cue_lines = [ln for ln, keep in zip(parts, keep_flags) if not keep]
232
- pruned = pruned + non_cue_lines[:pad_needed]
233
-
234
- candidate = "\n".join(pruned)
235
- cand_tokens = len(tokenizer(candidate, return_tensors=None, add_special_tokens=False).input_ids)
236
- if cand_tokens > threshold:
237
- mid = len(parts)//2
238
- half = max(min_lines_keep//2, 50)
239
- slice_parts = parts[max(0, mid-half): min(len(parts), mid+half)]
240
- candidate2 = "\n".join(slice_parts)
241
- candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
242
- candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
243
-
244
- if len(candidate.splitlines()) < min_lines_keep: return text
245
- return candidate
246
-
247
- def enforce_rules(labels, transcript_text):
248
- labels = set(labels or [])
249
- if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
250
- labels.add("schedule_meeting"); labels.discard("plan_contact")
251
- if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
252
- labels.add("update_contact_info_non_postal")
253
- kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
254
- if "update_kyc_activity" in labels and not kyc_rx.search(transcript_text):
255
- labels.discard("update_kyc_activity")
256
- return sorted(labels)
257
-
258
- # ========= Revision pinning =========
259
- # Map repo_id -> default revision (None -> "main").
260
- MODEL_REVISIONS: Dict[str, Optional[str]] = {
261
- "mistralai/Mistral-7B-Instruct-v0.2": None, # set an env var to pin a commit if desired
262
- "Qwen/Qwen2.5-7B-Instruct": None,
263
  "HuggingFaceH4/zephyr-7b-beta": None,
 
264
  "tiiuae/falcon-7b-instruct": None,
 
265
  }
266
 
267
- def _slug_repo_id(repo_id: str) -> str:
268
- return re.sub(r"[^A-Za-z0-9]", "_", repo_id).upper()
269
-
270
- def resolve_revision(repo_id: str) -> str:
271
- """Order: env var MODEL_REVISION__<ORG_MODEL> > dict default > 'main'."""
272
- env_key = f"MODEL_REVISION__{_slug_repo_id(repo_id)}"
273
- env_rev = os.getenv(env_key, "").strip()
274
- if env_rev:
275
- return env_rev
276
- default_rev = MODEL_REVISIONS.get(repo_id)
277
- return (default_rev.strip() if isinstance(default_rev, str) and default_rev.strip() else "main")
278
-
279
- def ensure_local_dir(repo_id: str) -> str:
280
- """Download a pinned snapshot to cache and return its local path."""
281
- rev = resolve_revision(repo_id)
282
- local_dir = snapshot_download(
283
- repo_id=repo_id,
284
- revision=rev,
285
- allow_patterns=[
286
- "*.json", "*.safetensors", "*.bin", "*.model",
287
- "tokenizer.*", "config.json", "generation_config.json", "*.py"
288
- ],
289
- resume_download=True,
290
- local_dir=None, # use HF cache under HF_HOME
291
- local_dir_use_symlinks=False,
292
- token=HF_TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
294
- return local_dir
295
 
296
- # ========= HF model wrapper (loads from local snapshot) =========
297
  class HFModel:
298
- def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
 
 
 
 
 
 
 
 
299
  self.repo_id = repo_id
300
- self.revision = resolve_revision(repo_id)
301
- # Always load from a complete local snapshot to avoid partial shards
302
- self.local_dir = ensure_local_dir(repo_id)
 
 
 
 
303
 
304
- self.tokenizer = AutoTokenizer.from_pretrained(
305
- self.local_dir, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN
306
- )
307
- torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}.get(dtype, torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
- self.model = None
310
- if load_4bit:
311
- try:
312
- q = BitsAndBytesConfig(
313
- load_in_4bit=True, bnb_4bit_use_double_quant=True,
314
- bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_quant_type="nf4"
315
- )
316
- self.model = AutoModelForCausalLM.from_pretrained(
317
- self.local_dir, device_map="auto", trust_remote_code=trust_remote_code,
318
- quantization_config=q, torch_dtype=torch_dtype, token=HF_TOKEN
319
- )
320
- except Exception as e:
321
- print(f"[WARN] 4-bit load failed for {repo_id}@{self.revision}: {e}\nFalling back to normal load...", file=sys.stderr)
322
- if self.model is None:
323
  self.model = AutoModelForCausalLM.from_pretrained(
324
- self.local_dir, device_map="auto", trust_remote_code=trust_remote_code,
325
- torch_dtype=torch_dtype, token=HF_TOKEN
 
 
 
 
 
 
 
326
  )
327
-
328
- self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
329
- or getattr(self.model.config, "max_sequence_length", None) or 8192
330
-
331
- def apply_chat_template(self, system_text: str, user_text: str) -> str:
332
- if getattr(self.tokenizer, "chat_template", None):
333
- messages = [{"role":"system","content":system_text},
334
- {"role":"user","content":user_text}]
335
- return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
336
- return ("### System\n" + system_text.strip() + "\n\n" +
337
- "### User\n" + user_text.strip() + "\n\n" +
338
- "### Assistant\n")
339
 
340
  @torch.inference_mode()
341
- def generate_json(self, system_text: str, user_text: str, max_new_tokens: int = 256):
342
- prompt = self.apply_chat_template(system_text, user_text)
343
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
344
- t0 = time.perf_counter()
345
- out = self.model.generate(
346
- **inputs, max_new_tokens=max_new_tokens,
347
- do_sample=False, temperature=None, top_p=None,
348
- eos_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
349
  )
350
- latency_ms = int((time.perf_counter() - t0) * 1000)
351
- text = self.tokenizer.decode(out[0], skip_special_tokens=True)
352
- if text.startswith(prompt): text = text[len(prompt):]
353
- return latency_ms, text, prompt
354
-
355
- # Cache now includes revision implicitly via HFModel (we also add revision to key)
356
- MODEL_CACHE: Dict[Tuple[str, bool, str, bool, str], HFModel] = {}
357
- def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
358
- rev = resolve_revision(repo_id)
359
- key = (repo_id, bool(load_4bit), dtype, bool(trust_remote_code), rev)
360
- if key not in MODEL_CACHE:
361
- MODEL_CACHE[key] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
362
- return MODEL_CACHE[key]
363
-
364
- # ========= ZeroGPU functions =========
365
- @spaces.GPU(duration=180, secrets=["HF_TOKEN"]) # pass token into ZeroGPU job
366
- def gpu_generate(repo_id: str, system_text: str, user_text: str,
367
- load_4bit: bool, dtype: str, trust_remote_code: bool):
368
- token_seen = bool(os.getenv("HF_TOKEN"))
369
- hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
370
- lat, txt, prmpt = hf.generate_json(system_text.strip(), user_text.strip(), max_new_tokens=256)
371
- return lat, txt, prmpt, token_seen
372
-
373
- @spaces.GPU(duration=15, secrets=["HF_TOKEN"])
374
- def gpu_check_token():
375
- return bool(os.getenv("HF_TOKEN"))
376
-
377
- # ========= ZIP helpers =========
378
- def _read_zip_bytes(dataset_zip: Union[bytes, str, dict, None]) -> bytes:
379
- if dataset_zip is None: raise ValueError("No ZIP provided")
380
- if isinstance(dataset_zip, bytes): return dataset_zip
381
- if isinstance(dataset_zip, str):
382
- with open(dataset_zip, "rb") as f: return f.read()
383
- if isinstance(dataset_zip, dict) and "path" in dataset_zip:
384
- with open(dataset_zip["path"], "rb") as f: return f.read()
385
- path = getattr(dataset_zip, "name", None)
386
- if path and os.path.exists(path):
387
- with open(path, "rb") as f: return f.read()
388
- raise ValueError("Unsupported file object from Gradio")
389
-
390
- def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
391
- zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
392
- samples = {}
393
- for n in zf.namelist():
394
- p = Path(n)
395
- if p.suffix.lower() == ".txt":
396
- samples.setdefault(p.stem, ["", []])[0] = zf.read(n).decode("utf-8", "replace")
397
- elif p.suffix.lower() == ".json":
398
- try:
399
- js = json.loads(zf.read(n).decode("utf-8", "replace"))
400
- except Exception:
401
- js = []
402
- samples.setdefault(p.stem, ["", []])[1] = _coerce_labels_list(js)
403
- return samples
404
-
405
- # ========= Prompts =========
406
- DEFAULT_SYSTEM = (
407
- "You are a task extraction assistant. "
408
- "Always output valid JSON with a field \"labels\" (list of strings). "
409
- "Use only from this set: " + json.dumps(ALLOWED_LABELS) + ". "
410
- "Return JSON only."
411
- )
412
- DEFAULT_CONTEXT = (
413
- "- plan_contact: conversation without a concrete meeting (no date/time)\n"
414
- "- schedule_meeting: explicit date/time/modality confirmation\n"
415
- "- update_contact_info_non_postal: changes to email/phone\n"
416
- "- update_contact_info_postal_address: changes to mailing address\n"
417
- "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
418
- )
419
-
420
- # ========= Preprocess + build input =========
421
- def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window: int,
422
- add_cues: bool, strip_smalltalk: bool, tokenizer) -> Tuple[str, int, int]:
423
- before = len(tokenizer(raw_txt, return_tensors=None, add_special_tokens=False).input_ids)
424
- proc_text = raw_txt
425
- if preprocess:
426
- t_norm = normalize_text(proc_text, fix_typos=True)
427
- lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
428
- cue_lines = find_cue_lines(lines)
429
- if cue_lines:
430
- kept = prune_by_window(lines, cue_lines, window=pre_window, strip_smalltalk=strip_smalltalk)
431
- else:
432
- kept = [ln for ln in lines if not (strip_smalltalk and SMALLTALK_RX.search(ln))]
433
- t_kept = "\n".join(kept)
434
- cues = extract_cues(t_kept)
435
- header = build_cues_header(cues) if add_cues else ""
436
- proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
437
- proc_text = shrink_to_token_cap_by_lines(proc_text, soft_cap, tokenizer)
438
- if len(proc_text.splitlines()) < 30:
439
- proc_text = t_norm
440
- after = len(tokenizer(proc_text, return_tensors=None, add_special_tokens=False).input_ids)
441
- return proc_text, before, after
442
-
443
- def explain_params_markdown() -> str:
444
- return (
445
- "**Parameter help** \n"
446
- "- **Soft token cap**: target max input size; we prune long transcripts toward this size to save latency. \n"
447
- "- **Enable preprocessing**: normalizes speaker tags, fixes obvious typos, and focuses on cue lines. \n"
448
- "- **Window ± lines around cues**: how many lines we keep around detected cues (dates/emails/‘meeting’, etc.). \n"
449
- "- **Add cues header**: inserts a short summary block (email, meeting signal) above the transcript to guide the model. \n"
450
- "- **Strip smalltalk**: removes lines like ‘thanks, bye’ to keep only useful content. \n"
451
- "- **Load in 4-bit (GPU only)**: memory-saving quantization; has no effect on CPU Spaces."
452
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
- # ========= Single mode =========
455
- def single_mode(
456
- preset_model: str, custom_model: str,
457
- system_text: str, context_text: str,
458
- transcript_text: str, transcript_file,
459
- expected_labels_json,
460
- soft_cap: int, preprocess: bool, pre_window: int, add_cues: bool, strip_smalltalk: bool,
461
- load_4bit: bool, dtype: str, trust_remote_code: bool
462
- ):
463
- repo_id = custom_model.strip() or preset_model.strip()
464
- if not repo_id:
465
- return "Please choose a model.", "", "", "", None, None, None, ""
466
-
467
- txt = (transcript_text or "").strip()
468
- if transcript_file and hasattr(transcript_file, "name") and os.path.exists(transcript_file.name):
469
- with open(transcript_file.name, "r", encoding="utf-8", errors="replace") as f:
470
- txt = f.read()
471
- if not txt:
472
- return "Please paste a transcript or upload a .txt file.", "", "", "", None, None, None, ""
473
-
474
- exp = []
475
- if expected_labels_json and hasattr(expected_labels_json, "name") and os.path.exists(expected_labels_json.name):
476
- try:
477
- with open(expected_labels_json.name, "r", encoding="utf-8", errors="replace") as f:
478
- exp = _coerce_labels_list(json.load(f))
479
- except Exception:
480
- exp = []
481
-
482
- # tokenizer for preprocessing — from local snapshot to avoid streaming
483
  try:
484
- local_dir = ensure_local_dir(repo_id)
485
- dummy_tok = AutoTokenizer.from_pretrained(local_dir, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
 
 
 
 
486
  except Exception as e:
487
- msg = (f"Failed to load tokenizer for `{repo_id}`. "
488
- "If gated, accept license and set HF_TOKEN in Space → Settings → Secrets.\n\nError: " + str(e))
489
- return msg, "", "", "", None, None, None, banner_text()
490
-
491
- proc_text, tok_before, tok_after = prepare_input_text(
492
- txt, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
493
- )
494
- system = (system_text or DEFAULT_SYSTEM).strip()
495
- user = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
  try:
498
- latency_ms, raw_text, _prompt, gpu_token_seen = gpu_generate(
499
- repo_id, system, user, load_4bit, dtype, trust_remote_code
500
- )
501
  except Exception as e:
502
- msg = (f"Failed to run `{repo_id}`. If gated, accept license and set HF_TOKEN.\n\nError: {e}")
503
- return msg, "", "", "", None, None, None, banner_text()
504
-
505
- out = safe_json_load(raw_text)
506
- pred_labels = enforce_rules(out.get("labels", []), proc_text)
507
- exact, prec, rec, f1, ham = classic_metrics(pred_labels, exp)
508
- ubs = ubs_score_one(exp, pred_labels) if exp else None
509
-
510
- kpi1 = f"**F1**\n\n{f1:.3f}" if exp else "**F1**\n\n—"
511
- kpi2 = f"**UBS score**\n\n{ubs:.3f}" if ubs is not None else "**UBS score**\n\n—"
512
- kpi3 = f"**Latency (ms)**\n\n{latency_ms}"
513
-
514
- zbuf = io.BytesIO()
515
- with zipfile.ZipFile(zbuf, "w", zipfile.ZIP_DEFLATED) as zout:
516
- zout.writestr("PREPROCESSED.txt", proc_text)
517
- zout.writestr("MODEL_OUTPUT.raw.txt", raw_text)
518
- final_json = {
519
- "labels": pred_labels,
520
- "diagnostics": {
521
- "model_name": repo_id,
522
- "latency_ms": latency_ms,
523
- "token_in_est_before": tok_before,
524
- "token_in_est_after": tok_after,
525
- "preprocess": preprocess,
526
- "pre_window": pre_window,
527
- "pre_add_cues_header": add_cues if preprocess else False,
528
- "pre_strip_smalltalk": strip_smalltalk if preprocess else False,
529
- "pre_soft_token_cap": soft_cap if preprocess else None,
530
- "model_calls": 1
531
- },
532
- "evaluation": None if not exp else {
533
- "exact_match": exact, "precision": prec, "recall": rec,
534
- "f1": f1, "hamming": ham, "ubs_score": ubs
535
- }
536
- }
537
- zout.writestr("FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
538
- zbuf.seek(0); zbuf.name = "artifacts_single.zip"
539
-
540
- row = pd.DataFrame([{
541
- "model": repo_id,
542
- "latency_ms": latency_ms,
543
- "token_before": tok_before,
544
- "token_after": tok_after,
545
- "model_calls": 1,
546
- "pred_labels": json.dumps(pred_labels, ensure_ascii=False),
547
- "exp_labels": json.dumps(exp, ensure_ascii=False),
548
- "exact_match": exact if exp else None,
549
- "precision": round(prec,6) if exp else None,
550
- "recall": round(rec,6) if exp else None,
551
- "f1": round(f1,6) if exp else None,
552
- "hamming": round(ham,6) if exp else None,
553
- "ubs_score": round(ubs,6) if ubs is not None else None
554
- }])
555
-
556
- csv_buf = io.BytesIO(row.to_csv(index=False).encode("utf-8")); csv_buf.name = "results_single.csv"
557
-
558
- return (
559
- "Done.",
560
- kpi1, kpi2, kpi3,
561
- row, csv_buf, zbuf,
562
- banner_text(gpu_token_seen)
563
- )
564
-
565
- # ========= Batch mode =========
566
- def run_batch_ui(models_list, custom_models_str, instructions_text, context_text, dataset_zip,
567
- soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
568
- repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
569
-
570
- models = [m for m in (models_list or [])]
571
- models += [m.strip() for m in (custom_models_str or "").split(",") if m.strip()]
572
- if not models:
573
- return pd.DataFrame(), None, None, "Please pick at least one model.", banner_text()
574
-
575
- if not dataset_zip:
576
- return pd.DataFrame(), None, None, "Please upload a ZIP with *.txt (+ optional matching *.json).", banner_text()
577
 
578
  try:
579
- zip_bytes = _read_zip_bytes(dataset_zip)
580
- samples = parse_zip(zip_bytes)
581
  except Exception as e:
582
- return pd.DataFrame(), None, None, f"Failed to read ZIP: {e}", banner_text()
583
-
584
- rows = []; total_runs = 0
585
- all_artifacts = io.BytesIO()
586
- zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
587
- last_gpu_token_seen = None
588
 
589
- for repo_id in models:
590
- # tokenizer for preprocessing (auth check) — also from local snapshot
591
  try:
592
- local_dir = ensure_local_dir(repo_id)
593
- dummy_tok = AutoTokenizer.from_pretrained(local_dir, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
 
 
 
 
 
 
594
  except Exception as e:
595
- # gated or missing token; record a summary row and continue
596
- rows.append({
597
- "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
598
- "sample_id": None,
599
- "model": repo_id,
600
- "is_summary": True,
601
- "run_index": None,
602
- "preprocess": preprocess,
603
- "pre_window": pre_window,
604
- "add_cues_header": add_cues,
605
- "strip_smalltalk": strip_smalltalk,
606
- "soft_cap": soft_cap,
607
- "median_latency_ms": None,
608
- "latency_ms": None,
609
- "token_before": None,
610
- "token_after": None,
611
- "model_calls": None,
612
- "pred_labels": "[]",
613
- "exp_labels": "[]",
614
- "exact_match": None,
615
- "precision": None,
616
- "recall": None,
617
- "f1": None,
618
- "hamming": None,
619
- "ubs_score": None,
620
- })
621
- continue
622
-
623
- for sample_id, (transcript_text, exp_labels) in samples.items():
624
- if not transcript_text.strip(): continue
625
- latencies = []; last_pred = None
626
- for r in range(1, repeats+1):
627
- if total_runs >= max_total_runs: break
628
- proc_text, before_tok, after_tok = prepare_input_text(
629
- transcript_text, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
630
- )
631
- system_text = (instructions_text or DEFAULT_SYSTEM).strip()
632
- user_text = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
633
-
634
- try:
635
- latency_ms, raw_text, _prompt, token_seen = gpu_generate(
636
- repo_id, system_text, user_text, load_4bit, dtype, trust_remote_code
637
- )
638
- last_gpu_token_seen = token_seen
639
- except Exception as e:
640
- base = f"{repo_id.replace('/','_')}/{sample_id}/error_r{r}"
641
- zout.writestr(base + "/ERROR.txt", f"Failed to run model via @spaces.GPU. If gated, accept license and set HF_TOKEN.\n\n{e}")
642
- total_runs += 1
643
- continue
644
-
645
- out = safe_json_load(raw_text)
646
- pred_labels = enforce_rules(out.get("labels", []), proc_text)
647
-
648
- exact, prec, rec, f1, ham = classic_metrics(pred_labels, exp_labels)
649
- ubs = ubs_score_one(exp_labels, pred_labels)
650
-
651
- rows.append({
652
- "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
653
- "sample_id": sample_id,
654
- "model": repo_id,
655
- "is_summary": False,
656
- "run_index": r,
657
- "preprocess": preprocess,
658
- "pre_window": pre_window,
659
- "add_cues_header": add_cues,
660
- "strip_smalltalk": strip_smalltalk,
661
- "soft_cap": soft_cap,
662
- "latency_ms": latency_ms,
663
- "token_before": before_tok,
664
- "token_after": after_tok,
665
- "model_calls": 1,
666
- "pred_labels": json.dumps(pred_labels, ensure_ascii=False),
667
- "exp_labels": json.dumps(exp_labels, ensure_ascii=False),
668
- "exact_match": exact,
669
- "precision": round(prec, 6),
670
- "recall": round(rec, 6),
671
- "f1": round(f1, 6),
672
- "hamming": round(ham, 6),
673
- "ubs_score": round(ubs, 6),
674
- })
675
-
676
- base = f"{repo_id.replace('/','_')}/{sample_id}/pre{int(preprocess)}_win{pre_window}_cues{int(add_cues)}_small{int(strip_smalltalk)}_cap{soft_cap}_r{r}"
677
- zout.writestr(base + "/PREPROCESSED.txt", proc_text)
678
- zout.writestr(base + "/MODEL_OUTPUT.raw.txt", raw_text)
679
- final_json = {
680
- "labels": pred_labels,
681
- "diagnostics": {
682
- "model_name": repo_id,
683
- "latency_ms": latency_ms,
684
- "token_in_est_before": before_tok,
685
- "token_in_est_after": after_tok,
686
- "preprocess": preprocess,
687
- "pre_window": pre_window,
688
- "pre_add_cues_header": add_cues if preprocess else False,
689
- "pre_strip_smalltalk": strip_smalltalk if preprocess else False,
690
- "pre_soft_token_cap": soft_cap if preprocess else None,
691
- "model_calls": 1
692
- }
693
- }
694
- zout.writestr(base + "/FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
695
-
696
- latencies.append(latency_ms)
697
- last_pred = pred_labels
698
- total_runs += 1
699
-
700
- if latencies:
701
- med = int(statistics.median(latencies))
702
- exact, prec, rec, f1, ham = classic_metrics(last_pred, exp_labels) if last_pred is not None else (None,)*5
703
- ubs = ubs_score_one(exp_labels, last_pred) if last_pred is not None else None
704
- rows.append({
705
- "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
706
- "sample_id": sample_id,
707
- "model": repo_id,
708
- "is_summary": True,
709
- "run_index": None,
710
- "preprocess": preprocess,
711
- "pre_window": pre_window,
712
- "add_cues_header": add_cues,
713
- "strip_smalltalk": strip_smalltalk,
714
- "soft_cap": soft_cap,
715
- "median_latency_ms": med,
716
- "latency_ms": None,
717
- "token_before": None,
718
- "token_after": None,
719
- "model_calls": None,
720
- "pred_labels": json.dumps(last_pred or [], ensure_ascii=False),
721
- "exp_labels": json.dumps(exp_labels or [], ensure_ascii=False),
722
- "exact_match": exact,
723
- "precision": round(prec, 6) if prec is not None else None,
724
- "recall": round(rec, 6) if rec is not None else None,
725
- "f1": round(f1, 6) if f1 is not None else None,
726
- "hamming": round(ham, 6) if ham is not None else None,
727
- "ubs_score": round(ubs, 6) if ubs is not None else None,
728
- })
729
-
730
- if total_runs >= max_total_runs:
731
- break
732
-
733
- zout.close()
734
- df = pd.DataFrame(rows)
735
- if df.empty:
736
- return pd.DataFrame(), None, None, "No runs executed (empty dataset / exceeded cap / gated models).", banner_text(last_gpu_token_seen)
737
-
738
- csv_pair = ("results.csv", df.to_csv(index=False).encode("utf-8"))
739
- zip_pair = ("artifacts.zip", all_artifacts.getvalue())
740
- return df, csv_pair, zip_pair, "Done.", banner_text(last_gpu_token_seen)
741
-
742
- # ========= UI helpers =========
743
- OPEN_MODEL_PRESETS = [
744
- "mistralai/Mistral-7B-Instruct-v0.2",
745
- "Qwen/Qwen2.5-7B-Instruct",
746
- "HuggingFaceH4/zephyr-7b-beta",
747
- "tiiuae/falcon-7b-instruct",
748
- ]
749
 
750
- def banner_text(gpu_token_seen: bool | None = None) -> str:
751
- app_seen = bool(HF_TOKEN)
752
- lines = []
753
- if not app_seen:
754
- lines.append("🟡 **HF_TOKEN not detected in App** — gated models will fail unless you set it in **Settings → Variables and secrets**.")
755
- else:
756
- lines.append("🟢 **HF_TOKEN detected in App**.")
757
- if gpu_token_seen is None:
758
- lines.append("ℹ️ ZeroGPU token status: click **Run** or **Check ZeroGPU token** to verify.")
759
- else:
760
- lines.append("🟢 **HF_TOKEN detected inside ZeroGPU job.**" if gpu_token_seen else "🔴 **HF_TOKEN missing inside ZeroGPU job** (add `secrets=[\"HF_TOKEN\"]` to @spaces.GPU).")
761
- lines.append("✅ Tip: use **Open models** (no license gating): " + ", ".join(OPEN_MODEL_PRESETS))
762
- # Show pin info for transparency
763
- try:
764
- revs = [f"{m}@{resolve_revision(m)}" for m in OPEN_MODEL_PRESETS]
765
- lines.append("📌 Pinned revisions: " + ", ".join(revs))
766
- except Exception:
767
- pass
768
- return "\n\n".join(lines)
769
-
770
- # ========= UI (dark red) =========
771
- DARK_RED_CSS = """
772
- :root, .gradio-container {
773
- --color-background: #0b0b0d;
774
- --color-foreground: #e6e6e6;
775
- --color-primary: #e11d48;
776
- --color-secondary: #111216;
777
- --color-border: #1f2024;
778
- --color-muted: #9ca3af;
779
- }
780
- .gradio-container { background: var(--color-background) !important; color: var(--color-foreground) !important; }
781
- .gr-box, .gr-panel, .gr-group, .gr-form, .wrap.svelte-1ipelgc {
782
- background: var(--color-secondary) !important;
783
- border: 1px solid var(--color-border) !important;
784
- border-radius: 10px !important;
785
- }
786
- button, .gr-button {
787
- border-radius: 10px !important;
788
- border: 1px solid var(--color-primary) !important;
789
- background: linear-gradient(180deg, #b91c1c, #7f1d1d) !important;
790
- color: white !important;
791
- }
792
- .kpi {
793
- border: 1px solid #e11d48; border-radius: 10px; padding: 12px; text-align: center;
794
- background: #1a0f10; font-size: 18px;
795
- }
796
- """
797
 
798
- with gr.Blocks(title="From Talk to Task — HF Space", css=DARK_RED_CSS) as demo:
799
- gr.Markdown("## 🟥 From Talk to Task — Batch & Single Task Extraction")
800
- help_md = (
801
- "This tool extracts **task labels** from transcripts using Hugging Face models. \n"
802
- "1) Pick a model (or paste a custom repo id). \n"
803
- "2) Provide **Instructions** and **Context**, then supply a transcript (single) or a ZIP (batch). \n"
804
- "3) Adjust parameters (soft token cap, preprocessing). \n"
805
- "4) Run and review **latency**, **precision/recall/F1**, **UBS score**, and download artifacts."
 
 
 
 
 
806
  )
807
- gr.Markdown(help_md)
808
 
809
- # Status banner (token presence + revisions)
810
- banner = gr.Markdown(banner_text())
 
 
 
 
 
 
 
 
811
 
812
- check_btn = gr.Button("Check ZeroGPU token")
813
- def _check_token():
814
- try:
815
- present = gpu_check_token()
816
- except Exception:
817
- present = None
818
- return banner_text(present)
819
- check_btn.click(_check_token, outputs=banner)
 
 
 
820
 
821
  with gr.Tabs():
822
- # Single
823
- with gr.TabItem("Single Transcript (default)"):
824
- with gr.Row():
825
- with gr.Column():
826
- preset_model = gr.Dropdown(choices=OPEN_MODEL_PRESETS, value=OPEN_MODEL_PRESETS[0],
827
- label="Model (Open presets — no gating)")
828
- custom_model = gr.Textbox(label="Custom model repo id (overrides preset)",
829
- placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct")
830
- instructions = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
831
- context = gr.Textbox(label="Context (User prefix before transcript)", lines=6, value=DEFAULT_CONTEXT)
832
- with gr.Column():
833
- transcript_text = gr.Textbox(label="Paste transcript text", lines=14, placeholder="Paste your transcript here...")
834
- transcript_file = gr.File(label="...or upload a single transcript .txt", file_types=[".txt"], file_count="single", type="filepath")
835
- expected_labels_json = gr.File(label="(Optional) Expected labels JSON for metrics", file_types=[".json"], file_count="single", type="filepath")
836
-
837
- with gr.Row():
838
- with gr.Column():
839
- soft_cap_s = gr.Slider(1024, 32768, value=8192, step=512, label="Soft token cap")
840
- preprocess_s = gr.Checkbox(value=True, label="Enable preprocessing")
841
- pre_window_s = gr.Slider(0, 6, value=3, step=1, label="Window ± lines around cues")
842
- add_cues_s = gr.Checkbox(value=True, label="Add cues header")
843
- strip_smalltalk_s = gr.Checkbox(value=False, label="Strip smalltalk")
844
- gr.Markdown(explain_params_markdown())
845
- with gr.Column():
846
- load_4bit_s = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
847
- dtype_s = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
848
- trust_remote_code_s = gr.Checkbox(value=True, label="Trust remote code")
849
-
850
- run_single_btn = gr.Button("Run (Single)")
851
- kpi1 = gr.Markdown(elem_classes=["kpi"]); kpi2 = gr.Markdown(elem_classes=["kpi"]); kpi3 = gr.Markdown(elem_classes=["kpi"])
852
- single_table = gr.Dataframe(label="Single run — metrics & diagnostics", interactive=False)
853
- single_csv = gr.File(label="Download CSV", interactive=False)
854
- single_zip = gr.File(label="Download Artifacts ZIP", interactive=False)
855
- single_status = gr.Markdown("")
856
 
857
  def _run_single(*args):
858
- status, m1, m2, m3, df, csv_buf, zip_buf, btxt = single_mode(*args)
859
- return m1 or "", m2 or "", m3 or "", (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (status or ""), (btxt or banner_text())
860
 
861
- run_single_btn.click(
862
  _run_single,
863
- inputs=[preset_model, custom_model, instructions, context,
864
- transcript_text, transcript_file, expected_labels_json,
865
- soft_cap_s, preprocess_s, pre_window_s, add_cues_s, strip_smalltalk_s,
866
- load_4bit_s, dtype_s, trust_remote_code_s],
867
- outputs=[kpi1, kpi2, kpi3, single_table, single_csv, single_zip, single_status, banner]
868
  )
869
 
870
- # Batch
871
- with gr.TabItem("Batch (ZIP of many transcripts)"):
872
- with gr.Row():
873
- with gr.Column():
874
- models_list = gr.Checkboxgroup(
875
- choices=OPEN_MODEL_PRESETS, value=[OPEN_MODEL_PRESETS[0]],
876
- label="Models (Open presets — select one or more)"
877
- )
878
- custom_models = gr.Textbox(label="Custom model repo ids (comma-separated)",
879
- placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct")
880
- instructions_b = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
881
- context_b = gr.Textbox(label="Context (User prefix before transcript)", lines=6, value=DEFAULT_CONTEXT)
882
- with gr.Column():
883
- dataset_zip = gr.File(
884
- label="Upload ZIP of transcripts (*.txt) + expected (*.json)",
885
- file_types=[".zip"], file_count="single", type="filepath"
886
- )
887
- gr.Markdown("Zip must contain pairs like `ID.txt` and optional `ID.json` with expected labels (same base filename).")
888
-
889
- with gr.Row():
890
- with gr.Column():
891
- soft_cap = gr.Slider(1024, 32768, value=8192, step=512, label="Soft token cap")
892
- preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
893
- pre_window = gr.Slider(0, 6, value=3, step=1, label="Window ± lines around cues")
894
- add_cues = gr.Checkbox(value=True, label="Add cues header")
895
- strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
896
- gr.Markdown(explain_params_markdown())
897
- with gr.Column():
898
- repeats = gr.Slider(1, 6, value=3, step=1, label="Repeats per config")
899
- max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
900
- load_4bit = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
901
- dtype = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
902
- trust_remote_code = gr.Checkbox(value=True, label="Trust remote code")
903
-
904
- run_btn = gr.Button("Run Batch")
905
- kpi_b1 = gr.Markdown(elem_classes=["kpi"]); kpi_b2 = gr.Markdown(elem_classes=["kpi"]); kpi_b3 = gr.Markdown(elem_classes=["kpi"])
906
- table = gr.Dataframe(label="Batch results (per run + summary rows)", interactive=False)
907
- csv_dl = gr.File(label="Download CSV", interactive=False)
908
- zip_dl = gr.File(label="Download Artifacts ZIP", interactive=False)
909
- status = gr.Markdown("")
910
 
911
  def _run_batch(*args):
912
- df, csv_pair, zip_pair, msg, btxt = run_batch_ui(*args)
913
- m1 = m2 = m3 = ""
914
- if isinstance(df, pd.DataFrame) and not df.empty:
915
- summaries = df[df["is_summary"] == True]
916
- if not summaries.empty:
917
- last = summaries.iloc[-1]
918
- f1 = last.get("f1"); ubs = last.get("ubs_score"); med = last.get("median_latency_ms")
919
- m1 = f"**F1 (last summary)**\n\n{f1:.3f}" if pd.notna(f1) else "**F1 (last summary)**\n\n—"
920
- m2 = f"**UBS (last summary)**\n\n{ubs:.3f}" if pd.notna(ubs) else "**UBS (last summary)**\n\n—"
921
- m3 = f"**Median latency (ms)**\n\n{int(med) if pd.notna(med) else '—'}"
922
- csv_buf = zip_buf = None
923
- if isinstance(csv_pair, tuple):
924
- name, data = csv_pair; csv_buf = io.BytesIO(data); csv_buf.name = name
925
- if isinstance(zip_pair, tuple):
926
- name, data = zip_pair; zip_buf = io.BytesIO(data); zip_buf.name = name
927
- return m1, m2, m3, (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (msg or ""), (btxt or banner_text())
928
 
929
- run_btn.click(
930
  _run_batch,
931
- inputs=[models_list, custom_models, instructions_b, context_b, dataset_zip,
932
- soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
933
- repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
934
- outputs=[kpi_b1, kpi_b2, kpi_b3, table, csv_dl, zip_dl, status, banner]
 
935
  )
936
 
937
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # From Talk to Task Batch & Single Task Extraction
3
+ # Works on CPU / GPU / ZeroGPU. Uses a writable HF cache path (no /data).
4
+ # If you want to use gated models (e.g., mistralai/Mistral-7B-Instruct-v0.2),
5
+ # accept the license on HF and set HF_TOKEN in Space → Settings → Secrets.
6
+
7
+ import os
8
+ import io
9
+ import re
10
+ import sys
11
+ import time
12
+ import json
13
+ import zipfile
14
  from pathlib import Path
15
+ from typing import List, Dict, Tuple, Optional
16
 
17
  import gradio as gr
 
 
 
 
18
 
19
+ # ====== Robust, writable HF cache ======
20
+ # Avoid /data (read-only in Spaces). Prefer $HOME or /tmp.
21
+ HOME = Path(os.environ.get("HOME", "/home/user"))
22
+ CACHE_DIR = HOME / ".cache" / "huggingface"
23
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
24
+ os.environ.setdefault("HF_HOME", str(CACHE_DIR))
25
+ os.environ.setdefault("TRANSFORMERS_CACHE", str(CACHE_DIR))
26
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads when available
27
+
28
+ HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # ====== Transformers safe import ======
31
+ try:
32
+ import torch
33
+ from transformers import (
34
+ AutoTokenizer,
35
+ AutoModelForCausalLM,
36
+ BitsAndBytesConfig,
37
  )
38
+ except Exception as e:
39
+ raise RuntimeError(
40
+ "Failed to import transformers/torch. "
41
+ "Make sure requirements.txt includes: transformers>=4.41, torch, accelerate"
42
+ ) from e
43
 
44
+ DTYPE_FALLBACK = torch.float32
45
+ if torch.cuda.is_available():
46
+ DTYPE_FALLBACK = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
+
50
+ # ====== ZeroGPU (optional) ======
51
+ # If you’re running on ZeroGPU, Spaces injects a runtime; we keep a safe shim.
52
+ try:
53
+ import spaces # noqa: F401
54
+ ON_ZERO_GPU = True
55
+ except Exception:
56
+ ON_ZERO_GPU = False
57
+
58
+ # ====== UI presets ======
59
+ OPEN_MODEL_PRESETS = [
60
+ # choose truly open, ungated options first
61
+ "HuggingFaceH4/zephyr-7b-beta",
62
+ "Qwen/Qwen2.5-7B-Instruct",
63
+ "tiiuae/falcon-7b-instruct",
64
+ # You can still type a custom gated model repo id below if you have access.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ]
66
 
67
+ PINNED_REVISIONS = {
68
+ # None means "main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  "HuggingFaceH4/zephyr-7b-beta": None,
70
+ "Qwen/Qwen2.5-7B-Instruct": None,
71
  "tiiuae/falcon-7b-instruct": None,
72
+ # "mistralai/Mistral-7B-Instruct-v0.2": None, # gated — use only if token + license ok
73
  }
74
 
75
+ SYSTEM_INSTRUCTIONS = (
76
+ "You are a task extraction assistant. Always output valid JSON with a field "
77
+ '"labels" (list of strings). Use only from this set: '
78
+ '["plan_contact","schedule_meeting","update_contact_info_non_postal",'
79
+ '"update_contact_info_postal_address","update_kyc_activity","update_kyc_origin_of_assets",'
80
+ '"update_kyc_purpose_of_businessrelation","update_kyc_total_assets"]. '
81
+ "Return JSON only."
82
+ )
83
+
84
+ CONTEXT_GUIDE = """\
85
+ - plan_contact: conversation without a concrete meeting (no date/time)
86
+ - schedule_meeting: explicit date/time/modality confirmation
87
+ - update_contact_info_non_postal: changes to email/phone
88
+ - update_contact_info_postal_address: changes to mailing address
89
+ - update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)
90
+ """
91
+
92
+ # ====== Utility ======
93
+ def _json_only(text: str) -> str:
94
+ """
95
+ Try to extract the first JSON object from text.
96
+ """
97
+ text = text.strip()
98
+ if text.startswith("{") and text.endswith("}"):
99
+ return text
100
+ m = re.search(r"\{.*\}", text, re.DOTALL)
101
+ return m.group(0) if m else '{"labels": []}'
102
+
103
+ def safe_json_loads(s: str) -> dict:
104
+ try:
105
+ return json.loads(s)
106
+ except Exception:
107
+ return {"labels": []}
108
+
109
+ def build_prompt(system: str, context: str, transcript: str) -> str:
110
+ return (
111
+ f"### System:\n{system}\n\n"
112
+ f"### Context:\n{context}\n\n"
113
+ f"### Transcript:\n{transcript}\n\n"
114
+ "### Output:\nReturn JSON only."
115
  )
 
116
 
117
+ # ====== Model wrapper ======
118
  class HFModel:
119
+ def __init__(
120
+ self,
121
+ repo_id: str,
122
+ revision: Optional[str] = None,
123
+ load_in_4bit: bool = False,
124
+ trust_remote_code: bool = True,
125
+ dtype: Optional[torch.dtype] = None,
126
+ token: Optional[str] = None,
127
+ ) -> None:
128
  self.repo_id = repo_id
129
+ self.revision = revision or "main"
130
+ self.trust_remote_code = trust_remote_code
131
+ self.token = token
132
+ self.dtype = dtype or DTYPE_FALLBACK
133
+ self.load_in_4bit = load_in_4bit and (DEVICE == "cuda")
134
+ self.tokenizer = None
135
+ self.model = None
136
 
137
+ def load(self):
138
+ quant_cfg = None
139
+ if self.load_in_4bit:
140
+ quant_cfg = BitsAndBytesConfig(load_in_4bit=True)
141
+ try:
142
+ self.tokenizer = AutoTokenizer.from_pretrained(
143
+ self.repo_id,
144
+ revision=self.revision,
145
+ token=self.token,
146
+ cache_dir=str(CACHE_DIR),
147
+ trust_remote_code=self.trust_remote_code,
148
+ use_fast=True,
149
+ )
150
+ except Exception as e:
151
+ raise RuntimeError(
152
+ f"Failed to load tokenizer for {self.repo_id} "
153
+ "(If gated, accept license and set HF_TOKEN in Space → Settings → Secrets)."
154
+ ) from e
155
 
156
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  self.model = AutoModelForCausalLM.from_pretrained(
158
+ self.repo_id,
159
+ revision=self.revision,
160
+ token=self.token,
161
+ cache_dir=str(CACHE_DIR),
162
+ trust_remote_code=self.trust_remote_code,
163
+ torch_dtype=self.dtype,
164
+ device_map="auto" if DEVICE == "cuda" else None,
165
+ quantization_config=quant_cfg,
166
+ low_cpu_mem_usage=True,
167
  )
168
+ if DEVICE == "cpu":
169
+ self.model = self.model.to(DEVICE)
170
+ except Exception as e:
171
+ raise RuntimeError(
172
+ f"Failed to load model weights for {self.repo_id}. "
173
+ "Check license, token, and hardware availability."
174
+ ) from e
 
 
 
 
 
175
 
176
  @torch.inference_mode()
177
+ def generate(self, prompt: str, max_new_tokens: int = 256, temperature: float = 0.1) -> str:
178
+ tok = self.tokenizer
179
+ mdl = self.model
180
+ if tok.pad_token is None:
181
+ tok.pad_token = tok.eos_token
182
+
183
+ inputs = tok(prompt, return_tensors="pt").to(mdl.device)
184
+ out = mdl.generate(
185
+ **inputs,
186
+ max_new_tokens=max_new_tokens,
187
+ do_sample=temperature > 0,
188
+ temperature=temperature,
189
+ top_p=0.9,
190
+ pad_token_id=tok.eos_token_id,
191
+ eos_token_id=tok.eos_token_id,
192
  )
193
+ text = tok.decode(out[0], skip_special_tokens=True)
194
+ gen = text[len(prompt):].strip() if text.startswith(prompt) else text
195
+ return _json_only(gen)
196
+
197
+ # ====== Model cache (per Space worker) ======
198
+ _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
199
+
200
+ def get_model(repo_id: str, revision: Optional[str], load_in_4bit: bool) -> HFModel:
201
+ key = (repo_id, revision, load_in_4bit)
202
+ if key in _MODEL_CACHE:
203
+ return _MODEL_CACHE[key]
204
+ model = HFModel(
205
+ repo_id=repo_id,
206
+ revision=revision,
207
+ load_in_4bit=load_in_4bit,
208
+ token=HF_TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  )
210
+ model.load()
211
+ _MODEL_CACHE[key] = model
212
+ return model
213
+
214
+ # ====== Single transcript inference ======
215
+ def run_single(
216
+ model_choice: str,
217
+ custom_repo_id: str,
218
+ system: str,
219
+ context: str,
220
+ transcript: str,
221
+ soft_token_cap: int,
222
+ preprocess: bool,
223
+ lines_window: int,
224
+ add_header: bool,
225
+ strip_smalltalk: bool,
226
+ load_in_4bit: bool,
227
+ ) -> Tuple[str, str, str, str]:
228
+ """
229
+ Returns (repo_id_used, revision, raw_json, debug_log)
230
+ """
231
+ debug = []
232
+ t0 = time.perf_counter()
233
+
234
+ repo = (custom_repo_id or model_choice).strip()
235
+ rev = PINNED_REVISIONS.get(repo, None)
236
+ debug.append(f"Repo: {repo} | Revision: {rev or 'main'} | 4bit: {load_in_4bit} | Device: {DEVICE}")
237
+
238
+ # Lightweight "preprocess"
239
+ if preprocess:
240
+ # basic cleanup
241
+ lines = [ln.rstrip() for ln in transcript.splitlines()]
242
+ if strip_smalltalk:
243
+ lines = [ln for ln in lines if not re.search(r"\b(thanks?|bye|ok(ay)?)\b", ln, re.I)]
244
+ transcript = "\n".join(lines[-32768:]) # hard cap
245
+ if add_header:
246
+ transcript = f"[EMAIL/MESSAGE SIGNAL]\n{transcript}"
247
+
248
+ # Soft token cap (truncate by char approximation)
249
+ if soft_token_cap and soft_token_cap > 0:
250
+ approx_chars = int(soft_token_cap * 4) # naive 4 chars/token
251
+ if len(transcript) > approx_chars:
252
+ transcript = transcript[-approx_chars:]
253
+
254
+ prompt = build_prompt(system or SYSTEM_INSTRUCTIONS, context or CONTEXT_GUIDE, transcript)
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  try:
257
+ model = get_model(repo, rev, load_in_4bit)
258
+ raw = model.generate(prompt, max_new_tokens=256, temperature=0.1)
259
+ data = safe_json_loads(raw)
260
+ out_json = json.dumps(data, ensure_ascii=False)
261
+ debug.append(f"Generation OK in {time.perf_counter()-t0:.2f}s")
262
+ return repo, (rev or "main"), out_json, "\n".join(debug)
263
  except Exception as e:
264
+ debug.append(f"ERROR: {e}")
265
+ return repo, (rev or "main"), json.dumps({"labels": []}), "\n".join(debug)
266
+
267
+ # ====== Batch (ZIP of many .txt files) ======
268
+ def run_batch(
269
+ model_choice: str,
270
+ custom_repo_id: str,
271
+ system: str,
272
+ context: str,
273
+ zip_file: Optional[io.BytesIO],
274
+ soft_token_cap: int,
275
+ preprocess: bool,
276
+ lines_window: int,
277
+ add_header: bool,
278
+ strip_smalltalk: bool,
279
+ load_in_4bit: bool,
280
+ ) -> Tuple[str, str, str, str]:
281
+ """
282
+ Accepts a ZIP of .txt files. Returns (repo_id, revision, csv_like, debug)
283
+ """
284
+ debug = []
285
+ repo = (custom_repo_id or model_choice).strip()
286
+ rev = PINNED_REVISIONS.get(repo, None)
287
+
288
+ if not zip_file:
289
+ return repo, (rev or "main"), "filename,labels\n", "No ZIP provided."
290
 
291
  try:
292
+ z = zipfile.ZipFile(zip_file)
293
+ names = [n for n in z.namelist() if n.lower().endswith(".txt")]
294
+ debug.append(f"Files detected: {len(names)}")
295
  except Exception as e:
296
+ return repo, (rev or "main"), "filename,labels\n", f"Bad ZIP: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  try:
299
+ model = get_model(repo, rev, load_in_4bit)
 
300
  except Exception as e:
301
+ return repo, (rev or "main"), "filename,labels\n", f"Model load error: {e}"
 
 
 
 
 
302
 
303
+ rows = ["filename,labels"]
304
+ for name in names:
305
  try:
306
+ txt = z.read(name).decode("utf-8", errors="replace")
307
+ _, _, labels_json, _ = run_single(
308
+ model_choice, custom_repo_id, system, context, txt,
309
+ soft_token_cap, preprocess, lines_window, add_header,
310
+ strip_smalltalk, load_in_4bit
311
+ )
312
+ labels = safe_json_loads(labels_json).get("labels", [])
313
+ rows.append(f"{name},{json.dumps(labels, ensure_ascii=False)}")
314
  except Exception as e:
315
+ rows.append(f"{name},[] # error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ return repo, (rev or "main"), "\n".join(rows), "\n".join(debug)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ # ====== Gradio UI ======
320
+ with gr.Blocks(title="From Talk to Task — Batch & Single Task Extraction") as demo:
321
+ gr.Markdown(
322
+ """
323
+ # From Talk to Task Batch & Single Task Extraction
324
+
325
+ **Tip:** Use **open models** first (no gating). If you pick a gated model, make sure
326
+ you have accepted its license _and_ set `HF_TOKEN` in **Settings → Secrets**.
327
+
328
+ **Pinned revisions:** {}
329
+ """.format(
330
+ ", ".join([f"{k}@{v or 'main'}" for k, v in PINNED_REVISIONS.items()])
331
+ )
332
  )
 
333
 
334
+ with gr.Row():
335
+ model_choice = gr.Dropdown(
336
+ OPEN_MODEL_PRESETS,
337
+ label="Model (Open presets — no gating)",
338
+ value=OPEN_MODEL_PRESETS[0],
339
+ )
340
+ custom_repo_id = gr.Textbox(
341
+ label="Custom model repo id (overrides preset)",
342
+ placeholder="e.g. mistralai/Mistral-7B-Instruct-v0.2 (requires license + HF_TOKEN)"
343
+ )
344
 
345
+ system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS, lines=5)
346
+ context = gr.Textbox(label="Context (User prefix before transcript)", value=CONTEXT_GUIDE, lines=6)
347
+
348
+ with gr.Row():
349
+ soft_cap = gr.Slider(1024, 32768, value=8192, step=1, label="Soft token cap")
350
+ preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
351
+ lines_window = gr.Slider(0, 6, value=3, step=1, label="Window ± lines around cues")
352
+ with gr.Row():
353
+ add_header = gr.Checkbox(value=True, label="Add cues header")
354
+ strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
355
+ load_4bit = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
356
 
357
  with gr.Tabs():
358
+ with gr.Tab("Single Transcript (default)"):
359
+ transcript = gr.Textbox(label="Paste transcript text", lines=12, placeholder="Paste your transcript here...")
360
+ run_btn = gr.Button("Run (Single)", variant="primary")
361
+ repo_used = gr.Textbox(label="Repo used", interactive=False)
362
+ rev_used = gr.Textbox(label="Revision", interactive=False)
363
+ json_out = gr.Code(label="JSON Output", language="json")
364
+ debug_out = gr.Textbox(label="Diagnostics", lines=6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  def _run_single(*args):
367
+ r, v, j, d = run_single(*args)
368
+ return r, v, j, d
369
 
370
+ run_btn.click(
371
  _run_single,
372
+ inputs=[
373
+ model_choice, custom_repo_id, system, context, transcript,
374
+ soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
375
+ ],
376
+ outputs=[repo_used, rev_used, json_out, debug_out],
377
  )
378
 
379
+ with gr.Tab("Batch (ZIP of many transcripts)"):
380
+ zip_in = gr.File(label="Upload ZIP of .txt transcripts", file_types=[".zip"])
381
+ run_batch_btn = gr.Button("Run (Batch)", variant="primary")
382
+ repo_used_b = gr.Textbox(label="Repo used", interactive=False)
383
+ rev_used_b = gr.Textbox(label="Revision", interactive=False)
384
+ csv_out = gr.Code(label="CSV (filename,labels)", language="text")
385
+ debug_out_b = gr.Textbox(label="Diagnostics", lines=6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  def _run_batch(*args):
388
+ r, v, c, d = run_batch(*args)
389
+ return r, v, c, d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ run_batch_btn.click(
392
  _run_batch,
393
+ inputs=[
394
+ model_choice, custom_repo_id, system, context, zip_in,
395
+ soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
396
+ ],
397
+ outputs=[repo_used_b, rev_used_b, csv_out, debug_out_b],
398
  )
399
 
400
+ gr.Markdown(
401
+ f"""
402
+ - **HF_TOKEN detected:** {"✅ yes" if HF_TOKEN else "⚠️ no (only needed for gated models)"}
403
+ - **Device:** {DEVICE}
404
+ - **Cache dir:** `{CACHE_DIR}`
405
+ """
406
+ )
407
+
408
+ if __name__ == "__main__":
409
+ # Gradio 5 default port/host are fine in Spaces; keep `debug` false for speed
410
+ demo.launch()