RishiRP commited on
Commit
5f0642c
·
verified ·
1 Parent(s): aa5f588

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -48
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import os
3
  import re
4
  import io
@@ -28,15 +27,13 @@ SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
30
 
31
- # Deterministic, compact outputs
32
  GEN_CONFIG = GenerationConfig(
33
  temperature=0.0,
34
  top_p=1.0,
35
  do_sample=False,
36
- max_new_tokens=128, # raise if your JSON truncates
37
  )
38
 
39
- # Canonical labels (UBS)
40
  OFFICIAL_LABELS = [
41
  "plan_contact",
42
  "schedule_meeting",
@@ -72,7 +69,7 @@ DEFAULT_LABEL_GLOSSARY = {
72
  "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
73
  }
74
 
75
- # Minimal multilingual fallback rules (optional)
76
  DEFAULT_FALLBACK_CUES = {
77
  "plan_contact": [
78
  r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
@@ -250,14 +247,15 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
250
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
251
 
252
  # =========================
253
- # HF model wrapper (robust loader + fast→slow tokenizer fallback)
254
  # =========================
255
  class ModelWrapper:
256
- def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool):
257
  self.repo_id = repo_id
258
  self.hf_token = hf_token
259
  self.load_in_4bit = load_in_4bit
260
  self.use_sdpa = use_sdpa
 
261
  self.tokenizer = None
262
  self.model = None
263
  self.load_path = "uninitialized"
@@ -265,18 +263,21 @@ class ModelWrapper:
265
  def _load_tokenizer(self):
266
  fast_err = None
267
  tok = None
 
 
 
 
 
 
 
 
 
268
  try:
269
- tok = AutoTokenizer.from_pretrained(
270
- self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
271
- trust_remote_code=True, use_fast=True
272
- )
273
  except Exception as e:
274
  fast_err = e
275
  if tok is None:
276
- tok = AutoTokenizer.from_pretrained(
277
- self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
278
- trust_remote_code=True, use_fast=False
279
- )
280
  if tok.pad_token is None and tok.eos_token:
281
  tok.pad_token = tok.eos_token
282
  return tok, fast_err
@@ -372,10 +373,10 @@ class ModelWrapper:
372
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
373
 
374
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
375
- def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool) -> ModelWrapper:
376
- key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}"
377
  if key not in _MODEL_CACHE:
378
- m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa)
379
  m.load()
380
  _MODEL_CACHE[key] = m
381
  return _MODEL_CACHE[key]
@@ -425,7 +426,7 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
425
  return float(max(0.0, min(1.0, np.mean(per_sample))))
426
 
427
  # =========================
428
- # Multilingual regex fallback
429
  # =========================
430
  def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
431
  low = text.lower()
@@ -452,10 +453,10 @@ def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[st
452
  def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
453
  return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
454
 
455
- def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str) -> str:
456
  t0 = _now_ms()
457
  try:
458
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
459
  _ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
460
  return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
461
  except Exception as e:
@@ -477,11 +478,11 @@ def run_single(
477
  use_sdpa: bool,
478
  max_input_tokens: int,
479
  hf_token: str,
 
480
  ) -> Tuple[str, str, str, str, str, str, str, str, str]:
481
 
482
  t0 = _now_ms()
483
 
484
- # Transcript
485
  raw_text = ""
486
  if transcript_file:
487
  raw_text = read_text_file_any(transcript_file)
