RishiRP commited on
Commit
954d97c
·
verified ·
1 Parent(s): db991b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -41
app.py CHANGED
@@ -45,6 +45,7 @@ OFFICIAL_LABELS = [
45
  "update_kyc_purpose_of_businessrelation",
46
  "update_kyc_total_assets",
47
  ]
 
48
 
49
  # Per-label keyword cues (static prompt context to improve recall)
50
  LABEL_KEYWORDS: Dict[str, List[str]] = {
@@ -213,13 +214,11 @@ def read_text_file_any(file_input) -> str:
213
  """Works for gr.File(type='filepath') and raw strings/Path and file-like."""
214
  if not file_input:
215
  return ""
216
- # filepath string
217
  if isinstance(file_input, (str, Path)):
218
  try:
219
  return Path(file_input).read_text(encoding="utf-8", errors="ignore")
220
  except Exception:
221
  return ""
222
- # gr.File object or file-like
223
  try:
224
  data = file_input.read()
225
  return data.decode("utf-8", errors="ignore")
@@ -284,7 +283,7 @@ class ModelWrapper:
284
 
285
  @torch.inference_mode()
286
  def generate(self, system_prompt: str, user_prompt: str) -> str:
287
- # Build inputs as input_ids=... (avoid **tensor bug)
288
  if hasattr(self.tokenizer, "apply_chat_template"):
289
  messages = [
290
  {"role": "system", "content": system_prompt},
@@ -382,10 +381,11 @@ def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
382
  for lab in allowed:
383
  hits = []
384
  for kw in LABEL_KEYWORDS.get(lab, []):
385
- if kw.lower() in low:
 
386
  # capture small evidence window
387
- i = low.find(kw.lower())
388
- start = max(0, i - 40); end = min(len(text), i + len(kw) + 40)
389
  hits.append(text[start:end].strip())
390
  if hits:
391
  labels.append(lab)
@@ -418,7 +418,7 @@ def run_single(
418
  use_4bit: bool,
419
  max_input_tokens: int,
420
  hf_token: str,
421
- ) -> Tuple[str, str, str, str, str, str]:
422
 
423
  t0 = _now_ms()
424
 
@@ -428,11 +428,11 @@ def run_single(
428
  raw_text = read_text_file_any(transcript_file)
429
  raw_text = (raw_text or transcript_text or "").strip()
430
  if not raw_text:
431
- return "", "", "No transcript provided.", "", "", ""
432
 
433
  text = clean_transcript(raw_text) if use_cleaning else raw_text
434
 
435
- # Allowed labels
436
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
437
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
438
 
@@ -440,7 +440,7 @@ def run_single(
440
  try:
441
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
442
  except Exception as e:
443
- return "", "", f"Model load failed: {e}", "", "", ""
444
 
445
  # Truncate
446
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
@@ -459,7 +459,7 @@ def run_single(
459
  try:
460
  out = model.generate(SYSTEM_PROMPT, user_prompt)
461
  except Exception as e:
462
- return "", "", f"Generation error: {e}", "", "", ""
463
  t2 = _now_ms()
464
 
465
  parsed = robust_json_extract(out)
@@ -482,10 +482,16 @@ def run_single(
482
  f"Allowed labels: {', '.join(allowed)}",
483
  ])
484
 
485
- # Context preview shown in UI
486
- context_preview = "Allowed Labels:\n" + "\n".join(f"- {l}" for l in allowed) + "\n\nKeyword cues:\n" + keyword_ctx
 
 
 
 
 
 
487
 
488
- # Summary
489
  labs = filtered.get("labels", [])
490
  tasks = filtered.get("tasks", [])
491
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
@@ -496,7 +502,6 @@ def run_single(
496
  )
497
  else:
498
  summary += "\n\nTasks: (none)"
499
-
500
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
501
 
502
  # Optional single-file scoring if GT provided
@@ -533,7 +538,7 @@ def run_single(
533
  else:
534
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
535
 
536
- return summary, json_out, diag, out.strip(), context_preview, metrics
537
 
538
  # =========================
539
  # Batch mode (ZIP with transcripts + truths)
@@ -569,7 +574,6 @@ def run_batch(
569
  except Exception: pass
570
  work.mkdir(parents=True, exist_ok=True)
571
 
572
- # Unzip
573
  files = read_zip_from_path(zip_path, work)
574
 
575
  txts: Dict[str, Path] = {}
@@ -642,7 +646,7 @@ def run_batch(
642
 
643
  rows.append({
644
  "file": stem,
645
- "true_labels": ", ".join(gt_labels),
646
  "pred_labels": ", ".join(pred_labels),
647
  "TP": len(tp), "FP": len(fp), "FN": len(fn),
648
  "gen_ms": t1 - t0
@@ -689,32 +693,43 @@ MODEL_CHOICES = [
689
  "mistralai/Mistral-7B-Instruct-v0.3",
690
  ]
691
 
692
- with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
693
- gr.Markdown("# Talk2Task Task Extraction (UBS Challenge)")
694
- gr.Markdown(
695
- "Extract challenge labels from transcripts. False negatives are penalised more than false positives "
696
- "in the official score, so the app biases for recall."
697
- )
 
 
 
 
 
 
 
 
698
 
699
  with gr.Tab("Single transcript"):
700
  with gr.Row():
701
  with gr.Column(scale=3):
702
- gr.Markdown("### Transcript")
703
  file = gr.File(
704
  label="Drag & drop transcript (.txt / .md / .json)",
705
  file_types=[".txt", ".md", ".json"],
706
  type="filepath",
707
  )
708
  text = gr.Textbox(label="Or paste transcript", lines=10)
 
709
 
710
- gr.Markdown("### Ground truth JSON (optional)")
711
  gt_file = gr.File(
712
  label="Upload ground truth JSON (expects {'labels': [...]})",
713
  file_types=[".json"],
714
  type="filepath",
715
  )
716
- gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{"labels": ["schedule_meeting"]}')
 
717
 
 
718
  use_cleaning = gr.Checkbox(
719
  label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
720
  value=True,
@@ -723,28 +738,51 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
723
  label="Keyword fallback if model returns empty",
724
  value=True,
725
  )
 
726
 
 
727
  labels_text = gr.Textbox(
728
- label="Allowed Labels (one per line; empty = official list)",
729
- value="",
730
  lines=8,
731
  )
 
 
 
732
  with gr.Column(scale=2):
 
733
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
734
  use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
735
  max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
736
  hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
737
  run_btn = gr.Button("Run Extraction", variant="primary")
 
 
 
 
 
 
 
 
738
 
739
  with gr.Row():
740
- summary = gr.Textbox(label="Summary", lines=12)
741
- json_out = gr.Code(label="Strict JSON Output", language="json")
742
- with gr.Row():
743
- diag = gr.Textbox(label="Diagnostics", lines=8)
744
- raw = gr.Textbox(label="Raw Model Output", lines=8)
745
- with gr.Row():
746
- context_used = gr.Code(label="Effective context used this run (labels + keyword cues)", language="markdown")
747
- single_metrics = gr.Textbox(label="Single-file metrics (if ground truth provided)", lines=6)
 
 
 
 
 
 
 
 
 
748
 
749
  run_btn.click(
750
  fn=run_single,
@@ -752,28 +790,38 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
752
  text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
753
  labels_text, repo, use_4bit, max_tokens, hf_token
754
  ],
755
- outputs=[summary, json_out, diag, raw, context_used, single_metrics],
756
  )
757
 
 
 
 
 
758
  with gr.Tab("Batch evaluation"):
759
  with gr.Row():
760
  with gr.Column(scale=3):
 
761
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
762
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
763
  use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
 
764
  with gr.Column(scale=2):
 
765
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
766
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
767
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
768
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
769
  limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
770
  run_batch_btn = gr.Button("Run Batch", variant="primary")
 
771
 
772
  with gr.Row():
 
773
  status = gr.Textbox(label="Status", lines=1)
774
  diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
775
- df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
776
- csv_out = gr.File(label="Download CSV", interactive=False)
 
777
 
778
  run_batch_btn.click(
779
  fn=run_batch,
@@ -782,5 +830,4 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
782
  )
783
 
784
  if __name__ == "__main__":
785
- demo = demo # to satisfy some runtimes
786
  demo.launch()
 
45
  "update_kyc_purpose_of_businessrelation",
46
  "update_kyc_total_assets",
47
  ]
48
+ OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
49
 
50
  # Per-label keyword cues (static prompt context to improve recall)
51
  LABEL_KEYWORDS: Dict[str, List[str]] = {
 
214
  """Works for gr.File(type='filepath') and raw strings/Path and file-like."""
215
  if not file_input:
216
  return ""
 
217
  if isinstance(file_input, (str, Path)):
218
  try:
219
  return Path(file_input).read_text(encoding="utf-8", errors="ignore")
220
  except Exception:
221
  return ""
 
222
  try:
223
  data = file_input.read()
224
  return data.decode("utf-8", errors="ignore")
 
283
 
284
  @torch.inference_mode()
285
  def generate(self, system_prompt: str, user_prompt: str) -> str:
286
+ # Build inputs as input_ids=... (avoid **tensor bug from earlier)
287
  if hasattr(self.tokenizer, "apply_chat_template"):
288
  messages = [
289
  {"role": "system", "content": system_prompt},
 
381
  for lab in allowed:
382
  hits = []
383
  for kw in LABEL_KEYWORDS.get(lab, []):
384
+ k = kw.lower()
385
+ if k in low:
386
  # capture small evidence window
387
+ i = low.find(k)
388
+ start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
389
  hits.append(text[start:end].strip())
390
  if hits:
391
  labels.append(lab)
 
418
  use_4bit: bool,
419
  max_input_tokens: int,
420
  hf_token: str,
421
+ ) -> Tuple[str, str, str, str, str, str, str]:
422
 
423
  t0 = _now_ms()
424
 
 
428
  raw_text = read_text_file_any(transcript_file)
429
  raw_text = (raw_text or transcript_text or "").strip()
430
  if not raw_text:
431
+ return "", "", "No transcript provided.", "", "", "", ""
432
 
433
  text = clean_transcript(raw_text) if use_cleaning else raw_text
434
 
435
+ # Allowed labels (pre-filled defaults)
436
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
437
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
438
 
 
440
  try:
441
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
442
  except Exception as e:
443
+ return "", "", f"Model load failed: {e}", "", "", "", ""
444
 
445
  # Truncate
446
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
 
459
  try:
460
  out = model.generate(SYSTEM_PROMPT, user_prompt)
461
  except Exception as e:
462
+ return "", "", f"Generation error: {e}", "", "", "", ""
463
  t2 = _now_ms()
464
 
465
  parsed = robust_json_extract(out)
 
482
  f"Allowed labels: {', '.join(allowed)}",
483
  ])
484
 
485
+ # Context & instructions preview shown in UI
486
+ context_preview = (
487
+ "### Allowed Labels\n"
488
+ + "\n".join(f"- {l}" for l in allowed)
489
+ + "\n\n### Keyword cues per label\n"
490
+ + keyword_ctx
491
+ )
492
+ instructions_preview = "```\n" + SYSTEM_PROMPT + "\n```"
493
 
494
+ # Summary & JSON
495
  labs = filtered.get("labels", [])
496
  tasks = filtered.get("tasks", [])
497
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
502
  )
503
  else:
504
  summary += "\n\nTasks: (none)"
 
505
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
506
 
507
  # Optional single-file scoring if GT provided
 
538
  else:
539
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
540
 
541
+ return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics
542
 
543
  # =========================
544
  # Batch mode (ZIP with transcripts + truths)
 
574
  except Exception: pass
575
  work.mkdir(parents=True, exist_ok=True)
576
 
 
577
  files = read_zip_from_path(zip_path, work)
578
 
579
  txts: Dict[str, Path] = {}
 
646
 
647
  rows.append({
648
  "file": stem,
649
+ "true_labels": ", "..join(gt_labels),
650
  "pred_labels": ", ".join(pred_labels),
651
  "TP": len(tp), "FP": len(fp), "FN": len(fn),
652
  "gen_ms": t1 - t0
 
693
  "mistralai/Mistral-7B-Instruct-v0.3",
694
  ]
695
 
696
+ custom_css = """
697
+ :root { --radius: 14px; }
698
+ .gradio-container { font-family: Inter, ui-sans-serif, system-ui; }
699
+ .card { border: 1px solid rgba(255,255,255,.08); border-radius: var(--radius); padding: 14px 16px; background: rgba(255,255,255,.02); box-shadow: 0 1px 10px rgba(0,0,0,.12) inset; }
700
+ .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
701
+ .subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; }
702
+ hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; }
703
+ .accordion-title { font-weight: 600; }
704
+ .gr-button { border-radius: 12px !important; }
705
+ """
706
+
707
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
708
+ gr.Markdown("<div class='header'>Talk2Task — Task Extraction (UBS Challenge)</div>")
709
+ gr.Markdown("<div class='subtle'>False negatives are penalised 2× more than false positives in the official score. This UI biases for recall, shows the exact instructions & context, and supports single or batch evaluation.</div>")
710
 
711
  with gr.Tab("Single transcript"):
712
  with gr.Row():
713
  with gr.Column(scale=3):
714
+ gr.Markdown("<div class='card'><div class='header'>Transcript</div>", elem_id="card1")
715
  file = gr.File(
716
  label="Drag & drop transcript (.txt / .md / .json)",
717
  file_types=[".txt", ".md", ".json"],
718
  type="filepath",
719
  )
720
  text = gr.Textbox(label="Or paste transcript", lines=10)
721
+ gr.Markdown("<hr class='sep'/>")
722
 
723
+ gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>", elem_id="card1b")
724
  gt_file = gr.File(
725
  label="Upload ground truth JSON (expects {'labels': [...]})",
726
  file_types=[".json"],
727
  type="filepath",
728
  )
729
+ gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
730
+ gr.Markdown("</div>") # close card
731
 
732
+ gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>", elem_id="card2")
733
  use_cleaning = gr.Checkbox(
734
  label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
735
  value=True,
 
738
  label="Keyword fallback if model returns empty",
739
  value=True,
740
  )
741
+ gr.Markdown("</div>")
742
 
743
+ gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>", elem_id="card3")
744
  labels_text = gr.Textbox(
745
+ label="Allowed Labels (one per line)",
746
+ value=OFFICIAL_LABELS_TEXT, # prefilled
747
  lines=8,
748
  )
749
+ reset_btn = gr.Button("Reset to official labels")
750
+ gr.Markdown("</div>")
751
+
752
  with gr.Column(scale=2):
753
+ gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card4")
754
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
755
  use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
756
  max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
757
  hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
758
  run_btn = gr.Button("Run Extraction", variant="primary")
759
+ gr.Markdown("</div>")
760
+
761
+ gr.Markdown("<div class='card'><div class='header'>Outputs</div>", elem_id="card5")
762
+ summary = gr.Textbox(label="Summary", lines=12)
763
+ json_out = gr.Code(label="Strict JSON Output", language="json")
764
+ diag = gr.Textbox(label="Diagnostics", lines=8)
765
+ raw = gr.Textbox(label="Raw Model Output", lines=8)
766
+ gr.Markdown("</div>")
767
 
768
  with gr.Row():
769
+ with gr.Column():
770
+ with gr.Accordion("Instructions used (system prompt)", open=False):
771
+ instr_md = gr.Markdown("")
772
+ with gr.Column():
773
+ with gr.Accordion("Context used (allowed labels + keyword cues)", open=True):
774
+ context_md = gr.Markdown("")
775
+
776
+ # reset button behavior
777
+ def _reset_labels():
778
+ return OFFICIAL_LABELS_TEXT
779
+ reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
780
+
781
+ # single run
782
+ def _pack_context_md(allowed: str) -> str:
783
+ allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
784
+ ctx = build_keyword_context(allowed_list)
785
+ return "### Allowed Labels\n" + "\n".join(f"- {l}" for l in allowed_list) + "\n\n### Keyword cues per label\n" + ctx
786
 
787
  run_btn.click(
788
  fn=run_single,
 
790
  text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
791
  labels_text, repo, use_4bit, max_tokens, hf_token
792
  ],
793
+ outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False)],
794
  )
795
 
796
+ # also keep instructions visible at initial load
797
+ instr_md.value = "```\n" + SYSTEM_PROMPT + "\n```"
798
+ context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
799
+
800
  with gr.Tab("Batch evaluation"):
801
  with gr.Row():
802
  with gr.Column(scale=3):
803
+ gr.Markdown("<div class='card'><div class='header'>ZIP input</div>", elem_id="card6")
804
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
805
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
806
  use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
807
+ gr.Markdown("</div>")
808
  with gr.Column(scale=2):
809
+ gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card7")
810
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
811
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
812
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
813
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
814
  limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
815
  run_batch_btn = gr.Button("Run Batch", variant="primary")
816
+ gr.Markdown("</div>")
817
 
818
  with gr.Row():
819
+ gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>", elem_id="card8")
820
  status = gr.Textbox(label="Status", lines=1)
821
  diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
822
+ df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
823
+ csv_out = gr.File(label="Download CSV", interactive=False)
824
+ gr.Markdown("</div>")
825
 
826
  run_batch_btn.click(
827
  fn=run_batch,
 
830
  )
831
 
832
  if __name__ == "__main__":
 
833
  demo.launch()