RishiRP commited on
Commit
db991b8
·
verified ·
1 Parent(s): e5618ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -60
app.py CHANGED
@@ -209,21 +209,35 @@ def clean_transcript(text: str) -> str:
209
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
210
  return s
211
 
212
- def read_text_from_file(file: gr.File) -> str:
213
- if not file or not file.name:
 
214
  return ""
215
- name = file.name.lower()
216
- data = file.read()
217
- if name.endswith(".json"):
218
  try:
219
- obj = json.loads(data.decode("utf-8", errors="ignore"))
220
- if isinstance(obj, dict) and "transcript" in obj:
221
- return str(obj["transcript"])
222
- return json.dumps(obj, ensure_ascii=False)
223
  except Exception:
224
- return data.decode("utf-8", errors="ignore")
225
- else:
 
 
226
  return data.decode("utf-8", errors="ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
229
  toks = tokenizer(text, add_special_tokens=False)["input_ids"]
@@ -270,30 +284,47 @@ class ModelWrapper:
270
 
271
  @torch.inference_mode()
272
  def generate(self, system_prompt: str, user_prompt: str) -> str:
 
273
  if hasattr(self.tokenizer, "apply_chat_template"):
274
- msgs = [{"role": "system", "content": system_prompt},
275
- {"role": "user", "content": user_prompt}]
276
- inputs = self.tokenizer.apply_chat_template(
277
- msgs, add_generation_prompt=True, return_tensors="pt"
278
- ).to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
279
  else:
280
- text = f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n"
281
- inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
282
-
283
- with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
284
- out_ids = self.model.generate(
285
- **inputs,
286
  generation_config=GEN_CONFIG,
287
  eos_token_id=self.tokenizer.eos_token_id,
288
  pad_token_id=self.tokenizer.pad_token_id,
289
  )
 
 
 
290
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
291
 
292
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
293
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
294
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
295
  if key not in _MODEL_CACHE:
296
- m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load()
 
297
  _MODEL_CACHE[key] = m
298
  return _MODEL_CACHE[key]
299
 
@@ -303,8 +334,6 @@ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> Mode
303
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
304
  ALLOWED_LABELS = OFFICIAL_LABELS
305
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
306
- FN_PENALTY = 2.0
307
- FP_PENALTY = 1.0
308
 
309
  def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
310
  if not isinstance(sample_labels, list):
@@ -315,13 +344,10 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
315
  raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
316
  if label in seen:
317
  raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
318
- seen.add(label); uniq.append(label)
319
- valid = []
320
- for label in uniq:
321
  if label not in ALLOWED_LABELS:
322
  raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
323
- valid.append(label)
324
- return valid
325
 
326
  if len(y_true) != len(y_pred):
327
  raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
@@ -339,13 +365,37 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
339
  for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
340
  y_pred_binary[i, LABEL_TO_IDX[label]] = 1
341
 
342
- fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
343
- fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
344
  weighted = 2.0 * fn + 1.0 * fp
345
  max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
346
  per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
347
  return float(max(0.0, min(1.0, np.mean(per_sample))))
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  # =========================
350
  # Inference helpers
351
  # =========================
@@ -358,34 +408,44 @@ def build_keyword_context(allowed: List[str]) -> str:
358
 
359
  def run_single(
360
  transcript_text: str,
361
- transcript_file: gr.File,
 
 
362
  use_cleaning: bool,
 
363
  allowed_labels_text: str,
364
  model_repo: str,
365
  use_4bit: bool,
366
  max_input_tokens: int,
367
  hf_token: str,
368
- ) -> Tuple[str, str, str, str]:
369
 
370
  t0 = _now_ms()
371
 
372
- raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
373
- raw_text = (raw_text or "").strip()
 
 
 
374
  if not raw_text:
375
- return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
376
 
377
  text = clean_transcript(raw_text) if use_cleaning else raw_text
378
 
 
379
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
380
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
381
 
 
382
  try:
383
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
384
  except Exception as e:
385
- return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
386
 
 
387
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
388
 
 
389
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
390
  keyword_ctx = build_keyword_context(allowed)
391
  user_prompt = USER_PROMPT_TEMPLATE.format(
@@ -394,25 +454,38 @@ def run_single(
394
  keyword_context=keyword_ctx,
395
  )
396
 
 
397
  t1 = _now_ms()
398
  try:
399
  out = model.generate(SYSTEM_PROMPT, user_prompt)
400
  except Exception as e:
401
- return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
402
  t2 = _now_ms()
403
 
404
  parsed = robust_json_extract(out)
405
  filtered = restrict_to_allowed(parsed, allowed)
406
 
 
 
 
 
 
 
 
407
  diag = "\n".join([
408
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
409
  f"Model: {model_repo}",
410
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
 
411
  f"Tokens (input, approx): ≤ {max_input_tokens}",
412
  f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
413
  f"Allowed labels: {', '.join(allowed)}",
414
  ])
415
 
 
 
 
 
416
  labs = filtered.get("labels", [])
417
  tasks = filtered.get("tasks", [])
418
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
@@ -424,20 +497,59 @@ def run_single(
424
  else:
425
  summary += "\n\nTasks: (none)"
426
 
427
- return summary, json.dumps(filtered, indent=2, ensure_ascii=False), diag, out.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # =========================
430
  # Batch mode (ZIP with transcripts + truths)
431
  # =========================
432
- def read_zip(fileobj: io.BytesIO, exdir: Path) -> List[Path]:
433
  exdir.mkdir(parents=True, exist_ok=True)
434
- with zipfile.ZipFile(fileobj) as zf:
 
 
435
  zf.extractall(exdir)
436
  return [p for p in exdir.rglob("*") if p.is_file()]
437
 
438
  def run_batch(
439
- zip_file: gr.File,
440
  use_cleaning: bool,
 
441
  model_repo: str,
442
  use_4bit: bool,
443
  max_input_tokens: int,
@@ -445,24 +557,20 @@ def run_batch(
445
  limit_files: int,
446
  ) -> Tuple[str, str, pd.DataFrame, str]:
447
 
448
- if not zip_file:
449
  return ("No ZIP provided.", "", pd.DataFrame(), "")
450
 
451
  work = Path("/tmp/batch")
452
  if work.exists():
453
  for p in sorted(work.rglob("*"), reverse=True):
454
- try:
455
- p.unlink()
456
- except Exception:
457
- pass
458
- try:
459
- work.rmdir()
460
- except Exception:
461
- pass
462
  work.mkdir(parents=True, exist_ok=True)
463
 
464
- data = zip_file.read()
465
- files = read_zip(io.BytesIO(data), work)
466
 
467
  txts: Dict[str, Path] = {}
468
  gts: Dict[str, Path] = {}
@@ -508,6 +616,12 @@ def run_batch(
508
 
509
  parsed = robust_json_extract(out)
510
  filtered = restrict_to_allowed(parsed, allowed)
 
 
 
 
 
 
511
  pred_labels = filtered.get("labels", [])
512
  y_pred.append(pred_labels)
513
 
@@ -543,6 +657,7 @@ def run_batch(
543
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
544
  f"Model: {model_repo}",
545
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
 
546
  f"Tokens (input, approx): ≤ {max_input_tokens}",
547
  f"Batch time: {_now_ms()-t_start} ms",
548
  ]
@@ -563,7 +678,6 @@ def run_batch(
563
  # save CSV for download
564
  out_csv = Path("/tmp/batch_results.csv")
565
  df.to_csv(out_csv, index=False, encoding="utf-8")
566
-
567
  return ("Batch done.", diag_str, df, str(out_csv))
568
 
569
  # =========================
@@ -585,16 +699,31 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
585
  with gr.Tab("Single transcript"):
586
  with gr.Row():
587
  with gr.Column(scale=3):
 
588
  file = gr.File(
589
  label="Drag & drop transcript (.txt / .md / .json)",
590
  file_types=[".txt", ".md", ".json"],
591
  type="filepath",
592
  )
593
- text = gr.Textbox(label="Or paste transcript", lines=14)
 
 
 
 
 
 
 
 
 
594
  use_cleaning = gr.Checkbox(
595
  label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
596
  value=True,
597
  )
 
 
 
 
 
598
  labels_text = gr.Textbox(
599
  label="Allowed Labels (one per line; empty = official list)",
600
  value="",
@@ -613,11 +742,17 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
613
  with gr.Row():
614
  diag = gr.Textbox(label="Diagnostics", lines=8)
615
  raw = gr.Textbox(label="Raw Model Output", lines=8)
 
 
 
616
 
617
  run_btn.click(
618
  fn=run_single,
619
- inputs=[text, file, use_cleaning, labels_text, repo, use_4bit, max_tokens, hf_token],
620
- outputs=[summary, json_out, diag, raw],
 
 
 
621
  )
622
 
623
  with gr.Tab("Batch evaluation"):
@@ -625,6 +760,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
625
  with gr.Column(scale=3):
626
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
627
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
 
628
  with gr.Column(scale=2):
629
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
630
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
@@ -636,15 +772,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
636
  with gr.Row():
637
  status = gr.Textbox(label="Status", lines=1)
638
  diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
639
-
640
  df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
641
  csv_out = gr.File(label="Download CSV", interactive=False)
642
 
643
  run_batch_btn.click(
644
  fn=run_batch,
645
- inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
646
  outputs=[status, diag_b, df_out, csv_out],
647
  )
648
 
649
  if __name__ == "__main__":
 
650
  demo.launch()
 
209
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
210
  return s
211
 
212
+ def read_text_file_any(file_input) -> str:
213
+ """Works for gr.File(type='filepath') and raw strings/Path and file-like."""
214
+ if not file_input:
215
  return ""
216
+ # filepath string
217
+ if isinstance(file_input, (str, Path)):
 
218
  try:
219
+ return Path(file_input).read_text(encoding="utf-8", errors="ignore")
 
 
 
220
  except Exception:
221
+ return ""
222
+ # gr.File object or file-like
223
+ try:
224
+ data = file_input.read()
225
  return data.decode("utf-8", errors="ignore")
226
+ except Exception:
227
+ return ""
228
+
229
+ def read_json_file_any(file_input) -> Optional[dict]:
230
+ if not file_input:
231
+ return None
232
+ if isinstance(file_input, (str, Path)):
233
+ try:
234
+ return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore"))
235
+ except Exception:
236
+ return None
237
+ try:
238
+ return json.loads(file_input.read().decode("utf-8", errors="ignore"))
239
+ except Exception:
240
+ return None
241
 
242
  def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
243
  toks = tokenizer(text, add_special_tokens=False)["input_ids"]
 
284
 
285
  @torch.inference_mode()
286
  def generate(self, system_prompt: str, user_prompt: str) -> str:
287
+ # Build inputs as input_ids=... (avoid **tensor bug)
288
  if hasattr(self.tokenizer, "apply_chat_template"):
289
+ messages = [
290
+ {"role": "system", "content": system_prompt},
291
+ {"role": "user", "content": user_prompt},
292
+ ]
293
+ input_ids = self.tokenizer.apply_chat_template(
294
+ messages,
295
+ tokenize=True,
296
+ add_generation_prompt=True,
297
+ return_tensors="pt",
298
+ )
299
+ input_ids = input_ids.to(self.model.device)
300
+ gen_kwargs = dict(
301
+ input_ids=input_ids,
302
+ generation_config=GEN_CONFIG,
303
+ eos_token_id=self.tokenizer.eos_token_id,
304
+ pad_token_id=self.tokenizer.pad_token_id,
305
+ )
306
  else:
307
+ enc = self.tokenizer(
308
+ f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n",
309
+ return_tensors="pt"
310
+ ).to(self.model.device)
311
+ gen_kwargs = dict(
312
+ **enc,
313
  generation_config=GEN_CONFIG,
314
  eos_token_id=self.tokenizer.eos_token_id,
315
  pad_token_id=self.tokenizer.pad_token_id,
316
  )
317
+
318
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
319
+ out_ids = self.model.generate(**gen_kwargs)
320
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
321
 
322
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
323
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
324
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
325
  if key not in _MODEL_CACHE:
326
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit)
327
+ m.load()
328
  _MODEL_CACHE[key] = m
329
  return _MODEL_CACHE[key]
330
 
 
334
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
335
  ALLOWED_LABELS = OFFICIAL_LABELS
336
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
 
 
337
 
338
  def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
339
  if not isinstance(sample_labels, list):
 
344
  raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
345
  if label in seen:
346
  raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
 
 
 
347
  if label not in ALLOWED_LABELS:
348
  raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
349
+ seen.add(label); uniq.append(label)
350
+ return uniq
351
 
352
  if len(y_true) != len(y_pred):
353
  raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
 
365
  for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
366
  y_pred_binary[i, LABEL_TO_IDX[label]] = 1
367
 
368
+ fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1) # penalty 2x
369
+ fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1) # penalty 1x
370
  weighted = 2.0 * fn + 1.0 * fp
371
  max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
372
  per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
373
  return float(max(0.0, min(1.0, np.mean(per_sample))))
374
 
375
+ # =========================
376
+ # Fallback: keyword heuristics if model returns empty
377
+ # =========================
378
+ def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
379
+ low = text.lower()
380
+ labels = []
381
+ tasks = []
382
+ for lab in allowed:
383
+ hits = []
384
+ for kw in LABEL_KEYWORDS.get(lab, []):
385
+ if kw.lower() in low:
386
+ # capture small evidence window
387
+ i = low.find(kw.lower())
388
+ start = max(0, i - 40); end = min(len(text), i + len(kw) + 40)
389
+ hits.append(text[start:end].strip())
390
+ if hits:
391
+ labels.append(lab)
392
+ tasks.append({
393
+ "label": lab,
394
+ "explanation": "Keyword match in transcript.",
395
+ "evidence": hits[0]
396
+ })
397
+ return {"labels": normalize_labels(labels), "tasks": tasks}
398
+
399
  # =========================
400
  # Inference helpers
401
  # =========================
 
408
 
409
  def run_single(
410
  transcript_text: str,
411
+ transcript_file, # filepath or file-like
412
+ gt_json_text: str,
413
+ gt_json_file, # filepath or file-like
414
  use_cleaning: bool,
415
+ use_keyword_fallback: bool,
416
  allowed_labels_text: str,
417
  model_repo: str,
418
  use_4bit: bool,
419
  max_input_tokens: int,
420
  hf_token: str,
421
+ ) -> Tuple[str, str, str, str, str, str]:
422
 
423
  t0 = _now_ms()
424
 
425
+ # Transcript
426
+ raw_text = ""
427
+ if transcript_file:
428
+ raw_text = read_text_file_any(transcript_file)
429
+ raw_text = (raw_text or transcript_text or "").strip()
430
  if not raw_text:
431
+ return "", "", "No transcript provided.", "", "", ""
432
 
433
  text = clean_transcript(raw_text) if use_cleaning else raw_text
434
 
435
+ # Allowed labels
436
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
437
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
438
 
439
+ # Model
440
  try:
441
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
442
  except Exception as e:
443
+ return "", "", f"Model load failed: {e}", "", "", ""
444
 
445
+ # Truncate
446
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
447
 
448
+ # Build prompt
449
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
450
  keyword_ctx = build_keyword_context(allowed)
451
  user_prompt = USER_PROMPT_TEMPLATE.format(
 
454
  keyword_context=keyword_ctx,
455
  )
456
 
457
+ # Generate
458
  t1 = _now_ms()
459
  try:
460
  out = model.generate(SYSTEM_PROMPT, user_prompt)
461
  except Exception as e:
462
+ return "", "", f"Generation error: {e}", "", "", ""
463
  t2 = _now_ms()
464
 
465
  parsed = robust_json_extract(out)
466
  filtered = restrict_to_allowed(parsed, allowed)
467
 
468
+ # Fallback if empty
469
+ if use_keyword_fallback and not filtered.get("labels"):
470
+ fb = keyword_fallback(trunc, allowed)
471
+ if fb["labels"]:
472
+ filtered = fb
473
+
474
+ # Diagnostics
475
  diag = "\n".join([
476
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
477
  f"Model: {model_repo}",
478
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
479
+ f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
480
  f"Tokens (input, approx): ≤ {max_input_tokens}",
481
  f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
482
  f"Allowed labels: {', '.join(allowed)}",
483
  ])
484
 
485
+ # Context preview shown in UI
486
+ context_preview = "Allowed Labels:\n" + "\n".join(f"- {l}" for l in allowed) + "\n\nKeyword cues:\n" + keyword_ctx
487
+
488
+ # Summary
489
  labs = filtered.get("labels", [])
490
  tasks = filtered.get("tasks", [])
491
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
497
  else:
498
  summary += "\n\nTasks: (none)"
499
 
500
+ json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
501
+
502
+ # Optional single-file scoring if GT provided
503
+ metrics = ""
504
+ true_labels = None
505
+ if gt_json_file or (gt_json_text and gt_json_text.strip()):
506
+ truth_obj = None
507
+ if gt_json_file:
508
+ truth_obj = read_json_file_any(gt_json_file)
509
+ if (not truth_obj) and gt_json_text:
510
+ try:
511
+ truth_obj = json.loads(gt_json_text)
512
+ except Exception:
513
+ pass
514
+ if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list):
515
+ true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS]
516
+ pred_labels = labs
517
+ try:
518
+ score = evaluate_predictions([true_labels], [pred_labels])
519
+ tp = len(set(true_labels) & set(pred_labels))
520
+ fp = len(set(pred_labels) - set(true_labels))
521
+ fn = len(set(true_labels) - set(pred_labels))
522
+ recall = tp / (tp + fn) if (tp + fn) else 1.0
523
+ precision = tp / (tp + fp) if (tp + fp) else 1.0
524
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
525
+ metrics = (
526
+ f"Weighted score: {score:.3f}\n"
527
+ f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n"
528
+ f"TP={tp} FP={fp} FN={fn}\n"
529
+ f"Truth: {', '.join(true_labels)}"
530
+ )
531
+ except Exception as e:
532
+ metrics = f"Scoring error: {e}"
533
+ else:
534
+ metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
535
+
536
+ return summary, json_out, diag, out.strip(), context_preview, metrics
537
 
538
  # =========================
539
  # Batch mode (ZIP with transcripts + truths)
540
  # =========================
541
+ def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
542
  exdir.mkdir(parents=True, exist_ok=True)
543
+ with open(path, "rb") as f:
544
+ data = f.read()
545
+ with zipfile.ZipFile(io.BytesIO(data)) as zf:
546
  zf.extractall(exdir)
547
  return [p for p in exdir.rglob("*") if p.is_file()]
548
 
549
  def run_batch(
550
+ zip_path, # filepath string
551
  use_cleaning: bool,
552
+ use_keyword_fallback: bool,
553
  model_repo: str,
554
  use_4bit: bool,
555
  max_input_tokens: int,
 
557
  limit_files: int,
558
  ) -> Tuple[str, str, pd.DataFrame, str]:
559
 
560
+ if not zip_path:
561
  return ("No ZIP provided.", "", pd.DataFrame(), "")
562
 
563
  work = Path("/tmp/batch")
564
  if work.exists():
565
  for p in sorted(work.rglob("*"), reverse=True):
566
+ try: p.unlink()
567
+ except Exception: pass
568
+ try: work.rmdir()
569
+ except Exception: pass
 
 
 
 
570
  work.mkdir(parents=True, exist_ok=True)
571
 
572
+ # Unzip
573
+ files = read_zip_from_path(zip_path, work)
574
 
575
  txts: Dict[str, Path] = {}
576
  gts: Dict[str, Path] = {}
 
616
 
617
  parsed = robust_json_extract(out)
618
  filtered = restrict_to_allowed(parsed, allowed)
619
+
620
+ if use_keyword_fallback and not filtered.get("labels"):
621
+ fb = keyword_fallback(trunc, allowed)
622
+ if fb["labels"]:
623
+ filtered = fb
624
+
625
  pred_labels = filtered.get("labels", [])
626
  y_pred.append(pred_labels)
627
 
 
657
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
658
  f"Model: {model_repo}",
659
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
660
+ f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
661
  f"Tokens (input, approx): ≤ {max_input_tokens}",
662
  f"Batch time: {_now_ms()-t_start} ms",
663
  ]
 
678
  # save CSV for download
679
  out_csv = Path("/tmp/batch_results.csv")
680
  df.to_csv(out_csv, index=False, encoding="utf-8")
 
681
  return ("Batch done.", diag_str, df, str(out_csv))
682
 
683
  # =========================
 
699
  with gr.Tab("Single transcript"):
700
  with gr.Row():
701
  with gr.Column(scale=3):
702
+ gr.Markdown("### Transcript")
703
  file = gr.File(
704
  label="Drag & drop transcript (.txt / .md / .json)",
705
  file_types=[".txt", ".md", ".json"],
706
  type="filepath",
707
  )
708
+ text = gr.Textbox(label="Or paste transcript", lines=10)
709
+
710
+ gr.Markdown("### Ground truth JSON (optional)")
711
+ gt_file = gr.File(
712
+ label="Upload ground truth JSON (expects {'labels': [...]})",
713
+ file_types=[".json"],
714
+ type="filepath",
715
+ )
716
+ gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{"labels": ["schedule_meeting"]}')
717
+
718
  use_cleaning = gr.Checkbox(
719
  label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
720
  value=True,
721
  )
722
+ use_keyword_fallback = gr.Checkbox(
723
+ label="Keyword fallback if model returns empty",
724
+ value=True,
725
+ )
726
+
727
  labels_text = gr.Textbox(
728
  label="Allowed Labels (one per line; empty = official list)",
729
  value="",
 
742
  with gr.Row():
743
  diag = gr.Textbox(label="Diagnostics", lines=8)
744
  raw = gr.Textbox(label="Raw Model Output", lines=8)
745
+ with gr.Row():
746
+ context_used = gr.Code(label="Effective context used this run (labels + keyword cues)", language="markdown")
747
+ single_metrics = gr.Textbox(label="Single-file metrics (if ground truth provided)", lines=6)
748
 
749
  run_btn.click(
750
  fn=run_single,
751
+ inputs=[
752
+ text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
753
+ labels_text, repo, use_4bit, max_tokens, hf_token
754
+ ],
755
+ outputs=[summary, json_out, diag, raw, context_used, single_metrics],
756
  )
757
 
758
  with gr.Tab("Batch evaluation"):
 
760
  with gr.Column(scale=3):
761
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
762
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
763
+ use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
764
  with gr.Column(scale=2):
765
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
766
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
 
772
  with gr.Row():
773
  status = gr.Textbox(label="Status", lines=1)
774
  diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
 
775
  df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
776
  csv_out = gr.File(label="Download CSV", interactive=False)
777
 
778
  run_batch_btn.click(
779
  fn=run_batch,
780
+ inputs=[zip_in, use_cleaning_b, use_keyword_fallback_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
781
  outputs=[status, diag_b, df_out, csv_out],
782
  )
783
 
784
  if __name__ == "__main__":
785
+ demo = demo # to satisfy some runtimes
786
  demo.launch()