@@ -491,36 +492,29 @@ def run_single(
491
 
492
  text = clean_transcript(raw_text) if use_cleaning else raw_text
493
 
494
- # Allowed labels
495
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
496
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
497
 
498
- # Editable configs
499
  try:
500
  sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
501
  except Exception:
502
  sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
503
-
504
  try:
505
  label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
506
  except Exception:
507
  label_glossary = DEFAULT_LABEL_GLOSSARY
508
-
509
  try:
510
  fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
511
  except Exception:
512
  fallback_cues = DEFAULT_FALLBACK_CUES
513
 
514
- # Model
515
  try:
516
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
517
  except Exception as e:
518
  return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
519
 
520
- # Truncate
521
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
522
 
523
- # Build prompt
524
  glossary_str = build_glossary_str(label_glossary, allowed)
525
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
526
  user_prompt = USER_PROMPT_TEMPLATE.format(
@@ -529,13 +523,11 @@ def run_single(
529
  glossary=glossary_str,
530
  )
531
 
532
- # Token info + prompt preview
533
  transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
534
  prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
535
  token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
536
  prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
537
 
538
- # Generate
539
  t1 = _now_ms()
540
  try:
541
  out = model.generate(sys_instructions, user_prompt)
@@ -546,7 +538,6 @@ def run_single(
546
  parsed = robust_json_extract(out)
547
  filtered = restrict_to_allowed(parsed, allowed)
548
 
549
- # Fallback merge for recall
550
  if use_fallback:
551
  fb = multilingual_fallback(trunc, allowed, fallback_cues)
552
  if fb["labels"]:
@@ -555,7 +546,6 @@ def run_single(
555
  merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
556
  filtered = {"labels": merged_labels, "tasks": merged_tasks}
557
 
558
- # Diagnostics
559
  diag = "\n".join([
560
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
561
  f"Model: {model_repo}",
@@ -567,7 +557,6 @@ def run_single(
567
  f"Allowed labels: {', '.join(allowed)}",
568
  ])
569
 
570
- # Summaries
571
  labs = filtered.get("labels", [])
572
  tasks = filtered.get("tasks", [])
573
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
@@ -580,7 +569,6 @@ def run_single(
580
  summary += "\n\nTasks: (none)"
581
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
582
 
583
- # Single-file metrics if GT provided
584
  metrics = ""
585
  if gt_json_file or (gt_json_text and gt_json_text.strip()):
586
  truth_obj = None
@@ -613,9 +601,8 @@ def run_single(
613
  else:
614
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
615
 
616
- # Previews
617
- context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in label_glossary.items() if k in allowed)
618
- instructions_preview = "```\n" + sys_instructions + "\n```"
619
 
620
  return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
621
 
@@ -642,6 +629,7 @@ def run_batch(
642
  use_sdpa: bool,
643
  max_input_tokens: int,
644
  hf_token: str,
 
645
  limit_files: int,
646
  ) -> Tuple[str, str, pd.DataFrame, str]:
647
 
@@ -661,7 +649,6 @@ def run_batch(
661
  except Exception:
662
  fallback_cues = DEFAULT_FALLBACK_CUES
663
 
664
- # Workspace
665
  work = Path("/tmp/batch")
666
  if work.exists():
667
  for p in sorted(work.rglob("*"), reverse=True):
@@ -686,9 +673,8 @@ def run_batch(
686
  if not stems:
687
  return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
688
 
689
- # Model
690
  try:
691
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
692
  except Exception as e:
693
  return (f"Model load failed: {e}", "", pd.DataFrame(), "")
694
 
@@ -787,12 +773,11 @@ def run_batch(
787
  # UI
788
  # =========================
789
  MODEL_CHOICES = [
790
- "swiss-ai/Apertus-8B-Instruct-2509", # multilingual
791
  "meta-llama/Meta-Llama-3-8B-Instruct",
792
  "mistralai/Mistral-7B-Instruct-v0.3",
793
  ]
794
 
795
- # White, modern UI (no purple)
796
  custom_css = """
797
  :root { --radius: 14px; }
798
  .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
@@ -806,7 +791,7 @@ a, .prose a { color: #0ea5e9; }
806
 
807
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
808
  gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
809
- gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN) with compact prompts. Optional rule fallback ensures recall. Batch evaluation & scoring included.</div>")
810
 
811
  with gr.Tab("Single transcript"):
812
  with gr.Row():
@@ -850,6 +835,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
850
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
851
  use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
852
  use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
 
853
  max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
854
  hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
855
  warm_btn = gr.Button("Warm up model (load & compile kernels)")
@@ -875,8 +861,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
875
 
876
  # Reset labels
877
  reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text)
 
878
  # Warm-up
879
- warm_btn.click(fn=warmup_model, inputs=[repo, use_4bit, use_sdpa, hf_token], outputs=diag)
 
 
 
 
880
 
881
  def _pack_context_md(glossary_json, allowed_text):
882
  try:
@@ -894,7 +885,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
894
  inputs=[
895
  text, file, gt_text, gt_file, use_cleaning, use_fallback,
896
  labels_text, sys_instr_tb, glossary_tb, fallback_tb,
897
- repo, use_4bit, use_sdpa, max_tokens, hf_token
898
  ],
899
  outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
900
  )
@@ -912,6 +903,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
912
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
913
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
914
  use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
 
915
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
916
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
917
  sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
@@ -934,10 +926,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
934
  inputs=[
935
  zip_in, use_cleaning_b, use_fallback_b,
936
  sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
937
- repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, limit_files
938
  ],
939
  outputs=[status, diag_b, df_out, csv_out],
940
  )
941
 
942
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
943
  demo.launch()
 
 
1
  import os
2
  import re
3
  import io
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
29
 
 
30
  GEN_CONFIG = GenerationConfig(
31
  temperature=0.0,
32
  top_p=1.0,
33
  do_sample=False,
34
+ max_new_tokens=128, # raise if JSON truncates
35
  )
36
 
 
37
  OFFICIAL_LABELS = [
38
  "plan_contact",
39
  "schedule_meeting",
 
69
  "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
70
  }
71
 
72
+ # Tiny multilingual fallback rules (optional) to avoid empty outputs
73
  DEFAULT_FALLBACK_CUES = {
74
  "plan_contact": [
75
  r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
 
247
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
248
 
249
  # =========================
250
+ # HF model wrapper (robust: fast→slow tokenizer + load fallbacks)
251
  # =========================
252
  class ModelWrapper:
253
+ def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
254
  self.repo_id = repo_id
255
  self.hf_token = hf_token
256
  self.load_in_4bit = load_in_4bit
257
  self.use_sdpa = use_sdpa
258
+ self.force_tok_redownload = force_tok_redownload
259
  self.tokenizer = None
260
  self.model = None
261
  self.load_path = "uninitialized"
 
263
  def _load_tokenizer(self):
264
  fast_err = None
265
  tok = None
266
+ common = dict(
267
+ pretrained_model_name_or_path=self.repo_id,
268
+ token=self.hf_token,
269
+ cache_dir=str(SPACE_CACHE),
270
+ trust_remote_code=True,
271
+ local_files_only=False,
272
+ force_download=True if self.force_tok_redownload else False,
273
+ revision=None,
274
+ )
275
  try:
276
+ tok = AutoTokenizer.from_pretrained(use_fast=True, **common)
 
 
 
277
  except Exception as e:
278
  fast_err = e
279
  if tok is None:
280
+ tok = AutoTokenizer.from_pretrained(use_fast=False, **common)
 
 
 
281
  if tok.pad_token is None and tok.eos_token:
282
  tok.pad_token = tok.eos_token
283
  return tok, fast_err
 
373
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
374
 
375
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
376
+ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool) -> ModelWrapper:
377
+ key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}::{'force' if force_tok_redownload else 'cache'}"
378
  if key not in _MODEL_CACHE:
379
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa, force_tok_redownload)
380
  m.load()
381
  _MODEL_CACHE[key] = m
382
  return _MODEL_CACHE[key]
 
426
  return float(max(0.0, min(1.0, np.mean(per_sample))))
427
 
428
  # =========================
429
+ # Multilingual regex fallback (optional)
430
  # =========================
431
  def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
432
  low = text.lower()
 
453
  def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
454
  return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
455
 
456
+ def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str, force_tok_redownload: bool) -> str:
457
  t0 = _now_ms()
458
  try:
459
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
460
  _ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
461
  return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
462
  except Exception as e:
 
478
  use_sdpa: bool,
479
  max_input_tokens: int,
480
  hf_token: str,
481
+ force_tok_redownload: bool,
482
  ) -> Tuple[str, str, str, str, str, str, str, str, str]:
483
 
484
  t0 = _now_ms()
485
 
 
486
  raw_text = ""
487
  if transcript_file:
488
  raw_text = read_text_file_any(transcript_file)
 
492
 
493
  text = clean_transcript(raw_text) if use_cleaning else raw_text
494
 
 
495
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
496
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
497
 
 
498
  try:
499
  sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
500
  except Exception:
501
  sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
 
502
  try:
503
  label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
504
  except Exception:
505
  label_glossary = DEFAULT_LABEL_GLOSSARY
 
506
  try:
507
  fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
508
  except Exception:
509
  fallback_cues = DEFAULT_FALLBACK_CUES
510
 
 
511
  try:
512
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
513
  except Exception as e:
514
  return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
515
 
 
516
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
517
 
 
518
  glossary_str = build_glossary_str(label_glossary, allowed)
519
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
520
  user_prompt = USER_PROMPT_TEMPLATE.format(
 
523
  glossary=glossary_str,
524
  )
525
 
 
526
  transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
527
  prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
528
  token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
529
  prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
530
 
 
531
  t1 = _now_ms()
532
  try:
533
  out = model.generate(sys_instructions, user_prompt)
 
538
  parsed = robust_json_extract(out)
539
  filtered = restrict_to_allowed(parsed, allowed)
540
 
 
541
  if use_fallback:
542
  fb = multilingual_fallback(trunc, allowed, fallback_cues)
543
  if fb["labels"]:
 
546
  merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
547
  filtered = {"labels": merged_labels, "tasks": merged_tasks}
548
 
 
549
  diag = "\n".join([
550
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
551
  f"Model: {model_repo}",
 
557
  f"Allowed labels: {', '.join(allowed)}",
558
  ])
559
 
 
560
  labs = filtered.get("labels", [])
561
  tasks = filtered.get("tasks", [])
562
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
569
  summary += "\n\nTasks: (none)"
570
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
571
 
 
572
  metrics = ""
573
  if gt_json_file or (gt_json_text and gt_json_text.strip()):
574
  truth_obj = None
 
601
  else:
602
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
603
 
604
+ context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in DEFAULT_LABEL_GLOSSARY.items() if k in allowed)
605
+ instructions_preview = "```\n" + (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS) + "\n```"
 
606
 
607
  return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
608
 
 
629
  use_sdpa: bool,
630
  max_input_tokens: int,
631
  hf_token: str,
632
+ force_tok_redownload: bool,
633
  limit_files: int,
634
  ) -> Tuple[str, str, pd.DataFrame, str]:
635
 
 
649
  except Exception:
650
  fallback_cues = DEFAULT_FALLBACK_CUES
651
 
 
652
  work = Path("/tmp/batch")
653
  if work.exists():
654
  for p in sorted(work.rglob("*"), reverse=True):
 
673
  if not stems:
674
  return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
675
 
 
676
  try:
677
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
678
  except Exception as e:
679
  return (f"Model load failed: {e}", "", pd.DataFrame(), "")
680
 
 
773
  # UI
774
  # =========================
775
  MODEL_CHOICES = [
776
+ "swiss-ai/Apertus-8B-Instruct-2509",
777
  "meta-llama/Meta-Llama-3-8B-Instruct",
778
  "mistralai/Mistral-7B-Instruct-v0.3",
779
  ]
780
 
 
781
  custom_css = """
782
  :root { --radius: 14px; }
783
  .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
 
791
 
792
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
793
  gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
794
+ gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN). Optional rules fallback for recall. Batch evaluation included.</div>")
795
 
796
  with gr.Tab("Single transcript"):
797
  with gr.Row():
 
835
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
836
  use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
837
  use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
838
+ force_tok_redownload = gr.Checkbox(label="Force fresh tokenizer download", value=False)
839
  max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
840
  hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
841
  warm_btn = gr.Button("Warm up model (load & compile kernels)")
 
861
 
862
  # Reset labels
863
  reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text)
864
+
865
  # Warm-up
866
+ warm_btn.click(
867
+ fn=warmup_model,
868
+ inputs=[repo, use_4bit, use_sdpa, hf_token, force_tok_redownload],
869
+ outputs=diag
870
+ )
871
 
872
  def _pack_context_md(glossary_json, allowed_text):
873
  try:
 
885
  inputs=[
886
  text, file, gt_text, gt_file, use_cleaning, use_fallback,
887
  labels_text, sys_instr_tb, glossary_tb, fallback_tb,
888
+ repo, use_4bit, use_sdpa, max_tokens, hf_token, force_tok_redownload
889
  ],
890
  outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
891
  )
 
903
  repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
904
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
905
  use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
906
+ force_tok_redownload_b = gr.Checkbox(label="Force fresh tokenizer download", value=False)
907
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
908
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
909
  sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
 
926
  inputs=[
927
  zip_in, use_cleaning_b, use_fallback_b,
928
  sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
929
+ repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, force_tok_redownload_b, limit_files
930
  ],
931
  outputs=[status, diag_b, df_out, csv_out],
932
  )
933
 
934
  if __name__ == "__main__":
935
+ # Optional: print environment info to logs
936
+ try:
937
+ print("Torch version:", torch.__version__)
938
+ print("CUDA available:", torch.cuda.is_available())
939
+ if torch.cuda.is_available():
940
+ print("CUDA (compiled):", torch.version.cuda)
941
+ print("Device:", torch.cuda.get_device_name(0))
942
+ except Exception as _:
943
+ pass
944
+
945
  demo.launch()