RishiRP commited on
Commit
60a93e1
·
verified ·
1 Parent(s): 9789731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -176
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # app.py
2
- # From Talk to Task — Multilingual (EN/FR/DE/IT)
3
- # Focus: ACCURACY evaluation against UBS Ground Truth + rich diagnostics + downloadable artifacts
 
 
4
 
5
  import os
6
  import io
@@ -18,7 +20,6 @@ import gradio as gr
18
 
19
  DEFAULT_REPO = "swiss-ai/Apertus-8B-Instruct-2509"
20
 
21
- # Default label set (can be overridden by uploading a Rules JSON with {"labels":[...]}).
22
  DEFAULT_LABEL_SET = [
23
  "plan_contact",
24
  "schedule_meeting",
@@ -30,29 +31,58 @@ DEFAULT_LABEL_SET = [
30
  "update_kyc_total_assets",
31
  ]
32
 
33
- SYSTEM_INSTRUCTIONS = (
34
- "You are a task extraction assistant.\n"
35
- "Input transcript language can be English, French, German, or Italian.\n"
36
- "Output valid JSON ONLY (no prose) with a single field:\n"
37
  '"labels": a list of strings chosen ONLY from the allowed label set.\n'
38
- "Do not invent other fields. Do not translate labels. Return JSON only."
 
39
  )
40
 
41
  CONTEXT_GUIDE = (
42
- "- plan_contact: contact without firm date/time\n"
43
- "- schedule_meeting: explicit date/time/modality confirmed\n"
44
  "- update_contact_info_non_postal: email/phone updates\n"
45
  "- update_contact_info_postal_address: mailing address updates\n"
46
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # --------------------- WRITABLE HF CACHE -----------------------------
50
 
51
  HOME = Path(os.environ.get("HOME", "/home/user"))
52
  CACHE_DIR = HOME / ".cache" / "huggingface"
53
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
54
  os.environ.setdefault("HF_HOME", str(CACHE_DIR))
55
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads when supported
56
 
57
  HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
58
 
@@ -63,21 +93,14 @@ try:
63
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
64
  except Exception as e:
65
  raise RuntimeError(
66
- "Missing deps. In requirements.txt include: transformers>=4.56.0, torch, accelerate, huggingface_hub, bitsandbytes, gradio"
67
  ) from e
68
 
69
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
  GPU_NAME = torch.cuda.get_device_name(0) if DEVICE == "cuda" else "cpu"
71
- # Force fp16 on CUDA (T4 doesnt support bf16) for stable perf
72
  DTYPE_FALLBACK = torch.float16 if DEVICE == "cuda" else torch.float32
73
 
74
- # Optional ZeroGPU presence
75
- try:
76
- import spaces # noqa: F401
77
- ON_ZERO_GPU = True
78
- except Exception:
79
- ON_ZERO_GPU = False
80
-
81
  # -------------------------- HELPERS ---------------------------------
82
 
83
  RE_DISCLAIMER = re.compile(r"^\s*disclaimer\s*:", re.IGNORECASE)
@@ -115,20 +138,57 @@ def read_rules_labels(file_obj: Optional[gr.File]) -> Optional[List[str]]:
115
  except Exception:
116
  return None
117
 
118
- def build_prompt(system: str, context: str, transcript: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  return (
120
  f"### System\n{system}\n\n"
121
- f"### Context\n{context}\n\n"
 
 
122
  f"### Transcript\n{transcript}\n\n"
123
- "### Output\nReturn JSON only."
124
  )
125
 
126
  def prf1_accuracy(pred: List[str], gold: List[str]) -> Tuple[float, float, float, float, Dict[str, int]]:
127
- """Micro P/R/F1 + Jaccard-like accuracy (intersection/union)."""
128
  pset, gset = set(pred), set(gold)
129
- tp = len(pset & gset)
130
- fp = len(pset - gset)
131
- fn = len(gset - pset)
132
  prec = tp / (tp + fp) if (tp + fp) else 0.0
133
  rec = tp / (tp + fn) if (tp + fn) else 0.0
134
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
@@ -137,7 +197,6 @@ def prf1_accuracy(pred: List[str], gold: List[str]) -> Tuple[float, float, float
137
  return prec, rec, f1, acc, {"tp": tp, "fp": fp, "fn": fn, "pred_total": len(pset), "gold_total": len(gset)}
138
 
139
  def per_label_counts(pred: List[str], gold: List[str], all_labels: List[str]) -> Dict[str, Dict[str, int]]:
140
- """TP/FP/FN per label."""
141
  pset, gset = set(pred), set(gold)
142
  out = {}
143
  for lab in all_labels:
@@ -148,48 +207,16 @@ def per_label_counts(pred: List[str], gold: List[str], all_labels: List[str]) ->
148
  return out
149
 
150
  def hamming_loss(pred: List[str], gold: List[str], all_labels: List[str]) -> float:
151
- """Hamming loss over the label universe."""
152
  pset, gset = set(pred), set(gold)
153
  wrong = 0
154
  for lab in all_labels:
155
  in_p, in_g = (lab in pset), (lab in gset)
156
- if in_p != in_g:
157
- wrong += 1
158
  return wrong / max(1, len(all_labels))
159
 
160
- def read_single_ground_truth(file_obj: Optional[gr.File]) -> Optional[List[str]]:
161
- if not file_obj:
162
- return None
163
- try:
164
- data = json.loads(Path(file_obj.name).read_text(encoding="utf-8"))
165
- labels = data.get("labels", [])
166
- return [lab for lab in labels if isinstance(lab, str)]
167
- except Exception:
168
- return None
169
-
170
- def read_batch_ground_truth_zip(zip_file: Optional[gr.File]) -> Dict[str, List[str]]:
171
- out: Dict[str, List[str]] = {}
172
- if not zip_file:
173
- return out
174
- try:
175
- with zipfile.ZipFile(zip_file.name) as z:
176
- for name in z.namelist():
177
- if not name.lower().endswith(".json"):
178
- continue
179
- try:
180
- data = json.loads(z.read(name).decode("utf-8", errors="replace"))
181
- labs = [lab for lab in data.get("labels", []) if isinstance(lab, str)]
182
- out[Path(name).with_suffix("").name] = labs
183
- except Exception:
184
- pass
185
- except Exception:
186
- pass
187
- return out
188
-
189
  def write_csv(path: Path, rows: List[List[str]]):
190
  with path.open("w", newline="", encoding="utf-8") as f:
191
- w = csv.writer(f)
192
- w.writerows(rows)
193
 
194
  # -------------------------- MODEL -----------------------------------
195
 
@@ -234,10 +261,10 @@ class HFModel:
234
  self.model = self.model.to(DEVICE)
235
 
236
  @torch.inference_mode()
237
- def generate_json(self, prompt: str, max_new_tokens=32) -> Tuple[str, Dict[str, int]]:
238
  """
239
- Deterministic generation, returns (json_text, token_stats)
240
- token_stats: dict with prompt_tokens, output_tokens, total_tokens
241
  """
242
  tok = self.tokenizer
243
  mdl = self.model
@@ -246,15 +273,17 @@ class HFModel:
246
  templated = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
247
  inputs = tok([templated], return_tensors="pt", add_special_tokens=False).to(mdl.device)
248
 
249
- out = mdl.generate(
250
- **inputs,
251
  max_new_tokens=max_new_tokens,
252
- do_sample=False, # deterministic for classification
253
- temperature=0.0,
254
- top_p=1.0,
255
  pad_token_id=tok.eos_token_id,
256
  eos_token_id=tok.eos_token_id,
257
  )
 
 
 
 
 
 
258
 
259
  prompt_tokens = int(inputs.input_ids.shape[-1])
260
  output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
@@ -290,11 +319,21 @@ def preprocess_text(txt: str, add_header: bool, strip_smalltalk: bool) -> str:
290
  cleaned = "\n".join(lines[-32768:])
291
  return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
292
 
 
 
 
 
 
 
 
 
 
 
293
  def run_single(
294
  custom_repo_id: str,
295
  rules_json: Optional[gr.File],
296
- system: str,
297
- context: str,
298
  transcript: str,
299
  soft_token_cap: int,
300
  preprocess: bool,
@@ -303,95 +342,118 @@ def run_single(
303
  load_in_4bit: bool,
304
  hourly_rate: float,
305
  gt_json_file: Optional[gr.File],
 
306
  ):
307
- """
308
- Returns: repo, revision, json_out, diagnostics_text, metrics_json
309
- """
310
- t0 = time.perf_counter()
311
  repo = (custom_repo_id or DEFAULT_REPO).strip()
312
  revision = "main"
313
-
314
- # Resolve allowed labels
315
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
316
 
317
- # Preprocess
318
  effective_len = len(transcript)
319
  if preprocess:
320
  transcript = preprocess_text(transcript, add_header, strip_smalltalk)
321
  effective_len = len(transcript)
322
 
323
- # Soft cap (~4 chars / token rough heuristic)
324
  cap_info = ""
325
  if soft_token_cap and soft_token_cap > 0:
326
  approx_chars = int(soft_token_cap * 4)
327
  if len(transcript) > approx_chars:
328
  transcript = transcript[-approx_chars:]
329
- cap_info = f"(soft cap ~{soft_token_cap}t applied)"
330
 
331
- prompt = build_prompt(system or SYSTEM_INSTRUCTIONS, context or CONTEXT_GUIDE, transcript)
 
 
332
 
333
  model = get_model(repo, revision, load_in_4bit)
334
- gen_t0 = time.perf_counter()
335
- raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=32)
336
- gen_latency = time.perf_counter() - gen_t0
 
337
  pred_labels = safe_json_labels(raw_json, allowed)
338
 
 
 
 
 
 
 
 
 
 
 
339
  total_latency = time.perf_counter() - t0
340
  est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
341
 
342
  # Ground truth
343
  gt_labels = read_single_ground_truth(gt_json_file)
344
- detailed = {}
345
  pr = rc = f1 = acc = 0.0
346
- ham = 0.0
347
- missing = []
348
- extra = []
349
- per_label = {}
350
-
351
  if gt_labels is not None:
352
  pr, rc, f1, acc, counts = prf1_accuracy(pred_labels, gt_labels)
353
  ham = hamming_loss(pred_labels, gt_labels, allowed)
354
  per_label = per_label_counts(pred_labels, gt_labels, allowed)
355
  missing = sorted(list(set(gt_labels) - set(pred_labels)))
356
- extra = sorted(list(set(pred_labels) - set(gt_labels)))
357
- detailed = {
358
- "tp": counts["tp"], "fp": counts["fp"], "fn": counts["fn"],
359
- "missing_labels": missing, "extra_labels": extra,
360
- "per_label": per_label
361
- }
362
-
363
- diagnostics = "\n".join([
364
- f"Repo: {repo} | Rev: {revision}",
365
- f"Device: {DEVICE} ({GPU_NAME}) | DType: {DTYPE_FALLBACK} | 4bit: {bool(load_in_4bit)}",
366
- f"Allowed labels: {allowed}",
367
- f"Effective text length (chars): {effective_len} {cap_info}",
368
- f"Tokens prompt: {tok_stats['prompt_tokens']} | output: {tok_stats['output_tokens']} | total: {tok_stats['total_tokens']}",
369
- f"Latency generation: {gen_latency:.2f}s | total: {total_latency:.2f}s",
370
- f"Cost estimate (@{hourly_rate:.4f}/hr): ${est_cost:.6f}",
371
- ])
372
-
373
- metrics = {
 
 
 
 
 
 
 
 
 
 
 
374
  "labels_pred": pred_labels,
375
  "ground_truth_labels": gt_labels,
376
- "precision": round(pr, 4),
377
- "recall": round(rc, 4),
378
- "f1": round(f1, 4),
379
- "exact_match": 1.0 if gt_labels is not None and set(pred_labels) == set(gt_labels) else 0.0 if gt_labels is not None else None,
380
- "hamming_loss": round(ham, 4) if gt_labels is not None else None,
381
- "jaccard": round(prf1_accuracy(pred_labels, gt_labels)[3], 4) if gt_labels is not None else None,
382
- "detailed": detailed or None,
 
383
  "token_stats": tok_stats,
384
  "latency_seconds": round(total_latency, 3),
385
  "estimated_cost_usd": round(est_cost, 6),
 
386
  }
387
 
388
- return repo, revision, json.dumps({"labels": pred_labels}, ensure_ascii=False), diagnostics, json.dumps(metrics, indent=2)
 
 
 
 
 
389
 
390
  def run_batch(
391
  custom_repo_id: str,
392
  rules_json: Optional[gr.File],
393
- system: str,
394
- context: str,
395
  transcripts_zip: Optional[gr.File],
396
  gt_zip: Optional[gr.File],
397
  soft_token_cap: int,
@@ -400,46 +462,36 @@ def run_batch(
400
  strip_smalltalk: bool,
401
  load_in_4bit: bool,
402
  hourly_rate: float,
 
403
  ):
404
- """
405
- Batch: transcripts ZIP of *.txt, optional ground-truth ZIP of *.json matching filenames.
406
- Returns: repo, revision, csv_text, diagnostics, summary_json, downloads (3 files)
407
- """
408
  repo = (custom_repo_id or DEFAULT_REPO).strip()
409
  revision = "main"
410
-
411
  if not transcripts_zip:
412
  return repo, revision, "filename,labels\n", "No transcript ZIP provided.", "{}", None, None, None
413
 
414
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
415
-
416
  try:
417
  z = zipfile.ZipFile(transcripts_zip.name)
418
  txt_names = [n for n in z.namelist() if n.lower().endswith(".txt")]
419
  except Exception as e:
420
  return repo, revision, "filename,labels\n", f"Bad transcript ZIP: {e}", "{}", None, None, None
421
 
422
- gt_map = read_batch_ground_truth_zip(gt_zip) # stem -> labels
423
  model = get_model(repo, revision, load_in_4bit)
424
 
425
  rows = [["filename","labels"]]
426
  per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra"]]
427
-
428
  totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
429
  label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
 
430
 
431
- total_prompt_tokens = 0
432
- total_output_tokens = 0
433
- total_secs = 0.0
434
- n = 0
435
- samples_with_gt = 0
436
 
437
  for name in txt_names:
438
  try:
439
  txt = z.read(name).decode("utf-8", errors="replace")
440
  except Exception:
441
- rows.append([name, "[] # unreadable"])
442
- continue
443
 
444
  if preprocess:
445
  txt = preprocess_text(txt, add_header, strip_smalltalk)
@@ -449,37 +501,39 @@ def run_batch(
449
  if len(txt) > approx_chars:
450
  txt = txt[-approx_chars:]
451
 
452
- prompt = build_prompt(system or SYSTEM_INSTRUCTIONS, context or CONTEXT_GUIDE, txt)
453
 
454
  t0 = time.perf_counter()
455
- raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=32)
 
 
 
 
 
 
 
 
456
  total_secs += (time.perf_counter() - t0)
457
  total_prompt_tokens += tok_stats["prompt_tokens"]
458
  total_output_tokens += tok_stats["output_tokens"]
459
  n += 1
460
 
461
- pred = safe_json_labels(raw_json, allowed)
462
  rows.append([name, json.dumps(pred, ensure_ascii=False)])
463
 
464
  stem = Path(name).with_suffix("").name
465
  gold = gt_map.get(stem)
466
-
467
  if gold is not None:
468
- samples_with_gt += 1
469
  pr, rc, f1, acc, counts = prf1_accuracy(pred, gold)
470
  ham = hamming_loss(pred, gold, allowed)
471
  missing = sorted(list(set(gold) - set(pred)))
472
  extra = sorted(list(set(pred) - set(gold)))
473
-
474
- # aggregate
475
  for k in ["tp","fp","fn","pred_total","gold_total"]:
476
  totals[k] += counts[k]
477
- # per-label global
478
  pl = per_label_counts(pred, gold, allowed)
479
  for lab, c in pl.items():
480
  for k in ["tp","fp","fn"]:
481
  label_global[lab][k] += c[k]
482
-
483
  per_sample_rows.append([
484
  name,
485
  json.dumps(pred, ensure_ascii=False),
@@ -491,16 +545,12 @@ def run_batch(
491
  json.dumps(extra, ensure_ascii=False),
492
  ])
493
 
494
- # macro summary (micro over totals)
495
  tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
496
  prec = tp / (tp + fp) if (tp + fp) else 0.0
497
  rec = tp / (tp + fn) if (tp + fn) else 0.0
498
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
 
499
 
500
- hourly_rate = max(0.0, float(hourly_rate or 0.0))
501
- est_cost = (total_secs / 3600.0) * hourly_rate
502
-
503
- # coverage: did we ever predict each label at least once?
504
  coverage = {lab: 0 for lab in allowed}
505
  for r in rows[1:]:
506
  try:
@@ -513,7 +563,7 @@ def run_batch(
513
 
514
  summary = {
515
  "files_processed": n,
516
- "files_with_ground_truth": samples_with_gt,
517
  "labels_allowed": allowed,
518
  "precision_micro": round(prec, 4),
519
  "recall_micro": round(rec, 4),
@@ -532,21 +582,24 @@ def run_batch(
532
  "estimated_cost_usd": round(est_cost, 6),
533
  }
534
 
535
- diagnostics = (
536
- f"Repo: {repo} | Rev: {revision} | Device: {DEVICE} ({GPU_NAME}) | "
537
- f"DType: {DTYPE_FALLBACK} | 4bit: {bool(load_in_4bit)}\n"
538
- f"Files processed: {n} (with GT: {samples_with_gt})\n"
539
- f"Tokens prompt_total: {total_prompt_tokens} | output_total: {total_output_tokens}\n"
540
- f"Latency total: {summary['latency_seconds_total']}s | avg: {summary['avg_latency_seconds']}s\n"
541
- f"Cost estimate (@{hourly_rate:.4f}/hr): ${summary['estimated_cost_usd']}\n"
542
- f"Allowed labels: {allowed}"
543
- )
 
544
 
545
  # Write artifacts
546
  tmp_dir = Path("/tmp")
547
  pred_csv = tmp_dir / "predictions.csv"
548
  per_sample_csv = tmp_dir / "per_sample_metrics.csv"
549
  summary_json = tmp_dir / "summary_metrics.json"
 
 
550
  write_csv(pred_csv, rows)
551
  write_csv(per_sample_csv, per_sample_rows)
552
  summary_json.write_text(json.dumps(summary, indent=2), encoding="utf-8")
@@ -554,7 +607,7 @@ def run_batch(
554
  return (
555
  repo, revision,
556
  "\n".join([",".join(r) for r in rows]),
557
- diagnostics,
558
  json.dumps(summary, indent=2),
559
  str(pred_csv), str(per_sample_csv), str(summary_json)
560
  )
@@ -566,33 +619,33 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics") as demo:
566
  f"""
567
  # From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
568
 
569
- **Default model:** `{DEFAULT_REPO}` (recommended with GPU + 4-bit).
570
- Upload **UBS Ground Truth** to compute **precision / recall / F1 / accuracy** and detailed error analysis.
571
- Optionally upload a **Rules JSON** (`{{"labels":[...]}}`) to override the default allowed label set.
572
 
573
- **Output schema (model):** `{{"labels": [...]}}`
574
  """
575
  )
576
 
577
  with gr.Row():
578
  custom_repo = gr.Textbox(
579
- label="Model repo (leave empty to use default)",
580
  placeholder="e.g. swiss-ai/Apertus-8B-Instruct-2509"
581
  )
582
  load_4bit = gr.Checkbox(value=True, label="Load in 4-bit (GPU only)")
 
583
 
584
  rules_file = gr.File(label="Rules JSON (optional) — overrides allowed labels", file_types=[".json"])
585
 
586
- system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS, lines=6)
587
  context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
588
 
589
  with gr.Row():
590
  soft_cap = gr.Slider(512, 32768, value=2048, step=1, label="Soft token cap (approx)")
591
  preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
592
- with gr.Row():
593
  add_header = gr.Checkbox(value=True, label="Add cues header")
594
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
595
- hourly_rate = gr.Number(value=0.40, precision=4, label="Hourly hardware price (USD) for cost estimate")
596
 
597
  with gr.Tabs():
598
  with gr.Tab("Single Transcript"):
@@ -603,8 +656,11 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics") as demo:
603
  repo_used = gr.Textbox(label="Repo used", interactive=False)
604
  rev_used = gr.Textbox(label="Revision", interactive=False)
605
  json_out = gr.Code(label="Predicted JSON", language="json")
606
- diag_out = gr.Textbox(label="Diagnostics", lines=12)
607
- metrics_out = gr.Code(label="Metrics (PR/RC/F1/Acc, tokens, latency, errors)", language="json")
 
 
 
608
 
609
  def _single(*args):
610
  return run_single(*args)
@@ -614,9 +670,9 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics") as demo:
614
  inputs=[
615
  custom_repo, rules_file, system, context, transcript,
616
  soft_cap, preprocess, add_header, strip_smalltalk,
617
- load_4bit, hourly_rate, gt_single
618
  ],
619
- outputs=[repo_used, rev_used, json_out, diag_out, metrics_out],
620
  )
621
 
622
  with gr.Tab("Batch (ZIP)"):
@@ -627,10 +683,10 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics") as demo:
627
  repo_used_b = gr.Textbox(label="Repo used", interactive=False)
628
  rev_used_b = gr.Textbox(label="Revision", interactive=False)
629
  csv_out = gr.Textbox(label="Predictions CSV (filename,labels)", lines=12)
630
- diag_out_b = gr.Textbox(label="Diagnostics", lines=12)
631
- metrics_out_b = gr.Code(label="Summary Metrics (micro PR/RC/F1, per-label counts, tokens, latency)", language="json")
632
 
633
- # Downloadables
 
 
634
  preds_file = gr.File(label="Download predictions.csv")
635
  per_sample_file = gr.File(label="Download per_sample_metrics.csv")
636
  summary_file = gr.File(label="Download summary_metrics.json")
@@ -643,9 +699,9 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics") as demo:
643
  inputs=[
644
  custom_repo, rules_file, system, context, zip_in, gt_zip,
645
  soft_cap, preprocess, add_header, strip_smalltalk,
646
- load_4bit, hourly_rate
647
  ],
648
- outputs=[repo_used_b, rev_used_b, csv_out, diag_out_b, metrics_out_b, preds_file, per_sample_file, summary_file],
649
  )
650
 
651
  gr.Markdown(
 
1
  # app.py
2
+ # From Talk to Task — Accuracy & Diagnostics with user-friendly metric cards
3
+ # Model: swiss-ai/Apertus-8B-Instruct-2509
4
+ # Multilingual (EN/FR/DE/IT), writable cache, few-shot prompting, smart fallback,
5
+ # per-sample & batch metrics, and downloadable artifacts.
6
 
7
  import os
8
  import io
 
20
 
21
  DEFAULT_REPO = "swiss-ai/Apertus-8B-Instruct-2509"
22
 
 
23
  DEFAULT_LABEL_SET = [
24
  "plan_contact",
25
  "schedule_meeting",
 
31
  "update_kyc_total_assets",
32
  ]
33
 
34
+ SYSTEM_INSTRUCTIONS_BASE = (
35
+ "You are a task extraction assistant. Input transcript language may be English, French, "
36
+ "German, or Italian. Return ONLY valid JSON with a single field:\n"
 
37
  '"labels": a list of strings chosen ONLY from the allowed label set.\n'
38
+ "Do NOT add other fields or prose. Do NOT translate labels. If multiple labels apply, return all.\n"
39
+ "If none apply, return an empty list."
40
  )
41
 
42
  CONTEXT_GUIDE = (
43
+ "- plan_contact: conversation without a firm date/time\n"
44
+ "- schedule_meeting: explicit date/time/modality is agreed\n"
45
  "- update_contact_info_non_postal: email/phone updates\n"
46
  "- update_contact_info_postal_address: mailing address updates\n"
47
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
48
  )
49
 
50
+ # Few-shot exemplars to improve recall/F1 across languages
51
+ FEW_SHOTS = [
52
+ # EN
53
+ {
54
+ "transcript": "Agent: Can we meet on Friday at 3pm on Teams?\nClient: Yes, Friday 3pm works.\nAgent: Great, I'll send an invite.",
55
+ "labels": ["schedule_meeting"]
56
+ },
57
+ # DE
58
+ {
59
+ "transcript": "Kunde: Meine Telefonnummer hat sich geändert: +41 44 000 00 00.\nBerater: Alles klar, ich aktualisiere Ihre Kontaktdaten.",
60
+ "labels": ["update_contact_info_non_postal"]
61
+ },
62
+ # FR
63
+ {
64
+ "transcript": "Client: Nous avons acheté un nouvel appartement, l'adresse postale est Avenue X 12, 1200 Genève.\nConseiller: Merci, je mets à jour l'adresse postale.",
65
+ "labels": ["update_contact_info_postal_address"]
66
+ },
67
+ # IT
68
+ {
69
+ "transcript": "Cliente: Vorrei chiarire lo scopo del rapporto: gestione patrimoniale a lungo termine.\nConsulente: Perfetto, aggiorno lo scopo KYC.",
70
+ "labels": ["update_kyc_purpose_of_businessrelation"]
71
+ },
72
+ # EN KYC totals
73
+ {
74
+ "transcript": "Agent: To confirm, your total assets are 8,000,000 CHF with 3,700,000 in real estate.\nClient: Yes, correct.",
75
+ "labels": ["update_kyc_total_assets"]
76
+ },
77
+ ]
78
+
79
  # --------------------- WRITABLE HF CACHE -----------------------------
80
 
81
  HOME = Path(os.environ.get("HOME", "/home/user"))
82
  CACHE_DIR = HOME / ".cache" / "huggingface"
83
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
84
  os.environ.setdefault("HF_HOME", str(CACHE_DIR))
85
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
86
 
87
  HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
88
 
 
93
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
94
  except Exception as e:
95
  raise RuntimeError(
96
+ "Missing deps. requirements.txt must include: transformers>=4.56.0, torch, accelerate, huggingface_hub, bitsandbytes, gradio"
97
  ) from e
98
 
99
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
100
  GPU_NAME = torch.cuda.get_device_name(0) if DEVICE == "cuda" else "cpu"
101
+ # T4 doesn't support bf16 use fp16; CPU uses fp32
102
  DTYPE_FALLBACK = torch.float16 if DEVICE == "cuda" else torch.float32
103
 
 
 
 
 
 
 
 
104
  # -------------------------- HELPERS ---------------------------------
105
 
106
  RE_DISCLAIMER = re.compile(r"^\s*disclaimer\s*:", re.IGNORECASE)
 
138
  except Exception:
139
  return None
140
 
141
+ def read_single_ground_truth(file_obj: Optional[gr.File]) -> Optional[List[str]]:
142
+ if not file_obj:
143
+ return None
144
+ try:
145
+ data = json.loads(Path(file_obj.name).read_text(encoding="utf-8"))
146
+ labels = data.get("labels", [])
147
+ return [lab for lab in labels if isinstance(lab, str)]
148
+ except Exception:
149
+ return None
150
+
151
+ def read_batch_ground_truth_zip(zip_file: Optional[gr.File]) -> Dict[str, List[str]]:
152
+ out: Dict[str, List[str]] = {}
153
+ if not zip_file:
154
+ return out
155
+ try:
156
+ with zipfile.ZipFile(zip_file.name) as z:
157
+ for name in z.namelist():
158
+ if not name.lower().endswith(".json"):
159
+ continue
160
+ try:
161
+ data = json.loads(z.read(name).decode("utf-8", errors="replace"))
162
+ labs = [lab for lab in data.get("labels", []) if isinstance(lab, str)]
163
+ out[Path(name).with_suffix("").name] = labs
164
+ except Exception:
165
+ pass
166
+ except Exception:
167
+ pass
168
+ return out
169
+
170
+ def build_fewshot_block(allowed: List[str]) -> str:
171
+ shots = []
172
+ for ex in FEW_SHOTS:
173
+ shots.append(
174
+ f"- Transcript:\n{ex['transcript']}\n- Correct labels (choose subset from {allowed}): {ex['labels']}\n"
175
+ )
176
+ return "\n".join(shots)
177
+
178
+ def build_prompt(system: str, context: str, transcript: str, allowed: List[str], use_fewshot: bool) -> str:
179
+ fewshot_section = f"\n### Examples\n{build_fewshot_block(allowed)}\n" if use_fewshot else ""
180
  return (
181
  f"### System\n{system}\n\n"
182
+ f"### Allowed label set\n{allowed}\n\n"
183
+ f"### Context\n{context}\n"
184
+ f"{fewshot_section}\n"
185
  f"### Transcript\n{transcript}\n\n"
186
+ "### Output\nReturn JSON only: {\"labels\": [...]}"
187
  )
188
 
189
  def prf1_accuracy(pred: List[str], gold: List[str]) -> Tuple[float, float, float, float, Dict[str, int]]:
 
190
  pset, gset = set(pred), set(gold)
191
+ tp = len(pset & gset); fp = len(pset - gset); fn = len(gset - pset)
 
 
192
  prec = tp / (tp + fp) if (tp + fp) else 0.0
193
  rec = tp / (tp + fn) if (tp + fn) else 0.0
194
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
 
197
  return prec, rec, f1, acc, {"tp": tp, "fp": fp, "fn": fn, "pred_total": len(pset), "gold_total": len(gset)}
198
 
199
  def per_label_counts(pred: List[str], gold: List[str], all_labels: List[str]) -> Dict[str, Dict[str, int]]:
 
200
  pset, gset = set(pred), set(gold)
201
  out = {}
202
  for lab in all_labels:
 
207
  return out
208
 
209
  def hamming_loss(pred: List[str], gold: List[str], all_labels: List[str]) -> float:
 
210
  pset, gset = set(pred), set(gold)
211
  wrong = 0
212
  for lab in all_labels:
213
  in_p, in_g = (lab in pset), (lab in gset)
214
+ wrong += int(in_p != in_g)
 
215
  return wrong / max(1, len(all_labels))
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def write_csv(path: Path, rows: List[List[str]]):
218
  with path.open("w", newline="", encoding="utf-8") as f:
219
+ w = csv.writer(f); w.writerows(rows)
 
220
 
221
  # -------------------------- MODEL -----------------------------------
222
 
 
261
  self.model = self.model.to(DEVICE)
262
 
263
  @torch.inference_mode()
264
+ def generate_json(self, prompt: str, max_new_tokens=64, allow_sampling=False) -> Tuple[str, Dict[str, int]]:
265
  """
266
+ Deterministic by default. If allow_sampling=True (fallback), we use mild temperature.
267
+ Returns (json_text, token_stats)
268
  """
269
  tok = self.tokenizer
270
  mdl = self.model
 
273
  templated = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
274
  inputs = tok([templated], return_tensors="pt", add_special_tokens=False).to(mdl.device)
275
 
276
+ kwargs = dict(
 
277
  max_new_tokens=max_new_tokens,
 
 
 
278
  pad_token_id=tok.eos_token_id,
279
  eos_token_id=tok.eos_token_id,
280
  )
281
+ if allow_sampling:
282
+ kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
283
+ else:
284
+ kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
285
+
286
+ out = mdl.generate(**inputs, **kwargs)
287
 
288
  prompt_tokens = int(inputs.input_ids.shape[-1])
289
  output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
 
319
  cleaned = "\n".join(lines[-32768:])
320
  return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
321
 
322
+ def card_markdown(title: str, value: str, hint: str = "") -> str:
323
+ hint_md = f"<div style='font-size:12px;opacity:0.8'>{hint}</div>" if hint else ""
324
+ return f"""
325
+ <div style="border:1px solid #3a3a3a;border-radius:10px;padding:10px;margin:6px">
326
+ <div style="font-weight:600">{title}</div>
327
+ <div style="font-size:20px;margin-top:4px">{value}</div>
328
+ {hint_md}
329
+ </div>
330
+ """
331
+
332
  def run_single(
333
  custom_repo_id: str,
334
  rules_json: Optional[gr.File],
335
+ system_instructions: str,
336
+ context_text: str,
337
  transcript: str,
338
  soft_token_cap: int,
339
  preprocess: bool,
 
342
  load_in_4bit: bool,
343
  hourly_rate: float,
344
  gt_json_file: Optional[gr.File],
345
+ use_fewshot: bool,
346
  ):
347
+ """Returns: repo, revision, predicted_json, metrics_cards_md, diag_cards_md, raw_metrics_json"""
348
+
 
 
349
  repo = (custom_repo_id or DEFAULT_REPO).strip()
350
  revision = "main"
 
 
351
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
352
 
353
+ # Preprocess + cap
354
  effective_len = len(transcript)
355
  if preprocess:
356
  transcript = preprocess_text(transcript, add_header, strip_smalltalk)
357
  effective_len = len(transcript)
358
 
 
359
  cap_info = ""
360
  if soft_token_cap and soft_token_cap > 0:
361
  approx_chars = int(soft_token_cap * 4)
362
  if len(transcript) > approx_chars:
363
  transcript = transcript[-approx_chars:]
364
+ cap_info = f"(soft cap ~{soft_token_cap}t)"
365
 
366
+ # Build prompt (few-shot helps recall)
367
+ system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
368
+ prompt = build_prompt(system, context_text or CONTEXT_GUIDE, transcript, allowed, use_fewshot)
369
 
370
  model = get_model(repo, revision, load_in_4bit)
371
+
372
+ # First pass: deterministic
373
+ t0 = time.perf_counter()
374
+ raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=64, allow_sampling=False)
375
  pred_labels = safe_json_labels(raw_json, allowed)
376
 
377
+ # Fallback: if empty, try mild sampling once
378
+ fallback_used = False
379
+ if not pred_labels:
380
+ raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=64, allow_sampling=True)
381
+ pred_labels2 = safe_json_labels(raw_json2, allowed)
382
+ if pred_labels2:
383
+ pred_labels = pred_labels2
384
+ tok_stats = tok_stats2
385
+ fallback_used = True
386
+
387
  total_latency = time.perf_counter() - t0
388
  est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
389
 
390
  # Ground truth
391
  gt_labels = read_single_ground_truth(gt_json_file)
 
392
  pr = rc = f1 = acc = 0.0
393
+ ham = None
394
+ missing = []; extra = []; per_label = {}
 
 
 
395
  if gt_labels is not None:
396
  pr, rc, f1, acc, counts = prf1_accuracy(pred_labels, gt_labels)
397
  ham = hamming_loss(pred_labels, gt_labels, allowed)
398
  per_label = per_label_counts(pred_labels, gt_labels, allowed)
399
  missing = sorted(list(set(gt_labels) - set(pred_labels)))
400
+ extra = sorted(list(set(pred_labels) - set(gt_labels)))
401
+
402
+ # ------- User-friendly metric cards -------
403
+ metric_cards = ""
404
+ metric_cards += card_markdown("Precision", f"{pr:.3f}" if gt_labels is not None else "—", "Correct positive labels / All predicted positive labels")
405
+ metric_cards += card_markdown("Recall", f"{rc:.3f}" if gt_labels is not None else "—", "Correct positive labels / All actual positive labels")
406
+ metric_cards += card_markdown("F1 score", f"{f1:.3f}" if gt_labels is not None else "—", "Harmonic mean of Precision and Recall")
407
+ metric_cards += card_markdown("Exact match", f"{1.0 if gt_labels and set(pred_labels)==set(gt_labels) else 0.0 if gt_labels is not None else '—'}", "1.0 if predicted labels exactly equal ground truth")
408
+ metric_cards += card_markdown("Hamming loss", f"{ham:.3f}" if ham is not None else "—", "Fraction of labels where prediction disagrees with truth (lower is better)")
409
+ metric_cards += card_markdown("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
410
+ metric_cards += card_markdown("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
411
+
412
+ # ------- Diagnostics cards -------
413
+ diag_cards = ""
414
+ diag_cards += card_markdown("Model / Rev", f"{repo} / {revision}")
415
+ diag_cards += card_markdown("Device", f"{DEVICE} ({GPU_NAME})")
416
+ diag_cards += card_markdown("Precision dtype", f"{DTYPE_FALLBACK}")
417
+ diag_cards += card_markdown("4-bit", f"{bool(load_in_4bit)}")
418
+ diag_cards += card_markdown("Allowed labels", json.dumps(allowed, ensure_ascii=False))
419
+ diag_cards += card_markdown("Effective text length", f"{effective_len} chars {cap_info}")
420
+ diag_cards += card_markdown("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts help explain latency and cost")
421
+ diag_cards += card_markdown("Latency", f"{total_latency:.2f} s", "End-to-end time (first run includes caching)")
422
+ diag_cards += card_markdown("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
423
+ if fallback_used:
424
+ diag_cards += card_markdown("Fallback used", "Yes", "Empty prediction in first pass; retried with mild sampling to improve recall")
425
+ else:
426
+ diag_cards += card_markdown("Fallback used", "No")
427
+
428
+ raw_metrics = {
429
  "labels_pred": pred_labels,
430
  "ground_truth_labels": gt_labels,
431
+ "precision": round(pr, 4) if gt_labels is not None else None,
432
+ "recall": round(rc, 4) if gt_labels is not None else None,
433
+ "f1": round(f1, 4) if gt_labels is not None else None,
434
+ "exact_match": 1.0 if gt_labels and set(pred_labels)==set(gt_labels) else (0.0 if gt_labels is not None else None),
435
+ "hamming_loss": round(ham, 4) if ham is not None else None,
436
+ "missing": missing if gt_labels is not None else None,
437
+ "extra": extra if gt_labels is not None else None,
438
+ "per_label": per_label if gt_labels is not None else None,
439
  "token_stats": tok_stats,
440
  "latency_seconds": round(total_latency, 3),
441
  "estimated_cost_usd": round(est_cost, 6),
442
+ "fallback_used": fallback_used,
443
  }
444
 
445
+ return (
446
+ repo, revision,
447
+ json.dumps({"labels": pred_labels}, ensure_ascii=False),
448
+ metric_cards, diag_cards,
449
+ json.dumps(raw_metrics, indent=2)
450
+ )
451
 
452
  def run_batch(
453
  custom_repo_id: str,
454
  rules_json: Optional[gr.File],
455
+ system_instructions: str,
456
+ context_text: str,
457
  transcripts_zip: Optional[gr.File],
458
  gt_zip: Optional[gr.File],
459
  soft_token_cap: int,
 
462
  strip_smalltalk: bool,
463
  load_in_4bit: bool,
464
  hourly_rate: float,
465
+ use_fewshot: bool,
466
  ):
 
 
 
 
467
  repo = (custom_repo_id or DEFAULT_REPO).strip()
468
  revision = "main"
 
469
  if not transcripts_zip:
470
  return repo, revision, "filename,labels\n", "No transcript ZIP provided.", "{}", None, None, None
471
 
472
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
 
473
  try:
474
  z = zipfile.ZipFile(transcripts_zip.name)
475
  txt_names = [n for n in z.namelist() if n.lower().endswith(".txt")]
476
  except Exception as e:
477
  return repo, revision, "filename,labels\n", f"Bad transcript ZIP: {e}", "{}", None, None, None
478
 
479
+ gt_map = read_batch_ground_truth_zip(gt_zip)
480
  model = get_model(repo, revision, load_in_4bit)
481
 
482
  rows = [["filename","labels"]]
483
  per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra"]]
 
484
  totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
485
  label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
486
+ total_prompt_tokens = 0; total_output_tokens = 0; total_secs = 0.0; n=0; with_gt=0
487
 
488
+ system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
 
 
 
 
489
 
490
  for name in txt_names:
491
  try:
492
  txt = z.read(name).decode("utf-8", errors="replace")
493
  except Exception:
494
+ rows.append([name, "[] # unreadable"]); continue
 
495
 
496
  if preprocess:
497
  txt = preprocess_text(txt, add_header, strip_smalltalk)
 
501
  if len(txt) > approx_chars:
502
  txt = txt[-approx_chars:]
503
 
504
+ prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt, allowed, use_fewshot)
505
 
506
  t0 = time.perf_counter()
507
+ raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=64, allow_sampling=False)
508
+ pred = safe_json_labels(raw_json, allowed)
509
+ if not pred:
510
+ raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=64, allow_sampling=True)
511
+ pred2 = safe_json_labels(raw_json2, allowed)
512
+ if pred2:
513
+ pred = pred2
514
+ tok_stats = tok_stats2
515
+
516
  total_secs += (time.perf_counter() - t0)
517
  total_prompt_tokens += tok_stats["prompt_tokens"]
518
  total_output_tokens += tok_stats["output_tokens"]
519
  n += 1
520
 
 
521
  rows.append([name, json.dumps(pred, ensure_ascii=False)])
522
 
523
  stem = Path(name).with_suffix("").name
524
  gold = gt_map.get(stem)
 
525
  if gold is not None:
526
+ with_gt += 1
527
  pr, rc, f1, acc, counts = prf1_accuracy(pred, gold)
528
  ham = hamming_loss(pred, gold, allowed)
529
  missing = sorted(list(set(gold) - set(pred)))
530
  extra = sorted(list(set(pred) - set(gold)))
 
 
531
  for k in ["tp","fp","fn","pred_total","gold_total"]:
532
  totals[k] += counts[k]
 
533
  pl = per_label_counts(pred, gold, allowed)
534
  for lab, c in pl.items():
535
  for k in ["tp","fp","fn"]:
536
  label_global[lab][k] += c[k]
 
537
  per_sample_rows.append([
538
  name,
539
  json.dumps(pred, ensure_ascii=False),
 
545
  json.dumps(extra, ensure_ascii=False),
546
  ])
547
 
 
548
  tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
549
  prec = tp / (tp + fp) if (tp + fp) else 0.0
550
  rec = tp / (tp + fn) if (tp + fn) else 0.0
551
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
552
+ est_cost = (total_secs / 3600.0) * max(0.0, float(hourly_rate or 0.0))
553
 
 
 
 
 
554
  coverage = {lab: 0 for lab in allowed}
555
  for r in rows[1:]:
556
  try:
 
563
 
564
  summary = {
565
  "files_processed": n,
566
+ "files_with_ground_truth": with_gt,
567
  "labels_allowed": allowed,
568
  "precision_micro": round(prec, 4),
569
  "recall_micro": round(rec, 4),
 
582
  "estimated_cost_usd": round(est_cost, 6),
583
  }
584
 
585
+ diag_cards = ""
586
+ diag_cards += card_markdown("Model / Rev", f"{repo} / {revision}")
587
+ diag_cards += card_markdown("Device", f"{DEVICE} ({GPU_NAME})")
588
+ diag_cards += card_markdown("Precision dtype", f"{DTYPE_FALLBACK}")
589
+ diag_cards += card_markdown("4-bit", f"{bool(load_in_4bit)}")
590
+ diag_cards += card_markdown("Files processed", f"{n} (with GT: {with_gt})")
591
+ diag_cards += card_markdown("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
592
+ diag_cards += card_markdown("Latency", f"total={summary['latency_seconds_total']} s, avg={summary['avg_latency_seconds']} s")
593
+ diag_cards += card_markdown("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
594
+ diag_cards += card_markdown("Allowed labels", json.dumps(allowed, ensure_ascii=False))
595
 
596
  # Write artifacts
597
  tmp_dir = Path("/tmp")
598
  pred_csv = tmp_dir / "predictions.csv"
599
  per_sample_csv = tmp_dir / "per_sample_metrics.csv"
600
  summary_json = tmp_dir / "summary_metrics.json"
601
+
602
+ # CSV/text outputs
603
  write_csv(pred_csv, rows)
604
  write_csv(per_sample_csv, per_sample_rows)
605
  summary_json.write_text(json.dumps(summary, indent=2), encoding="utf-8")
 
607
  return (
608
  repo, revision,
609
  "\n".join([",".join(r) for r in rows]),
610
+ diag_cards,
611
  json.dumps(summary, indent=2),
612
  str(pred_csv), str(per_sample_csv), str(summary_json)
613
  )
 
619
  f"""
620
  # From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
621
 
622
+ **Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
623
+ Upload **ground truth** to compute **Precision / Recall / F1 / Exact match / Hamming loss**.
624
+ You can also upload a **Rules JSON** (`{{"labels":[...]}}`) to override the allowed label set.
625
 
626
+ **Model Output schema:** `{{"labels": [...]}}`
627
  """
628
  )
629
 
630
  with gr.Row():
631
  custom_repo = gr.Textbox(
632
+ label="Model repo (empty default)",
633
  placeholder="e.g. swiss-ai/Apertus-8B-Instruct-2509"
634
  )
635
  load_4bit = gr.Checkbox(value=True, label="Load in 4-bit (GPU only)")
636
+ use_fewshot = gr.Checkbox(value=True, label="Use few-shot examples (better recall/F1)")
637
 
638
  rules_file = gr.File(label="Rules JSON (optional) — overrides allowed labels", file_types=[".json"])
639
 
640
+ system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS_BASE, lines=6)
641
  context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
642
 
643
  with gr.Row():
644
  soft_cap = gr.Slider(512, 32768, value=2048, step=1, label="Soft token cap (approx)")
645
  preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
 
646
  add_header = gr.Checkbox(value=True, label="Add cues header")
647
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
648
+ hourly_rate = gr.Number(value=0.40, precision=4, label="Hourly hardware price (USD) for cost estimate")
649
 
650
  with gr.Tabs():
651
  with gr.Tab("Single Transcript"):
 
656
  repo_used = gr.Textbox(label="Repo used", interactive=False)
657
  rev_used = gr.Textbox(label="Revision", interactive=False)
658
  json_out = gr.Code(label="Predicted JSON", language="json")
659
+
660
+ # Metric & Diagnostic cards (rendered as HTML)
661
+ metric_cards_md = gr.HTML(label="Metrics (cards)")
662
+ diag_cards_md = gr.HTML(label="Diagnostics (cards)")
663
+ raw_metrics = gr.Code(label="Raw metrics JSON", language="json")
664
 
665
  def _single(*args):
666
  return run_single(*args)
 
670
  inputs=[
671
  custom_repo, rules_file, system, context, transcript,
672
  soft_cap, preprocess, add_header, strip_smalltalk,
673
+ load_4bit, hourly_rate, gt_single, use_fewshot
674
  ],
675
+ outputs=[repo_used, rev_used, json_out, metric_cards_md, diag_cards_md, raw_metrics],
676
  )
677
 
678
  with gr.Tab("Batch (ZIP)"):
 
683
  repo_used_b = gr.Textbox(label="Repo used", interactive=False)
684
  rev_used_b = gr.Textbox(label="Revision", interactive=False)
685
  csv_out = gr.Textbox(label="Predictions CSV (filename,labels)", lines=12)
 
 
686
 
687
+ diag_cards_b = gr.HTML(label="Diagnostics (cards)")
688
+ metrics_out_b = gr.Code(label="Summary metrics JSON", language="json")
689
+
690
  preds_file = gr.File(label="Download predictions.csv")
691
  per_sample_file = gr.File(label="Download per_sample_metrics.csv")
692
  summary_file = gr.File(label="Download summary_metrics.json")
 
699
  inputs=[
700
  custom_repo, rules_file, system, context, zip_in, gt_zip,
701
  soft_cap, preprocess, add_header, strip_smalltalk,
702
+ load_4bit, hourly_rate, use_fewshot
703
  ],
704
+ outputs=[repo_used_b, rev_used_b, csv_out, diag_cards_b, metrics_out_b, preds_file, per_sample_file, summary_file],
705
  )
706
 
707
  gr.Markdown(