RishiRP commited on
Commit
38169c5
Β·
verified Β·
1 Parent(s): 34720ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -178
app.py CHANGED
@@ -1,17 +1,15 @@
1
  import os, io, re, sys, time, json, zipfile, statistics
2
  from pathlib import Path
3
- from typing import List, Dict, Tuple, Union, Optional
4
 
5
  import gradio as gr
6
  import pandas as pd
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
 
10
- # --- ZeroGPU support ---------------------------------------------------------
11
- # If the 'spaces' package is available (on Spaces), we use @spaces.GPU.
12
- # Locally / on CPU hardware, we create a no-op decorator so the code still runs.
13
  try:
14
- import spaces # provided in HF Spaces runtime
15
  except Exception:
16
  class _DummySpaces:
17
  def GPU(self, *args, **kwargs):
@@ -19,16 +17,22 @@ except Exception:
19
  return deco
20
  spaces = _DummySpaces()
21
 
22
- # --- Auth token for gated models --------------------------------------------
23
  HF_TOKEN = (
24
  os.getenv("HF_TOKEN")
25
  or os.getenv("HUGGINGFACE_HUB_TOKEN")
26
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
  )
28
 
29
- # =======================
30
- # Label Set / Scoring
31
- # =======================
 
 
 
 
 
 
32
  ALLOWED_LABELS = [
33
  "plan_contact",
34
  "schedule_meeting",
@@ -39,7 +43,7 @@ ALLOWED_LABELS = [
39
  "update_kyc_purpose_of_businessrelation",
40
  "update_kyc_total_assets",
41
  ]
42
- LABEL_TO_IDX = {l:i for i,l in enumerate(ALLOWED_LABELS)}
43
  FN_PENALTY = 2.0
44
  FP_PENALTY = 1.0
45
 
@@ -48,7 +52,7 @@ def safe_json_load(s: str):
48
  return json.loads(s)
49
  except Exception:
50
  pass
51
- m = re.search(r'\{.*\}', s, re.S)
52
  if m:
53
  try:
54
  return json.loads(m.group(0))
@@ -62,29 +66,29 @@ def _coerce_labels_list(x):
62
  for it in x:
63
  if isinstance(it, str): out.append(it)
64
  elif isinstance(it, dict):
65
- for k in ("label","value","task","category","name"):
66
  v = it.get(k)
67
  if isinstance(v, str):
68
  out.append(v); break
69
  else:
70
  if isinstance(it.get("labels"), list):
71
  out += [s for s in it["labels"] if isinstance(s, str)]
72
- seen=set(); norm=[]
 
73
  for s in out:
74
  if s not in seen:
75
  norm.append(s); seen.add(s)
76
  return norm
77
  if isinstance(x, dict):
78
- for k in ("expected_labels","labels","targets","y_true"):
79
  if k in x: return _coerce_labels_list(x[k])
80
  if "one_hot" in x and isinstance(x["one_hot"], dict):
81
- return [k for k,v in x["one_hot"].items() if v]
82
  return []
83
 
84
  def classic_metrics(pred_labels, exp_labels):
85
- pred_labels = [str(x) for x in (pred_labels or []) if isinstance(x, (str,int,float,bool))]
86
- exp_labels = [str(x) for x in (exp_labels or []) if isinstance(x, (str,int,float,bool))]
87
- pred = set(pred_labels); gold = set(exp_labels)
88
  if not pred and not gold:
89
  return True, 1.0, 1.0, 1.0, 1.0
90
  inter = pred & gold; union = pred | gold
@@ -108,9 +112,7 @@ def ubs_score_one(true_labels, pred_labels) -> float:
108
  score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
109
  return float(max(0.0, min(1.0, score)))
110
 
111
- # =======================
112
- # Lightweight Preprocess
113
- # =======================
114
  EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
115
  TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
116
  DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
@@ -201,11 +203,9 @@ def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
201
  ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
202
  est = len(ids)
203
  threshold = int(soft_cap_tokens * apply_only_if_ratio)
204
- if est <= threshold:
205
- return text
206
  parts = text.splitlines()
207
- if len(parts) <= min_lines_keep:
208
- return text
209
 
210
  keep_flags=[]
211
  for ln in parts:
@@ -230,15 +230,13 @@ def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
230
  candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
231
  candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
232
 
233
- if len(candidate.splitlines()) < min_lines_keep:
234
- return text
235
  return candidate
236
 
237
  def enforce_rules(labels, transcript_text):
238
  labels = set(labels or [])
239
  if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
240
- labels.add("schedule_meeting")
241
- labels.discard("plan_contact")
242
  if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
243
  labels.add("update_contact_info_non_postal")
244
  kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
@@ -246,9 +244,7 @@ def enforce_rules(labels, transcript_text):
246
  labels.discard("update_kyc_activity")
247
  return sorted(labels)
248
 
249
- # =======================
250
- # HF Model Wrapper
251
- # =======================
252
  class HFModel:
253
  def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
254
  self.repo_id = repo_id
@@ -260,19 +256,16 @@ class HFModel:
260
  self.model = None
261
  if load_4bit:
262
  try:
263
- quant = BitsAndBytesConfig(
264
- load_in_4bit=True,
265
- bnb_4bit_use_double_quant=True,
266
- bnb_4bit_compute_dtype=torch_dtype,
267
- bnb_4bit_quant_type="nf4"
268
  )
269
  self.model = AutoModelForCausalLM.from_pretrained(
270
  repo_id, device_map="auto", trust_remote_code=trust_remote_code,
271
- quantization_config=quant, torch_dtype=torch_dtype, token=HF_TOKEN
272
  )
273
  except Exception as e:
274
  print(f"[WARN] 4-bit load failed for {repo_id}: {e}\nFalling back to normal load...", file=sys.stderr)
275
-
276
  if self.model is None:
277
  self.model = AutoModelForCausalLM.from_pretrained(
278
  repo_id, device_map="auto", trust_remote_code=trust_remote_code,
@@ -282,9 +275,6 @@ class HFModel:
282
  self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
283
  or getattr(self.model.config, "max_sequence_length", None) or 8192
284
 
285
- def encode_len(self, text: str) -> int:
286
- return len(self.tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids)
287
-
288
  def apply_chat_template(self, system_text: str, user_text: str) -> str:
289
  if getattr(self.tokenizer, "chat_template", None):
290
  messages = [{"role":"system","content":system_text},
@@ -300,78 +290,63 @@ class HFModel:
300
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
301
  t0 = time.perf_counter()
302
  out = self.model.generate(
303
- **inputs,
304
- max_new_tokens=max_new_tokens,
305
- do_sample=False,
306
- temperature=None,
307
- top_p=None,
308
- eos_token_id=self.tokenizer.eos_token_id,
309
  )
310
  latency_ms = int((time.perf_counter() - t0) * 1000)
311
  text = self.tokenizer.decode(out[0], skip_special_tokens=True)
312
- if text.startswith(prompt):
313
- text = text[len(prompt):]
314
  return latency_ms, text, prompt
315
 
316
- # Cache
317
  MODEL_CACHE: Dict[str, HFModel] = {}
318
-
319
  def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
320
  if repo_id not in MODEL_CACHE:
321
  MODEL_CACHE[repo_id] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
322
  return MODEL_CACHE[repo_id]
323
 
324
- # =======================
325
- # ZeroGPU-decorated generator
326
- # =======================
327
- @spaces.GPU(duration=180) # required by ZeroGPU; no-op on CPU
328
  def gpu_generate(repo_id: str, system_text: str, user_text: str,
329
  load_4bit: bool, dtype: str, trust_remote_code: bool):
 
330
  hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
331
- return hf.generate_json(system_text.strip(), user_text.strip(), max_new_tokens=256)
 
 
 
 
 
332
 
333
- # =======================
334
- # Utility (ZIP I/O)
335
- # =======================
336
  def _read_zip_bytes(dataset_zip: Union[bytes, str, dict, None]) -> bytes:
337
- if dataset_zip is None:
338
- raise ValueError("No ZIP provided")
339
- if isinstance(dataset_zip, bytes):
340
- return dataset_zip
341
  if isinstance(dataset_zip, str):
342
- with open(dataset_zip, "rb") as f:
343
- return f.read()
344
  if isinstance(dataset_zip, dict) and "path" in dataset_zip:
345
- with open(dataset_zip["path"], "rb") as f:
346
- return f.read()
347
  path = getattr(dataset_zip, "name", None)
348
  if path and os.path.exists(path):
349
- with open(path, "rb") as f:
350
- return f.read()
351
- raise ValueError("Unsupported file object received from Gradio")
352
 
353
  def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
354
  zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
355
- names = zf.namelist()
356
  samples = {}
357
- for n in names:
358
  p = Path(n)
359
  if p.suffix.lower() == ".txt":
360
- sample_id = p.stem
361
- txt = zf.read(n).decode("utf-8", "replace")
362
- samples.setdefault(sample_id, ["", []])[0] = txt
363
  elif p.suffix.lower() == ".json":
364
- sample_id = p.stem
365
  try:
366
  js = json.loads(zf.read(n).decode("utf-8", "replace"))
367
  except Exception:
368
  js = []
369
- samples.setdefault(sample_id, ["", []])[1] = _coerce_labels_list(js)
370
  return samples
371
 
372
- # =======================
373
- # Core Inference (shared)
374
- # =======================
375
  DEFAULT_SYSTEM = (
376
  "You are a task extraction assistant. "
377
  "Always output valid JSON with a field \"labels\" (list of strings). "
@@ -386,6 +361,7 @@ DEFAULT_CONTEXT = (
386
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
387
  )
388
 
 
389
  def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window: int,
390
  add_cues: bool, strip_smalltalk: bool, tokenizer) -> Tuple[str, int, int]:
391
  before = len(tokenizer(raw_txt, return_tensors=None, add_special_tokens=False).input_ids)
@@ -395,10 +371,10 @@ def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window
395
  lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
396
  cue_lines = find_cue_lines(lines)
397
  if cue_lines:
398
- lines_kept = prune_by_window(lines, cue_lines, window=pre_window, strip_smalltalk=strip_smalltalk)
399
  else:
400
- lines_kept = [ln for ln in lines if not (strip_smalltalk and SMALLTALK_RX.search(ln))]
401
- t_kept = "\n".join(lines_kept)
402
  cues = extract_cues(t_kept)
403
  header = build_cues_header(cues) if add_cues else ""
404
  proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
@@ -419,9 +395,7 @@ def explain_params_markdown() -> str:
419
  "- **Load in 4-bit (GPU only)**: memory-saving quantization; has no effect on CPU Spaces."
420
  )
421
 
422
- # =======================
423
- # Single Transcript Mode
424
- # =======================
425
  def single_mode(
426
  preset_model: str, custom_model: str,
427
  system_text: str, context_text: str,
@@ -432,14 +406,14 @@ def single_mode(
432
  ):
433
  repo_id = custom_model.strip() or preset_model.strip()
434
  if not repo_id:
435
- return "Please choose a model.", "", "", "", None, None, None
436
 
437
  txt = (transcript_text or "").strip()
438
  if transcript_file and hasattr(transcript_file, "name") and os.path.exists(transcript_file.name):
439
  with open(transcript_file.name, "r", encoding="utf-8", errors="replace") as f:
440
  txt = f.read()
441
  if not txt:
442
- return "Please paste a transcript or upload a .txt file.", "", "", "", None, None, None
443
 
444
  exp = []
445
  if expected_labels_json and hasattr(expected_labels_json, "name") and os.path.exists(expected_labels_json.name):
@@ -449,27 +423,27 @@ def single_mode(
449
  except Exception:
450
  exp = []
451
 
452
- # tokenizer for preprocessing (with token)
453
  try:
454
- dummy_tok = AutoTokenizer.from_pretrained(
455
- repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN
456
- )
457
  except Exception as e:
458
- msg = ("Failed to load tokenizer for `{}`. If the model is gated, accept its license and set HF_TOKEN in "
459
- "Space β†’ Settings β†’ Secrets.\n\nError: {}").format(repo_id, e)
460
- return msg, "", "", "", None, None, None
461
 
462
  proc_text, tok_before, tok_after = prepare_input_text(
463
  txt, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
464
  )
465
- user = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
466
  system = (system_text or DEFAULT_SYSTEM).strip()
 
467
 
468
  try:
469
- latency_ms, raw_text, _ = gpu_generate(repo_id, system, user, load_4bit, dtype, trust_remote_code)
 
 
470
  except Exception as e:
471
- msg = ("Failed to run `{}`. If gated, accept license and set HF_TOKEN.\n\nError: {}").format(repo_id, e)
472
- return msg, "", "", "", None, None, None
473
 
474
  out = safe_json_load(raw_text)
475
  pred_labels = enforce_rules(out.get("labels", []), proc_text)
@@ -499,12 +473,8 @@ def single_mode(
499
  "model_calls": 1
500
  },
501
  "evaluation": None if not exp else {
502
- "exact_match": exact,
503
- "precision": prec,
504
- "recall": rec,
505
- "f1": f1,
506
- "hamming": ham,
507
- "ubs_score": ubs
508
  }
509
  }
510
  zout.writestr("FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
@@ -526,45 +496,45 @@ def single_mode(
526
  "ubs_score": round(ubs,6) if ubs is not None else None
527
  }])
528
 
529
- csv_bytes = row.to_csv(index=False).encode("utf-8")
530
- csv_buf = io.BytesIO(csv_bytes); csv_buf.name = "results_single.csv"
531
- status = "Done. (ZeroGPU-ready: model calls run inside @spaces.GPU)."
532
- return status, kpi1, kpi2, kpi3, row, csv_buf, zbuf
 
 
 
 
533
 
534
- # =======================
535
- # Batch Mode (ZIP)
536
- # =======================
537
  def run_batch_ui(models_list, custom_models_str, instructions_text, context_text, dataset_zip,
538
  soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
539
  repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
540
 
541
  models = [m for m in (models_list or [])]
542
- custom = [m.strip() for m in (custom_models_str or "").split(",") if m.strip()]
543
- models.extend(custom)
544
- models = [m for m in models if m]
545
  if not models:
546
- return pd.DataFrame(), None, None, "Please pick at least one model."
547
 
548
  if not dataset_zip:
549
- return pd.DataFrame(), None, None, "Please upload a ZIP with *.txt (+ optional matching *.json)."
550
 
551
  try:
552
  zip_bytes = _read_zip_bytes(dataset_zip)
553
  samples = parse_zip(zip_bytes)
554
  except Exception as e:
555
- return pd.DataFrame(), None, None, f"Failed to read ZIP: {e}"
556
 
557
- rows = []
558
- total_runs = 0
559
  all_artifacts = io.BytesIO()
560
  zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
 
561
 
562
  for repo_id in models:
 
563
  try:
564
- dummy_tok = AutoTokenizer.from_pretrained(
565
- repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN
566
- )
567
  except Exception as e:
 
568
  rows.append({
569
  "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
570
  "sample_id": None,
@@ -593,14 +563,10 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
593
  continue
594
 
595
  for sample_id, (transcript_text, exp_labels) in samples.items():
596
- if not transcript_text.strip():
597
- continue
598
- latencies = []
599
- last_pred = None
600
  for r in range(1, repeats+1):
601
- if total_runs >= max_total_runs:
602
- break
603
-
604
  proc_text, before_tok, after_tok = prepare_input_text(
605
  transcript_text, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
606
  )
@@ -608,7 +574,10 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
608
  user_text = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
609
 
610
  try:
611
- latency_ms, raw_text, _ = gpu_generate(repo_id, system_text, user_text, load_4bit, dtype, trust_remote_code)
 
 
 
612
  except Exception as e:
613
  base = f"{repo_id.replace('/','_')}/{sample_id}/error_r{r}"
614
  zout.writestr(base + "/ERROR.txt", f"Failed to run model via @spaces.GPU. If gated, accept license and set HF_TOKEN.\n\n{e}")
@@ -706,14 +675,35 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
706
  zout.close()
707
  df = pd.DataFrame(rows)
708
  if df.empty:
709
- return pd.DataFrame(), None, None, "No runs executed (empty dataset / exceeded cap / gated models)."
 
 
 
 
710
 
711
- csv_bytes = df.to_csv(index=False).encode("utf-8")
712
- return df, ("results.csv", csv_bytes), ("artifacts.zip", all_artifacts.getvalue()), "Done."
 
 
 
 
 
713
 
714
- # =======================
715
- # UI (same dark theme)
716
- # =======================
 
 
 
 
 
 
 
 
 
 
 
 
717
  DARK_RED_CSS = """
718
  :root, .gradio-container {
719
  --color-background: #0b0b0d;
@@ -741,44 +731,36 @@ button, .gr-button {
741
  }
742
  """
743
 
744
- PRESET_MODELS = [
745
- "mistralai/Mistral-7B-Instruct-v0.2",
746
- "Qwen/Qwen2.5-7B-Instruct",
747
- "HuggingFaceH4/zephyr-7b-beta",
748
- "tiiuae/falcon-7b-instruct"
749
- ]
750
-
751
- DEFAULT_SYSTEM = (
752
- "You are a task extraction assistant. "
753
- "Always output valid JSON with a field \"labels\" (list of strings). "
754
- "Use only from this set: " + json.dumps(ALLOWED_LABELS) + ". "
755
- "Return JSON only."
756
- )
757
- DEFAULT_CONTEXT = (
758
- "- plan_contact: conversation without a concrete meeting (no date/time)\n"
759
- "- schedule_meeting: explicit date/time/modality confirmation\n"
760
- "- update_contact_info_non_postal: changes to email/phone\n"
761
- "- update_contact_info_postal_address: changes to mailing address\n"
762
- "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
763
- )
764
-
765
  with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo:
766
  gr.Markdown("## πŸŸ₯ From Talk to Task β€” Batch & Single Task Extraction")
767
- gr.Markdown(
768
- "This tool extracts **task labels** from client–advisor transcripts using Hugging Face models. \n"
769
  "1) Pick a model (or paste a custom repo id). \n"
770
  "2) Provide **Instructions** and **Context**, then supply a transcript (single) or a ZIP (batch). \n"
771
  "3) Adjust parameters (soft token cap, preprocessing). \n"
772
- "4) Run and review **latency**, **precision/recall/F1**, **UBS score**, and download artifacts.\n"
773
- "_ZeroGPU-ready: model calls run inside an @spaces.GPU function when available._"
774
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
  with gr.Tabs():
777
  # Single
778
  with gr.TabItem("Single Transcript (default)"):
779
  with gr.Row():
780
  with gr.Column():
781
- preset_model = gr.Dropdown(choices=PRESET_MODELS, value=PRESET_MODELS[0], label="Model (preset)")
 
782
  custom_model = gr.Textbox(label="Custom model repo id (overrides preset)",
783
  placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct")
784
  instructions = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
@@ -795,6 +777,7 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
795
  pre_window_s = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
796
  add_cues_s = gr.Checkbox(value=True, label="Add cues header")
797
  strip_smalltalk_s = gr.Checkbox(value=False, label="Strip smalltalk")
 
798
  with gr.Column():
799
  load_4bit_s = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
800
  dtype_s = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
@@ -808,11 +791,8 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
808
  single_status = gr.Markdown("")
809
 
810
  def _run_single(*args):
811
- status, m1, m2, m3, df, csv_buf, zip_buf = single_mode(*args)
812
- if isinstance(df, pd.DataFrame) and not df.empty:
813
- return m1, m2, m3, df, csv_buf, zip_buf, status
814
- else:
815
- return m1 or "", m2 or "", m3 or "", pd.DataFrame(), None, None, status
816
 
817
  run_single_btn.click(
818
  _run_single,
@@ -820,7 +800,7 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
820
  transcript_text, transcript_file, expected_labels_json,
821
  soft_cap_s, preprocess_s, pre_window_s, add_cues_s, strip_smalltalk_s,
822
  load_4bit_s, dtype_s, trust_remote_code_s],
823
- outputs=[kpi1, kpi2, kpi3, single_table, single_csv, single_zip, single_status]
824
  )
825
 
826
  # Batch
@@ -828,7 +808,8 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
828
  with gr.Row():
829
  with gr.Column():
830
  models_list = gr.Checkboxgroup(
831
- choices=PRESET_MODELS, value=[PRESET_MODELS[0]], label="Models (select one or more presets)"
 
832
  )
833
  custom_models = gr.Textbox(label="Custom model repo ids (comma-separated)",
834
  placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct")
@@ -839,6 +820,7 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
839
  label="Upload ZIP of transcripts (*.txt) + expected (*.json)",
840
  file_types=[".zip"], file_count="single", type="filepath"
841
  )
 
842
 
843
  with gr.Row():
844
  with gr.Column():
@@ -847,6 +829,7 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
847
  pre_window = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
848
  add_cues = gr.Checkbox(value=True, label="Add cues header")
849
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
 
850
  with gr.Column():
851
  repeats = gr.Slider(1, 6, value=3, step=1, label="Repeats per config")
852
  max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
@@ -862,7 +845,7 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
862
  status = gr.Markdown("")
863
 
864
  def _run_batch(*args):
865
- df, csv_pair, zip_pair, msg = run_batch_ui(*args)
866
  m1 = m2 = m3 = ""
867
  if isinstance(df, pd.DataFrame) and not df.empty:
868
  summaries = df[df["is_summary"] == True]
@@ -874,19 +857,17 @@ with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo
874
  m3 = f"**Median latency (ms)**\n\n{int(med) if pd.notna(med) else 'β€”'}"
875
  csv_buf = zip_buf = None
876
  if isinstance(csv_pair, tuple):
877
- name, data = csv_pair
878
- csv_buf = io.BytesIO(data); csv_buf.name = name
879
  if isinstance(zip_pair, tuple):
880
- name, data = zip_pair
881
- zip_buf = io.BytesIO(data); zip_buf.name = name
882
- return m1, m2, m3, df, csv_buf, zip_buf, msg
883
 
884
  run_btn.click(
885
  _run_batch,
886
  inputs=[models_list, custom_models, instructions_b, context_b, dataset_zip,
887
  soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
888
  repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
889
- outputs=[kpi_b1, kpi_b2, kpi_b3, table, csv_dl, zip_dl, status]
890
  )
891
 
892
  demo.launch()
 
1
  import os, io, re, sys, time, json, zipfile, statistics
2
  from pathlib import Path
3
+ from typing import List, Dict, Tuple, Union
4
 
5
  import gradio as gr
6
  import pandas as pd
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
 
10
+ # ========= ZeroGPU support =========
 
 
11
  try:
12
+ import spaces # available on HF Spaces
13
  except Exception:
14
  class _DummySpaces:
15
  def GPU(self, *args, **kwargs):
 
17
  return deco
18
  spaces = _DummySpaces()
19
 
20
+ # ========= Auth token =========
21
  HF_TOKEN = (
22
  os.getenv("HF_TOKEN")
23
  or os.getenv("HUGGINGFACE_HUB_TOKEN")
24
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
25
  )
26
 
27
+ # Console warning at startup (helps when logs are open)
28
+ if not HF_TOKEN:
29
+ print(
30
+ "[WARN] HF_TOKEN is not set. Gated models will fail. "
31
+ "Set it in Space β†’ Settings β†’ Variables and secrets.",
32
+ file=sys.stderr
33
+ )
34
+
35
+ # ========= Labels & metrics =========
36
  ALLOWED_LABELS = [
37
  "plan_contact",
38
  "schedule_meeting",
 
43
  "update_kyc_purpose_of_businessrelation",
44
  "update_kyc_total_assets",
45
  ]
46
+ LABEL_TO_IDX = {l: i for i, l in enumerate(ALLOWED_LABELS)}
47
  FN_PENALTY = 2.0
48
  FP_PENALTY = 1.0
49
 
 
52
  return json.loads(s)
53
  except Exception:
54
  pass
55
+ m = re.search(r"\{.*\}", s, re.S)
56
  if m:
57
  try:
58
  return json.loads(m.group(0))
 
66
  for it in x:
67
  if isinstance(it, str): out.append(it)
68
  elif isinstance(it, dict):
69
+ for k in ("label", "value", "task", "category", "name"):
70
  v = it.get(k)
71
  if isinstance(v, str):
72
  out.append(v); break
73
  else:
74
  if isinstance(it.get("labels"), list):
75
  out += [s for s in it["labels"] if isinstance(s, str)]
76
+ # dedupe keep order
77
+ seen = set(); norm = []
78
  for s in out:
79
  if s not in seen:
80
  norm.append(s); seen.add(s)
81
  return norm
82
  if isinstance(x, dict):
83
+ for k in ("expected_labels", "labels", "targets", "y_true"):
84
  if k in x: return _coerce_labels_list(x[k])
85
  if "one_hot" in x and isinstance(x["one_hot"], dict):
86
+ return [k for k, v in x["one_hot"].items() if v]
87
  return []
88
 
89
  def classic_metrics(pred_labels, exp_labels):
90
+ pred = set([str(x) for x in (pred_labels or []) if isinstance(x, (str,int,float,bool))])
91
+ gold = set([str(x) for x in (exp_labels or []) if isinstance(x, (str,int,float,bool))])
 
92
  if not pred and not gold:
93
  return True, 1.0, 1.0, 1.0, 1.0
94
  inter = pred & gold; union = pred | gold
 
112
  score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
113
  return float(max(0.0, min(1.0, score)))
114
 
115
+ # ========= Lightweight preprocessing =========
 
 
116
  EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
117
  TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
118
  DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
 
203
  ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
204
  est = len(ids)
205
  threshold = int(soft_cap_tokens * apply_only_if_ratio)
206
+ if est <= threshold: return text
 
207
  parts = text.splitlines()
208
+ if len(parts) <= min_lines_keep: return text
 
209
 
210
  keep_flags=[]
211
  for ln in parts:
 
230
  candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
231
  candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
232
 
233
+ if len(candidate.splitlines()) < min_lines_keep: return text
 
234
  return candidate
235
 
236
  def enforce_rules(labels, transcript_text):
237
  labels = set(labels or [])
238
  if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
239
+ labels.add("schedule_meeting"); labels.discard("plan_contact")
 
240
  if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
241
  labels.add("update_contact_info_non_postal")
242
  kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
 
244
  labels.discard("update_kyc_activity")
245
  return sorted(labels)
246
 
247
+ # ========= HF model wrapper =========
 
 
248
  class HFModel:
249
  def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
250
  self.repo_id = repo_id
 
256
  self.model = None
257
  if load_4bit:
258
  try:
259
+ q = BitsAndBytesConfig(
260
+ load_in_4bit=True, bnb_4bit_use_double_quant=True,
261
+ bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_quant_type="nf4"
 
 
262
  )
263
  self.model = AutoModelForCausalLM.from_pretrained(
264
  repo_id, device_map="auto", trust_remote_code=trust_remote_code,
265
+ quantization_config=q, torch_dtype=torch_dtype, token=HF_TOKEN
266
  )
267
  except Exception as e:
268
  print(f"[WARN] 4-bit load failed for {repo_id}: {e}\nFalling back to normal load...", file=sys.stderr)
 
269
  if self.model is None:
270
  self.model = AutoModelForCausalLM.from_pretrained(
271
  repo_id, device_map="auto", trust_remote_code=trust_remote_code,
 
275
  self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
276
  or getattr(self.model.config, "max_sequence_length", None) or 8192
277
 
 
 
 
278
  def apply_chat_template(self, system_text: str, user_text: str) -> str:
279
  if getattr(self.tokenizer, "chat_template", None):
280
  messages = [{"role":"system","content":system_text},
 
290
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
291
  t0 = time.perf_counter()
292
  out = self.model.generate(
293
+ **inputs, max_new_tokens=max_new_tokens,
294
+ do_sample=False, temperature=None, top_p=None,
295
+ eos_token_id=self.tokenizer.eos_token_id
 
 
 
296
  )
297
  latency_ms = int((time.perf_counter() - t0) * 1000)
298
  text = self.tokenizer.decode(out[0], skip_special_tokens=True)
299
+ if text.startswith(prompt): text = text[len(prompt):]
 
300
  return latency_ms, text, prompt
301
 
 
302
  MODEL_CACHE: Dict[str, HFModel] = {}
 
303
  def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
304
  if repo_id not in MODEL_CACHE:
305
  MODEL_CACHE[repo_id] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
306
  return MODEL_CACHE[repo_id]
307
 
308
+ # ========= ZeroGPU functions =========
309
+ @spaces.GPU(duration=180, secrets=["HF_TOKEN"]) # pass token into ZeroGPU job
 
 
310
  def gpu_generate(repo_id: str, system_text: str, user_text: str,
311
  load_4bit: bool, dtype: str, trust_remote_code: bool):
312
+ token_seen = bool(os.getenv("HF_TOKEN"))
313
  hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
314
+ lat, txt, prmpt = hf.generate_json(system_text.strip(), user_text.strip(), max_new_tokens=256)
315
+ return lat, txt, prmpt, token_seen
316
+
317
+ @spaces.GPU(duration=15, secrets=["HF_TOKEN"])
318
+ def gpu_check_token():
319
+ return bool(os.getenv("HF_TOKEN"))
320
 
321
+ # ========= ZIP helpers =========
 
 
322
  def _read_zip_bytes(dataset_zip: Union[bytes, str, dict, None]) -> bytes:
323
+ if dataset_zip is None: raise ValueError("No ZIP provided")
324
+ if isinstance(dataset_zip, bytes): return dataset_zip
 
 
325
  if isinstance(dataset_zip, str):
326
+ with open(dataset_zip, "rb") as f: return f.read()
 
327
  if isinstance(dataset_zip, dict) and "path" in dataset_zip:
328
+ with open(dataset_zip["path"], "rb") as f: return f.read()
 
329
  path = getattr(dataset_zip, "name", None)
330
  if path and os.path.exists(path):
331
+ with open(path, "rb") as f: return f.read()
332
+ raise ValueError("Unsupported file object from Gradio")
 
333
 
334
  def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
335
  zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
 
336
  samples = {}
337
+ for n in zf.namelist():
338
  p = Path(n)
339
  if p.suffix.lower() == ".txt":
340
+ samples.setdefault(p.stem, ["", []])[0] = zf.read(n).decode("utf-8", "replace")
 
 
341
  elif p.suffix.lower() == ".json":
 
342
  try:
343
  js = json.loads(zf.read(n).decode("utf-8", "replace"))
344
  except Exception:
345
  js = []
346
+ samples.setdefault(p.stem, ["", []])[1] = _coerce_labels_list(js)
347
  return samples
348
 
349
+ # ========= Prompts =========
 
 
350
  DEFAULT_SYSTEM = (
351
  "You are a task extraction assistant. "
352
  "Always output valid JSON with a field \"labels\" (list of strings). "
 
361
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
362
  )
363
 
364
+ # ========= Preprocess + build input =========
365
  def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window: int,
366
  add_cues: bool, strip_smalltalk: bool, tokenizer) -> Tuple[str, int, int]:
367
  before = len(tokenizer(raw_txt, return_tensors=None, add_special_tokens=False).input_ids)
 
371
  lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
372
  cue_lines = find_cue_lines(lines)
373
  if cue_lines:
374
+ kept = prune_by_window(lines, cue_lines, window=pre_window, strip_smalltalk=strip_smalltalk)
375
  else:
376
+ kept = [ln for ln in lines if not (strip_smalltalk and SMALLTALK_RX.search(ln))]
377
+ t_kept = "\n".join(kept)
378
  cues = extract_cues(t_kept)
379
  header = build_cues_header(cues) if add_cues else ""
380
  proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
 
395
  "- **Load in 4-bit (GPU only)**: memory-saving quantization; has no effect on CPU Spaces."
396
  )
397
 
398
+ # ========= Single mode =========
 
 
399
  def single_mode(
400
  preset_model: str, custom_model: str,
401
  system_text: str, context_text: str,
 
406
  ):
407
  repo_id = custom_model.strip() or preset_model.strip()
408
  if not repo_id:
409
+ return "Please choose a model.", "", "", "", None, None, None, ""
410
 
411
  txt = (transcript_text or "").strip()
412
  if transcript_file and hasattr(transcript_file, "name") and os.path.exists(transcript_file.name):
413
  with open(transcript_file.name, "r", encoding="utf-8", errors="replace") as f:
414
  txt = f.read()
415
  if not txt:
416
+ return "Please paste a transcript or upload a .txt file.", "", "", "", None, None, None, ""
417
 
418
  exp = []
419
  if expected_labels_json and hasattr(expected_labels_json, "name") and os.path.exists(expected_labels_json.name):
 
423
  except Exception:
424
  exp = []
425
 
426
+ # tokenizer for preprocessing
427
  try:
428
+ dummy_tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
 
 
429
  except Exception as e:
430
+ msg = (f"Failed to load tokenizer for `{repo_id}`. "
431
+ "If gated, accept license and set HF_TOKEN in Space β†’ Settings β†’ Secrets.\n\nError: " + str(e))
432
+ return msg, "", "", "", None, None, None, banner_text()
433
 
434
  proc_text, tok_before, tok_after = prepare_input_text(
435
  txt, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
436
  )
 
437
  system = (system_text or DEFAULT_SYSTEM).strip()
438
+ user = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
439
 
440
  try:
441
+ latency_ms, raw_text, _prompt, gpu_token_seen = gpu_generate(
442
+ repo_id, system, user, load_4bit, dtype, trust_remote_code
443
+ )
444
  except Exception as e:
445
+ msg = (f"Failed to run `{repo_id}`. If gated, accept license and set HF_TOKEN.\n\nError: {e}")
446
+ return msg, "", "", "", None, None, None, banner_text()
447
 
448
  out = safe_json_load(raw_text)
449
  pred_labels = enforce_rules(out.get("labels", []), proc_text)
 
473
  "model_calls": 1
474
  },
475
  "evaluation": None if not exp else {
476
+ "exact_match": exact, "precision": prec, "recall": rec,
477
+ "f1": f1, "hamming": ham, "ubs_score": ubs
 
 
 
 
478
  }
479
  }
480
  zout.writestr("FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
 
496
  "ubs_score": round(ubs,6) if ubs is not None else None
497
  }])
498
 
499
+ csv_buf = io.BytesIO(row.to_csv(index=False).encode("utf-8")); csv_buf.name = "results_single.csv"
500
+
501
+ return (
502
+ "Done.",
503
+ kpi1, kpi2, kpi3,
504
+ row, csv_buf, zbuf,
505
+ banner_text(gpu_token_seen)
506
+ )
507
 
508
+ # ========= Batch mode =========
 
 
509
  def run_batch_ui(models_list, custom_models_str, instructions_text, context_text, dataset_zip,
510
  soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
511
  repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
512
 
513
  models = [m for m in (models_list or [])]
514
+ models += [m.strip() for m in (custom_models_str or "").split(",") if m.strip()]
 
 
515
  if not models:
516
+ return pd.DataFrame(), None, None, "Please pick at least one model.", banner_text()
517
 
518
  if not dataset_zip:
519
+ return pd.DataFrame(), None, None, "Please upload a ZIP with *.txt (+ optional matching *.json).", banner_text()
520
 
521
  try:
522
  zip_bytes = _read_zip_bytes(dataset_zip)
523
  samples = parse_zip(zip_bytes)
524
  except Exception as e:
525
+ return pd.DataFrame(), None, None, f"Failed to read ZIP: {e}", banner_text()
526
 
527
+ rows = []; total_runs = 0
 
528
  all_artifacts = io.BytesIO()
529
  zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
530
+ last_gpu_token_seen = None
531
 
532
  for repo_id in models:
533
+ # tokenizer for preprocessing (auth check)
534
  try:
535
+ dummy_tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
 
 
536
  except Exception as e:
537
+ # gated or missing token; record a summary row and continue
538
  rows.append({
539
  "timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
540
  "sample_id": None,
 
563
  continue
564
 
565
  for sample_id, (transcript_text, exp_labels) in samples.items():
566
+ if not transcript_text.strip(): continue
567
+ latencies = []; last_pred = None
 
 
568
  for r in range(1, repeats+1):
569
+ if total_runs >= max_total_runs: break
 
 
570
  proc_text, before_tok, after_tok = prepare_input_text(
571
  transcript_text, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
572
  )
 
574
  user_text = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
575
 
576
  try:
577
+ latency_ms, raw_text, _prompt, token_seen = gpu_generate(
578
+ repo_id, system_text, user_text, load_4bit, dtype, trust_remote_code
579
+ )
580
+ last_gpu_token_seen = token_seen
581
  except Exception as e:
582
  base = f"{repo_id.replace('/','_')}/{sample_id}/error_r{r}"
583
  zout.writestr(base + "/ERROR.txt", f"Failed to run model via @spaces.GPU. If gated, accept license and set HF_TOKEN.\n\n{e}")
 
675
  zout.close()
676
  df = pd.DataFrame(rows)
677
  if df.empty:
678
+ return pd.DataFrame(), None, None, "No runs executed (empty dataset / exceeded cap / gated models).", banner_text(last_gpu_token_seen)
679
+
680
+ csv_pair = ("results.csv", df.to_csv(index=False).encode("utf-8"))
681
+ zip_pair = ("artifacts.zip", all_artifacts.getvalue())
682
+ return df, csv_pair, zip_pair, "Done.", banner_text(last_gpu_token_seen)
683
 
684
+ # ========= UI helpers =========
685
+ OPEN_MODEL_PRESETS = [
686
+ "mistralai/Mistral-7B-Instruct-v0.2",
687
+ "Qwen/Qwen2.5-7B-Instruct",
688
+ "HuggingFaceH4/zephyr-7b-beta",
689
+ "tiiuae/falcon-7b-instruct",
690
+ ]
691
 
692
+ def banner_text(gpu_token_seen: bool | None = None) -> str:
693
+ app_seen = bool(HF_TOKEN)
694
+ lines = []
695
+ if not app_seen:
696
+ lines.append("🟑 **HF_TOKEN not detected in App** β€” gated models will fail unless you set it in **Settings β†’ Variables and secrets**.")
697
+ else:
698
+ lines.append("🟒 **HF_TOKEN detected in App**.")
699
+ if gpu_token_seen is None:
700
+ lines.append("ℹ️ ZeroGPU token status: click **Run** or **Check ZeroGPU token** to verify.")
701
+ else:
702
+ lines.append("🟒 **HF_TOKEN detected inside ZeroGPU job.**" if gpu_token_seen else "πŸ”΄ **HF_TOKEN missing inside ZeroGPU job** (add `secrets=[\"HF_TOKEN\"]` to @spaces.GPU).")
703
+ lines.append("βœ… Tip: use **Open models** (no license gating): " + ", ".join(OPEN_MODEL_PRESETS))
704
+ return "\n\n".join(lines)
705
+
706
+ # ========= UI (dark red) =========
707
  DARK_RED_CSS = """
708
  :root, .gradio-container {
709
  --color-background: #0b0b0d;
 
731
  }
732
  """
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  with gr.Blocks(title="From Talk to Task β€” HF Space", css=DARK_RED_CSS) as demo:
735
  gr.Markdown("## πŸŸ₯ From Talk to Task β€” Batch & Single Task Extraction")
736
+ help_md = (
737
+ "This tool extracts **task labels** from transcripts using Hugging Face models. \n"
738
  "1) Pick a model (or paste a custom repo id). \n"
739
  "2) Provide **Instructions** and **Context**, then supply a transcript (single) or a ZIP (batch). \n"
740
  "3) Adjust parameters (soft token cap, preprocessing). \n"
741
+ "4) Run and review **latency**, **precision/recall/F1**, **UBS score**, and download artifacts."
 
742
  )
743
+ gr.Markdown(help_md)
744
+
745
+ # Status banner (token presence info)
746
+ banner = gr.Markdown(banner_text())
747
+
748
+ check_btn = gr.Button("Check ZeroGPU token")
749
+ def _check_token():
750
+ try:
751
+ present = gpu_check_token()
752
+ except Exception:
753
+ present = None
754
+ return banner_text(present)
755
+ check_btn.click(_check_token, outputs=banner)
756
 
757
  with gr.Tabs():
758
  # Single
759
  with gr.TabItem("Single Transcript (default)"):
760
  with gr.Row():
761
  with gr.Column():
762
+ preset_model = gr.Dropdown(choices=OPEN_MODEL_PRESETS, value=OPEN_MODEL_PRESETS[0],
763
+ label="Model (Open presets β€” no gating)")
764
  custom_model = gr.Textbox(label="Custom model repo id (overrides preset)",
765
  placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct")
766
  instructions = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
 
777
  pre_window_s = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
778
  add_cues_s = gr.Checkbox(value=True, label="Add cues header")
779
  strip_smalltalk_s = gr.Checkbox(value=False, label="Strip smalltalk")
780
+ gr.Markdown(explain_params_markdown())
781
  with gr.Column():
782
  load_4bit_s = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
783
  dtype_s = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
 
791
  single_status = gr.Markdown("")
792
 
793
  def _run_single(*args):
794
+ status, m1, m2, m3, df, csv_buf, zip_buf, btxt = single_mode(*args)
795
+ return m1 or "", m2 or "", m3 or "", (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (status or ""), (btxt or banner_text())
 
 
 
796
 
797
  run_single_btn.click(
798
  _run_single,
 
800
  transcript_text, transcript_file, expected_labels_json,
801
  soft_cap_s, preprocess_s, pre_window_s, add_cues_s, strip_smalltalk_s,
802
  load_4bit_s, dtype_s, trust_remote_code_s],
803
+ outputs=[kpi1, kpi2, kpi3, single_table, single_csv, single_zip, single_status, banner]
804
  )
805
 
806
  # Batch
 
808
  with gr.Row():
809
  with gr.Column():
810
  models_list = gr.Checkboxgroup(
811
+ choices=OPEN_MODEL_PRESETS, value=[OPEN_MODEL_PRESETS[0]],
812
+ label="Models (Open presets β€” select one or more)"
813
  )
814
  custom_models = gr.Textbox(label="Custom model repo ids (comma-separated)",
815
  placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct")
 
820
  label="Upload ZIP of transcripts (*.txt) + expected (*.json)",
821
  file_types=[".zip"], file_count="single", type="filepath"
822
  )
823
+ gr.Markdown("Zip must contain pairs like `ID.txt` and optional `ID.json` with expected labels (same base filename).")
824
 
825
  with gr.Row():
826
  with gr.Column():
 
829
  pre_window = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
830
  add_cues = gr.Checkbox(value=True, label="Add cues header")
831
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
832
+ gr.Markdown(explain_params_markdown())
833
  with gr.Column():
834
  repeats = gr.Slider(1, 6, value=3, step=1, label="Repeats per config")
835
  max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
 
845
  status = gr.Markdown("")
846
 
847
  def _run_batch(*args):
848
+ df, csv_pair, zip_pair, msg, btxt = run_batch_ui(*args)
849
  m1 = m2 = m3 = ""
850
  if isinstance(df, pd.DataFrame) and not df.empty:
851
  summaries = df[df["is_summary"] == True]
 
857
  m3 = f"**Median latency (ms)**\n\n{int(med) if pd.notna(med) else 'β€”'}"
858
  csv_buf = zip_buf = None
859
  if isinstance(csv_pair, tuple):
860
+ name, data = csv_pair; csv_buf = io.BytesIO(data); csv_buf.name = name
 
861
  if isinstance(zip_pair, tuple):
862
+ name, data = zip_pair; zip_buf = io.BytesIO(data); zip_buf.name = name
863
+ return m1, m2, m3, (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (msg or ""), (btxt or banner_text())
 
864
 
865
  run_btn.click(
866
  _run_batch,
867
  inputs=[models_list, custom_models, instructions_b, context_b, dataset_zip,
868
  soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
869
  repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
870
+ outputs=[kpi_b1, kpi_b2, kpi_b3, table, csv_dl, zip_dl, status, banner]
871
  )
872
 
873
  demo.launch()