RishiRP commited on
Commit
9949cc9
·
verified ·
1 Parent(s): 6acd2cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -89
app.py CHANGED
@@ -1,15 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
- Allowed Labels (canonical; use only these):
3
- {allowed_labels_list}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- Context cues (keywords/phrases that often indicate each label):
6
- {keyword_context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- Instructions:
9
- 1) Identify EVERY concrete task implied by the conversation.
10
- 2) Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).
11
- 3) Return STRICT JSON only in the exact schema described by the system prompt.
12
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # =========================
15
  # Utilities
@@ -55,8 +161,7 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
55
  continue
56
  k = str(t.get("label", "")).strip().lower()
57
  if k in allowed_map:
58
- new_t = dict(t)
59
- new_t["label"] = allowed_map[k]
60
  filt_tasks.append(new_t)
61
  merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
62
  out["labels"] = merged
@@ -64,10 +169,8 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
64
  return out
65
 
66
  # =========================
67
- # Default pre-processing
68
  # =========================
69
- # These are conservative; they remove boilerplate that appears in many files
70
- # and does not affect tasks. You can toggle this in the UI.
71
  _DISCLAIMER_PATTERNS = [
72
  r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
73
  r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
@@ -87,8 +190,7 @@ def clean_transcript(text: str) -> str:
87
  if not text:
88
  return text
89
  s = text
90
-
91
- # Remove common timestamps and speaker prefixes (line-wise)
92
  lines = []
93
  for ln in s.splitlines():
94
  ln2 = ln
@@ -96,16 +198,13 @@ def clean_transcript(text: str) -> str:
96
  ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
97
  lines.append(ln2)
98
  s = "\n".join(lines)
99
-
100
- # Remove top disclaimers
101
  for pat in _DISCLAIMER_PATTERNS:
102
  s = re.sub(pat, "", s).strip()
103
-
104
- # Remove trailing footers/signatures
105
  for pat in _FOOTER_PATTERNS:
106
  s = re.sub(pat, "", s)
107
-
108
- # Collapse repeated whitespace
109
  s = re.sub(r"[ \t]+", " ", s)
110
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
111
  return s
@@ -194,8 +293,7 @@ _MODEL_CACHE: Dict[str, ModelWrapper] = {}
194
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
195
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
196
  if key not in _MODEL_CACHE:
197
- m = ModelWrapper(repo_id, hf_token, load_in_4bit)
198
- m.load()
199
  _MODEL_CACHE[key] = m
200
  return _MODEL_CACHE[key]
201
 
@@ -211,7 +309,6 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
211
  def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
212
  if not isinstance(sample_labels, list):
213
  raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
214
- # dedupe
215
  seen, uniq = set(), []
216
  for label in sample_labels:
217
  if not isinstance(label, str):
@@ -219,7 +316,6 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
219
  if label in seen:
220
  raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
221
  seen.add(label); uniq.append(label)
222
- # validity
223
  valid = []
224
  for label in uniq:
225
  if label not in ALLOWED_LABELS:
@@ -257,10 +353,7 @@ def build_keyword_context(allowed: List[str]) -> str:
257
  parts = []
258
  for lab in allowed:
259
  kws = LABEL_KEYWORDS.get(lab, [])
260
- if kws:
261
- parts.append(f"- {lab}: " + ", ".join(kws))
262
- else:
263
- parts.append(f"- {lab}: (no default cues)")
264
  return "\n".join(parts)
265
 
266
  def run_single(
@@ -276,29 +369,23 @@ def run_single(
276
 
277
  t0 = _now_ms()
278
 
279
- # Get transcript
280
  raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
281
  raw_text = (raw_text or "").strip()
282
  if not raw_text:
283
  return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
284
 
285
- # Cleaning
286
  text = clean_transcript(raw_text) if use_cleaning else raw_text
287
 
288
- # Allowed labels
289
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
290
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
291
 
292
- # Model
293
  try:
294
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
295
  except Exception as e:
296
  return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
297
 
298
- # Truncate
299
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
300
 
301
- # Build prompt
302
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
303
  keyword_ctx = build_keyword_context(allowed)
304
  user_prompt = USER_PROMPT_TEMPLATE.format(
@@ -307,7 +394,6 @@ def run_single(
307
  keyword_context=keyword_ctx,
308
  )
309
 
310
- # Generate
311
  t1 = _now_ms()
312
  try:
313
  out = model.generate(SYSTEM_PROMPT, user_prompt)
@@ -315,11 +401,9 @@ def run_single(
315
  return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
316
  t2 = _now_ms()
317
 
318
- # Parse + filter
319
  parsed = robust_json_extract(out)
320
  filtered = restrict_to_allowed(parsed, allowed)
321
 
322
- # Diagnostics
323
  diag = "\n".join([
324
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
325
  f"Model: {model_repo}",
@@ -329,7 +413,6 @@ def run_single(
329
  f"Allowed labels: {', '.join(allowed)}",
330
  ])
331
 
332
- # Summary
333
  labs = filtered.get("labels", [])
334
  tasks = filtered.get("tasks", [])
335
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
@@ -350,11 +433,7 @@ def read_zip(fileobj: io.BytesIO, exdir: Path) -> List[Path]:
350
  exdir.mkdir(parents=True, exist_ok=True)
351
  with zipfile.ZipFile(fileobj) as zf:
352
  zf.extractall(exdir)
353
- out = []
354
- for p in exdir.rglob("*"):
355
- if p.is_file():
356
- out.append(p)
357
- return out
358
 
359
  def run_batch(
360
  zip_file: gr.File,
@@ -364,25 +443,27 @@ def run_batch(
364
  max_input_tokens: int,
365
  hf_token: str,
366
  limit_files: int,
367
- ) -> Tuple[str, str, str, pd.DataFrame, str]:
368
 
369
  if not zip_file:
370
- return ("No ZIP provided.", "", "", pd.DataFrame(), "")
371
 
372
  work = Path("/tmp/batch")
373
  if work.exists():
374
- for p in work.rglob("*"):
375
- try: p.unlink()
376
- except Exception: pass
377
- try: work.rmdir()
378
- except Exception: pass
 
 
 
 
379
  work.mkdir(parents=True, exist_ok=True)
380
 
381
- # Unzip
382
  data = zip_file.read()
383
  files = read_zip(io.BytesIO(data), work)
384
 
385
- # Gather pairs by stem
386
  txts: Dict[str, Path] = {}
387
  gts: Dict[str, Path] = {}
388
  for p in files:
@@ -395,15 +476,14 @@ def run_batch(
395
  if limit_files > 0:
396
  stems = stems[:limit_files]
397
  if not stems:
398
- return ("No .txt transcripts found in ZIP.", "", "", pd.DataFrame(), "")
399
 
400
- # Model
401
  try:
402
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
403
  except Exception as e:
404
- return (f"Model load failed: {e}", "", "", pd.DataFrame(), "")
405
 
406
- allowed = OFFICIAL_LABELS[:] # fixed for scoring
407
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
408
  keyword_ctx = build_keyword_context(allowed)
409
 
@@ -431,20 +511,17 @@ def run_batch(
431
  pred_labels = filtered.get("labels", [])
432
  y_pred.append(pred_labels)
433
 
434
- # Ground truth (optional)
435
  gt_labels = []
436
  if stem in gts:
437
  try:
438
  gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
439
- if isinstance(gt_obj, dict) and "labels" in gt_obj and isinstance(gt_obj["labels"], list):
440
  gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
441
  except Exception:
442
  pass
443
  y_true.append(gt_labels)
444
 
445
- # FP/FN counts for table
446
- gt_set = set(gt_labels)
447
- pr_set = set(pred_labels)
448
  tp = sorted(gt_set & pr_set)
449
  fp = sorted(pr_set - gt_set)
450
  fn = sorted(gt_set - pr_set)
@@ -457,8 +534,6 @@ def run_batch(
457
  "gen_ms": t1 - t0
458
  })
459
 
460
- # Metrics
461
- # If there is no ground truth in the ZIP, we still compute a table and skip score.
462
  have_truth = any(len(v) > 0 for v in y_true)
463
  score = evaluate_predictions(y_true, y_pred) if have_truth else None
464
 
@@ -472,7 +547,6 @@ def run_batch(
472
  f"Batch time: {_now_ms()-t_start} ms",
473
  ]
474
  if have_truth and score is not None:
475
- # Simple derived metrics
476
  total_tp = int(df["TP"].sum())
477
  total_fp = int(df["FP"].sum())
478
  total_fn = int(df["FN"].sum())
@@ -486,12 +560,11 @@ def run_batch(
486
  ]
487
  diag_str = "\n".join(diag)
488
 
489
- # CSV preview and data URL
490
- csv_buf = io.StringIO()
491
- df.to_csv(csv_buf, index=False)
492
- csv_data = csv_buf.getvalue()
493
 
494
- return ("Batch done.", diag_str, csv_data, df, csv_data)
495
 
496
  # =========================
497
  # UI
@@ -505,10 +578,8 @@ MODEL_CHOICES = [
505
  with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
506
  gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
507
  gr.Markdown(
508
- "This tool extracts challenge labels from transcripts. "
509
- "Use **Single** for quick tests; use **Batch** to score a ZIP with transcripts + truths. "
510
- "_Note: False negatives are penalised twice as much as false positives in the official metric; "
511
- "we bias for recall._"
512
  )
513
 
514
  with gr.Tab("Single transcript"):
@@ -520,9 +591,12 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
520
  type="filepath",
521
  )
522
  text = gr.Textbox(label="Or paste transcript", lines=14)
523
- use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, footers)", value=True)
 
 
 
524
  labels_text = gr.Textbox(
525
- label="Allowed Labels (one per line; leave empty to use official list)",
526
  value="",
527
  lines=8,
528
  )
@@ -561,23 +635,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
561
 
562
  with gr.Row():
563
  status = gr.Textbox(label="Status", lines=1)
564
- diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=10)
565
-
566
- with gr.Row():
567
- df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, times)", interactive=False)
568
- csv_out = gr.File(label="Download CSV (click to save)", interactive=False)
569
 
570
- def _save_csv(csv_text: str) -> str:
571
- if not csv_text:
572
- return ""
573
- out_path = Path("/tmp/batch_results.csv")
574
- out_path.write_text(csv_text, encoding="utf-8")
575
- return str(out_path)
576
 
577
  run_batch_btn.click(
578
  fn=run_batch,
579
  inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
580
- outputs=[status, diag_b, csv_out, df_out, gr.Textbox(visible=False)],
581
  )
582
 
583
  if __name__ == "__main__":
 
1
+ # app.py
2
+ import os
3
+ import re
4
+ import io
5
+ import json
6
+ import time
7
+ import zipfile
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any, Tuple, Optional
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import gradio as gr
14
+
15
+ import torch
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ BitsAndBytesConfig,
20
+ GenerationConfig,
21
+ )
22
 
23
+ # =========================
24
+ # Global config
25
+ # =========================
26
+ SPACE_CACHE = Path.home() / ".cache" / "huggingface"
27
+ SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ GEN_CONFIG = GenerationConfig(
31
+ temperature=0.2,
32
+ top_p=0.9,
33
+ do_sample=False,
34
+ max_new_tokens=256,
35
+ )
36
+
37
+ # Official UBS label set (strict)
38
+ OFFICIAL_LABELS = [
39
+ "plan_contact",
40
+ "schedule_meeting",
41
+ "update_contact_info_non_postal",
42
+ "update_contact_info_postal_address",
43
+ "update_kyc_activity",
44
+ "update_kyc_origin_of_assets",
45
+ "update_kyc_purpose_of_businessrelation",
46
+ "update_kyc_total_assets",
47
+ ]
48
 
49
+ # Per-label keyword cues (static prompt context to improve recall)
50
+ LABEL_KEYWORDS: Dict[str, List[str]] = {
51
+ "plan_contact": [
52
+ "call back", "follow up", "reach out", "contact later", "check-in",
53
+ "email them", "touch base", "remind", "send a note"
54
+ ],
55
+ "schedule_meeting": [
56
+ "book a meeting", "set up a meeting", "schedule a call",
57
+ "appointment", "calendar", "meeting next week", "meet on", "time slot"
58
+ ],
59
+ "update_contact_info_non_postal": [
60
+ "phone change", "new phone", "email change", "new email",
61
+ "update contact details", "update mobile", "alternate phone"
62
+ ],
63
+ "update_contact_info_postal_address": [
64
+ "moved to", "new address", "postal address", "mailing address",
65
+ "change of address", "residential address"
66
+ ],
67
+ "update_kyc_activity": [
68
+ "activity update", "economic activity", "employment status",
69
+ "occupation", "job change", "business activity"
70
+ ],
71
+ "update_kyc_origin_of_assets": [
72
+ "source of funds", "origin of assets", "where money comes from",
73
+ "inheritance", "salary", "business income", "asset origin"
74
+ ],
75
+ "update_kyc_purpose_of_businessrelation": [
76
+ "purpose of relationship", "why the account", "reason for banking",
77
+ "investment purpose", "relationship purpose"
78
+ ],
79
+ "update_kyc_total_assets": [
80
+ "total assets", "net worth", "assets under ownership",
81
+ "portfolio size", "how much you own"
82
+ ],
83
+ }
84
 
85
+ # =========================
86
+ # Instructions (string-safe; concatenated)
87
+ # =========================
88
+ SYSTEM_PROMPT = (
89
+ "You are a precise banking assistant that extracts ACTIONABLE TASKS from "
90
+ "client–advisor transcripts. Be conservative with hallucinations but "
91
+ "prioritise RECALL: if unsure and the transcript plausibly implies an "
92
+ "action, include the label and explain briefly.\n\n"
93
+ "Output STRICT JSON only:\n\n"
94
+ "{\n"
95
+ ' "labels": ["<Label1>", "..."],\n'
96
+ ' "tasks": [\n'
97
+ ' {"label": "<Label1>", "explanation": "<why>", "evidence": "<quoted text/snippet>"}\n'
98
+ " ]\n"
99
+ "}\n\n"
100
+ "Rules:\n"
101
+ "- Use ONLY allowed labels supplied to you. Case-insensitive during reasoning, "
102
+ " but output the canonical label text exactly.\n"
103
+ "- If none truly apply, return empty lists.\n"
104
+ "- Keep explanations concise; put the minimal evidence snippet that justifies the task.\n"
105
+ )
106
+
107
+ USER_PROMPT_TEMPLATE = (
108
+ "Transcript (cleaned):\n"
109
+ "```\n{transcript}\n```\n\n"
110
+ "Allowed Labels (canonical; use only these):\n"
111
+ "{allowed_labels_list}\n\n"
112
+ "Context cues (keywords/phrases that often indicate each label):\n"
113
+ "{keyword_context}\n\n"
114
+ "Instructions:\n"
115
+ "- Identify EVERY concrete task implied by the conversation.\n"
116
+ "- Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).\n"
117
+ "- Return STRICT JSON only in the exact schema described by the system prompt.\n"
118
+ )
119
 
120
  # =========================
121
  # Utilities
 
161
  continue
162
  k = str(t.get("label", "")).strip().lower()
163
  if k in allowed_map:
164
+ new_t = dict(t); new_t["label"] = allowed_map[k]
 
165
  filt_tasks.append(new_t)
166
  merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
167
  out["labels"] = merged
 
169
  return out
170
 
171
  # =========================
172
+ # Default pre-processing (toggleable)
173
  # =========================
 
 
174
  _DISCLAIMER_PATTERNS = [
175
  r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
176
  r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
 
190
  if not text:
191
  return text
192
  s = text
193
+ # remove timestamps/speaker prefixes line-wise
 
194
  lines = []
195
  for ln in s.splitlines():
196
  ln2 = ln
 
198
  ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
199
  lines.append(ln2)
200
  s = "\n".join(lines)
201
+ # remove top disclaimers
 
202
  for pat in _DISCLAIMER_PATTERNS:
203
  s = re.sub(pat, "", s).strip()
204
+ # remove trailing footers
 
205
  for pat in _FOOTER_PATTERNS:
206
  s = re.sub(pat, "", s)
207
+ # collapse whitespace
 
208
  s = re.sub(r"[ \t]+", " ", s)
209
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
210
  return s
 
293
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
294
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
295
  if key not in _MODEL_CACHE:
296
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load()
 
297
  _MODEL_CACHE[key] = m
298
  return _MODEL_CACHE[key]
299
 
 
309
  def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
310
  if not isinstance(sample_labels, list):
311
  raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
 
312
  seen, uniq = set(), []
313
  for label in sample_labels:
314
  if not isinstance(label, str):
 
316
  if label in seen:
317
  raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
318
  seen.add(label); uniq.append(label)
 
319
  valid = []
320
  for label in uniq:
321
  if label not in ALLOWED_LABELS:
 
353
  parts = []
354
  for lab in allowed:
355
  kws = LABEL_KEYWORDS.get(lab, [])
356
+ parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
 
 
 
357
  return "\n".join(parts)
358
 
359
  def run_single(
 
369
 
370
  t0 = _now_ms()
371
 
 
372
  raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
373
  raw_text = (raw_text or "").strip()
374
  if not raw_text:
375
  return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
376
 
 
377
  text = clean_transcript(raw_text) if use_cleaning else raw_text
378
 
 
379
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
380
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
381
 
 
382
  try:
383
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
384
  except Exception as e:
385
  return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
386
 
 
387
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
388
 
 
389
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
390
  keyword_ctx = build_keyword_context(allowed)
391
  user_prompt = USER_PROMPT_TEMPLATE.format(
 
394
  keyword_context=keyword_ctx,
395
  )
396
 
 
397
  t1 = _now_ms()
398
  try:
399
  out = model.generate(SYSTEM_PROMPT, user_prompt)
 
401
  return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
402
  t2 = _now_ms()
403
 
 
404
  parsed = robust_json_extract(out)
405
  filtered = restrict_to_allowed(parsed, allowed)
406
 
 
407
  diag = "\n".join([
408
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
409
  f"Model: {model_repo}",
 
413
  f"Allowed labels: {', '.join(allowed)}",
414
  ])
415
 
 
416
  labs = filtered.get("labels", [])
417
  tasks = filtered.get("tasks", [])
418
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
433
  exdir.mkdir(parents=True, exist_ok=True)
434
  with zipfile.ZipFile(fileobj) as zf:
435
  zf.extractall(exdir)
436
+ return [p for p in exdir.rglob("*") if p.is_file()]
 
 
 
 
437
 
438
  def run_batch(
439
  zip_file: gr.File,
 
443
  max_input_tokens: int,
444
  hf_token: str,
445
  limit_files: int,
446
+ ) -> Tuple[str, str, pd.DataFrame, str]:
447
 
448
  if not zip_file:
449
+ return ("No ZIP provided.", "", pd.DataFrame(), "")
450
 
451
  work = Path("/tmp/batch")
452
  if work.exists():
453
+ for p in sorted(work.rglob("*"), reverse=True):
454
+ try:
455
+ p.unlink()
456
+ except Exception:
457
+ pass
458
+ try:
459
+ work.rmdir()
460
+ except Exception:
461
+ pass
462
  work.mkdir(parents=True, exist_ok=True)
463
 
 
464
  data = zip_file.read()
465
  files = read_zip(io.BytesIO(data), work)
466
 
 
467
  txts: Dict[str, Path] = {}
468
  gts: Dict[str, Path] = {}
469
  for p in files:
 
476
  if limit_files > 0:
477
  stems = stems[:limit_files]
478
  if not stems:
479
+ return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
480
 
 
481
  try:
482
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
483
  except Exception as e:
484
+ return (f"Model load failed: {e}", "", pd.DataFrame(), "")
485
 
486
+ allowed = OFFICIAL_LABELS[:]
487
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
488
  keyword_ctx = build_keyword_context(allowed)
489
 
 
511
  pred_labels = filtered.get("labels", [])
512
  y_pred.append(pred_labels)
513
 
 
514
  gt_labels = []
515
  if stem in gts:
516
  try:
517
  gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
518
+ if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list):
519
  gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
520
  except Exception:
521
  pass
522
  y_true.append(gt_labels)
523
 
524
+ gt_set, pr_set = set(gt_labels), set(pred_labels)
 
 
525
  tp = sorted(gt_set & pr_set)
526
  fp = sorted(pr_set - gt_set)
527
  fn = sorted(gt_set - pr_set)
 
534
  "gen_ms": t1 - t0
535
  })
536
 
 
 
537
  have_truth = any(len(v) > 0 for v in y_true)
538
  score = evaluate_predictions(y_true, y_pred) if have_truth else None
539
 
 
547
  f"Batch time: {_now_ms()-t_start} ms",
548
  ]
549
  if have_truth and score is not None:
 
550
  total_tp = int(df["TP"].sum())
551
  total_fp = int(df["FP"].sum())
552
  total_fn = int(df["FN"].sum())
 
560
  ]
561
  diag_str = "\n".join(diag)
562
 
563
+ # save CSV for download
564
+ out_csv = Path("/tmp/batch_results.csv")
565
+ df.to_csv(out_csv, index=False, encoding="utf-8")
 
566
 
567
+ return ("Batch done.", diag_str, df, str(out_csv))
568
 
569
  # =========================
570
  # UI
 
578
  with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
579
  gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
580
  gr.Markdown(
581
+ "Extract challenge labels from transcripts. False negatives are penalised 2× more than false positives "
582
+ "in the official score, so the app biases for recall."
 
 
583
  )
584
 
585
  with gr.Tab("Single transcript"):
 
591
  type="filepath",
592
  )
593
  text = gr.Textbox(label="Or paste transcript", lines=14)
594
+ use_cleaning = gr.Checkbox(
595
+ label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
596
+ value=True,
597
+ )
598
  labels_text = gr.Textbox(
599
+ label="Allowed Labels (one per line; empty = official list)",
600
  value="",
601
  lines=8,
602
  )
 
635
 
636
  with gr.Row():
637
  status = gr.Textbox(label="Status", lines=1)
638
+ diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
 
 
 
 
639
 
640
+ df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
641
+ csv_out = gr.File(label="Download CSV", interactive=False)
 
 
 
 
642
 
643
  run_batch_btn.click(
644
  fn=run_batch,
645
  inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
646
+ outputs=[status, diag_b, df_out, csv_out],
647
  )
648
 
649
  if __name__ == "__main__":