Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import io
|
|
@@ -28,15 +27,13 @@ SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 30 |
|
| 31 |
-
# Deterministic, compact outputs
|
| 32 |
GEN_CONFIG = GenerationConfig(
|
| 33 |
temperature=0.0,
|
| 34 |
top_p=1.0,
|
| 35 |
do_sample=False,
|
| 36 |
-
max_new_tokens=128, # raise if
|
| 37 |
)
|
| 38 |
|
| 39 |
-
# Canonical labels (UBS)
|
| 40 |
OFFICIAL_LABELS = [
|
| 41 |
"plan_contact",
|
| 42 |
"schedule_meeting",
|
|
@@ -72,7 +69,7 @@ DEFAULT_LABEL_GLOSSARY = {
|
|
| 72 |
"update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
|
| 73 |
}
|
| 74 |
|
| 75 |
-
#
|
| 76 |
DEFAULT_FALLBACK_CUES = {
|
| 77 |
"plan_contact": [
|
| 78 |
r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
|
|
@@ -250,14 +247,15 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
|
| 250 |
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
|
| 251 |
|
| 252 |
# =========================
|
| 253 |
-
# HF model wrapper (robust
|
| 254 |
# =========================
|
| 255 |
class ModelWrapper:
|
| 256 |
-
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool):
|
| 257 |
self.repo_id = repo_id
|
| 258 |
self.hf_token = hf_token
|
| 259 |
self.load_in_4bit = load_in_4bit
|
| 260 |
self.use_sdpa = use_sdpa
|
|
|
|
| 261 |
self.tokenizer = None
|
| 262 |
self.model = None
|
| 263 |
self.load_path = "uninitialized"
|
|
@@ -265,18 +263,21 @@ class ModelWrapper:
|
|
| 265 |
def _load_tokenizer(self):
|
| 266 |
fast_err = None
|
| 267 |
tok = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
try:
|
| 269 |
-
tok = AutoTokenizer.from_pretrained(
|
| 270 |
-
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
|
| 271 |
-
trust_remote_code=True, use_fast=True
|
| 272 |
-
)
|
| 273 |
except Exception as e:
|
| 274 |
fast_err = e
|
| 275 |
if tok is None:
|
| 276 |
-
tok = AutoTokenizer.from_pretrained(
|
| 277 |
-
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
|
| 278 |
-
trust_remote_code=True, use_fast=False
|
| 279 |
-
)
|
| 280 |
if tok.pad_token is None and tok.eos_token:
|
| 281 |
tok.pad_token = tok.eos_token
|
| 282 |
return tok, fast_err
|
|
@@ -372,10 +373,10 @@ class ModelWrapper:
|
|
| 372 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 373 |
|
| 374 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
| 375 |
-
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool) -> ModelWrapper:
|
| 376 |
-
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}"
|
| 377 |
if key not in _MODEL_CACHE:
|
| 378 |
-
m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa)
|
| 379 |
m.load()
|
| 380 |
_MODEL_CACHE[key] = m
|
| 381 |
return _MODEL_CACHE[key]
|
|
@@ -425,7 +426,7 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 425 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 426 |
|
| 427 |
# =========================
|
| 428 |
-
# Multilingual regex fallback
|
| 429 |
# =========================
|
| 430 |
def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
|
| 431 |
low = text.lower()
|
|
@@ -452,10 +453,10 @@ def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[st
|
|
| 452 |
def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
|
| 453 |
return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
|
| 454 |
|
| 455 |
-
def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str) -> str:
|
| 456 |
t0 = _now_ms()
|
| 457 |
try:
|
| 458 |
-
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
|
| 459 |
_ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
|
| 460 |
return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
|
| 461 |
except Exception as e:
|
|
@@ -477,11 +478,11 @@ def run_single(
|
|
| 477 |
use_sdpa: bool,
|
| 478 |
max_input_tokens: int,
|
| 479 |
hf_token: str,
|
|
|
|
| 480 |
) -> Tuple[str, str, str, str, str, str, str, str, str]:
|
| 481 |
|
| 482 |
t0 = _now_ms()
|
| 483 |
|
| 484 |
-
# Transcript
|
| 485 |
raw_text = ""
|
| 486 |
if transcript_file:
|
| 487 |
raw_text = read_text_file_any(transcript_file)
|
|
@@ -491,36 +492,29 @@ def run_single(
|
|
| 491 |
|
| 492 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 493 |
|
| 494 |
-
# Allowed labels
|
| 495 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 496 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 497 |
|
| 498 |
-
# Editable configs
|
| 499 |
try:
|
| 500 |
sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
|
| 501 |
except Exception:
|
| 502 |
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
| 503 |
-
|
| 504 |
try:
|
| 505 |
label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
|
| 506 |
except Exception:
|
| 507 |
label_glossary = DEFAULT_LABEL_GLOSSARY
|
| 508 |
-
|
| 509 |
try:
|
| 510 |
fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
|
| 511 |
except Exception:
|
| 512 |
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 513 |
|
| 514 |
-
# Model
|
| 515 |
try:
|
| 516 |
-
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
|
| 517 |
except Exception as e:
|
| 518 |
return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
|
| 519 |
|
| 520 |
-
# Truncate
|
| 521 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 522 |
|
| 523 |
-
# Build prompt
|
| 524 |
glossary_str = build_glossary_str(label_glossary, allowed)
|
| 525 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 526 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
@@ -529,13 +523,11 @@ def run_single(
|
|
| 529 |
glossary=glossary_str,
|
| 530 |
)
|
| 531 |
|
| 532 |
-
# Token info + prompt preview
|
| 533 |
transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
|
| 534 |
prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
|
| 535 |
token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
|
| 536 |
prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
|
| 537 |
|
| 538 |
-
# Generate
|
| 539 |
t1 = _now_ms()
|
| 540 |
try:
|
| 541 |
out = model.generate(sys_instructions, user_prompt)
|
|
@@ -546,7 +538,6 @@ def run_single(
|
|
| 546 |
parsed = robust_json_extract(out)
|
| 547 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 548 |
|
| 549 |
-
# Fallback merge for recall
|
| 550 |
if use_fallback:
|
| 551 |
fb = multilingual_fallback(trunc, allowed, fallback_cues)
|
| 552 |
if fb["labels"]:
|
|
@@ -555,7 +546,6 @@ def run_single(
|
|
| 555 |
merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
|
| 556 |
filtered = {"labels": merged_labels, "tasks": merged_tasks}
|
| 557 |
|
| 558 |
-
# Diagnostics
|
| 559 |
diag = "\n".join([
|
| 560 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 561 |
f"Model: {model_repo}",
|
|
@@ -567,7 +557,6 @@ def run_single(
|
|
| 567 |
f"Allowed labels: {', '.join(allowed)}",
|
| 568 |
])
|
| 569 |
|
| 570 |
-
# Summaries
|
| 571 |
labs = filtered.get("labels", [])
|
| 572 |
tasks = filtered.get("tasks", [])
|
| 573 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
@@ -580,7 +569,6 @@ def run_single(
|
|
| 580 |
summary += "\n\nTasks: (none)"
|
| 581 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 582 |
|
| 583 |
-
# Single-file metrics if GT provided
|
| 584 |
metrics = ""
|
| 585 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 586 |
truth_obj = None
|
|
@@ -613,9 +601,8 @@ def run_single(
|
|
| 613 |
else:
|
| 614 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
instructions_preview = "```\n" + sys_instructions + "\n```"
|
| 619 |
|
| 620 |
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
|
| 621 |
|
|
@@ -642,6 +629,7 @@ def run_batch(
|
|
| 642 |
use_sdpa: bool,
|
| 643 |
max_input_tokens: int,
|
| 644 |
hf_token: str,
|
|
|
|
| 645 |
limit_files: int,
|
| 646 |
) -> Tuple[str, str, pd.DataFrame, str]:
|
| 647 |
|
|
@@ -661,7 +649,6 @@ def run_batch(
|
|
| 661 |
except Exception:
|
| 662 |
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 663 |
|
| 664 |
-
# Workspace
|
| 665 |
work = Path("/tmp/batch")
|
| 666 |
if work.exists():
|
| 667 |
for p in sorted(work.rglob("*"), reverse=True):
|
|
@@ -686,9 +673,8 @@ def run_batch(
|
|
| 686 |
if not stems:
|
| 687 |
return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
|
| 688 |
|
| 689 |
-
# Model
|
| 690 |
try:
|
| 691 |
-
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
|
| 692 |
except Exception as e:
|
| 693 |
return (f"Model load failed: {e}", "", pd.DataFrame(), "")
|
| 694 |
|
|
@@ -787,12 +773,11 @@ def run_batch(
|
|
| 787 |
# UI
|
| 788 |
# =========================
|
| 789 |
MODEL_CHOICES = [
|
| 790 |
-
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 791 |
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 792 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 793 |
]
|
| 794 |
|
| 795 |
-
# White, modern UI (no purple)
|
| 796 |
custom_css = """
|
| 797 |
:root { --radius: 14px; }
|
| 798 |
.gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
|
|
@@ -806,7 +791,7 @@ a, .prose a { color: #0ea5e9; }
|
|
| 806 |
|
| 807 |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
|
| 808 |
gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
|
| 809 |
-
gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN)
|
| 810 |
|
| 811 |
with gr.Tab("Single transcript"):
|
| 812 |
with gr.Row():
|
|
@@ -850,6 +835,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 850 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 851 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 852 |
use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
|
|
|
|
| 853 |
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 854 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 855 |
warm_btn = gr.Button("Warm up model (load & compile kernels)")
|
|
@@ -875,8 +861,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 875 |
|
| 876 |
# Reset labels
|
| 877 |
reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text)
|
|
|
|
| 878 |
# Warm-up
|
| 879 |
-
warm_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
|
| 881 |
def _pack_context_md(glossary_json, allowed_text):
|
| 882 |
try:
|
|
@@ -894,7 +885,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 894 |
inputs=[
|
| 895 |
text, file, gt_text, gt_file, use_cleaning, use_fallback,
|
| 896 |
labels_text, sys_instr_tb, glossary_tb, fallback_tb,
|
| 897 |
-
repo, use_4bit, use_sdpa, max_tokens, hf_token
|
| 898 |
],
|
| 899 |
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
|
| 900 |
)
|
|
@@ -912,6 +903,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 912 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 913 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 914 |
use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
|
|
|
|
| 915 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 916 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 917 |
sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
|
|
@@ -934,10 +926,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 934 |
inputs=[
|
| 935 |
zip_in, use_cleaning_b, use_fallback_b,
|
| 936 |
sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
|
| 937 |
-
repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, limit_files
|
| 938 |
],
|
| 939 |
outputs=[status, diag_b, df_out, csv_out],
|
| 940 |
)
|
| 941 |
|
| 942 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
demo.launch()
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import io
|
|
|
|
| 27 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 29 |
|
|
|
|
| 30 |
GEN_CONFIG = GenerationConfig(
|
| 31 |
temperature=0.0,
|
| 32 |
top_p=1.0,
|
| 33 |
do_sample=False,
|
| 34 |
+
max_new_tokens=128, # raise if JSON truncates
|
| 35 |
)
|
| 36 |
|
|
|
|
| 37 |
OFFICIAL_LABELS = [
|
| 38 |
"plan_contact",
|
| 39 |
"schedule_meeting",
|
|
|
|
| 69 |
"update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
|
| 70 |
}
|
| 71 |
|
| 72 |
+
# Tiny multilingual fallback rules (optional) to avoid empty outputs
|
| 73 |
DEFAULT_FALLBACK_CUES = {
|
| 74 |
"plan_contact": [
|
| 75 |
r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
|
|
|
|
| 247 |
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
|
| 248 |
|
| 249 |
# =========================
|
| 250 |
+
# HF model wrapper (robust: fast→slow tokenizer + load fallbacks)
|
| 251 |
# =========================
|
| 252 |
class ModelWrapper:
|
| 253 |
+
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
|
| 254 |
self.repo_id = repo_id
|
| 255 |
self.hf_token = hf_token
|
| 256 |
self.load_in_4bit = load_in_4bit
|
| 257 |
self.use_sdpa = use_sdpa
|
| 258 |
+
self.force_tok_redownload = force_tok_redownload
|
| 259 |
self.tokenizer = None
|
| 260 |
self.model = None
|
| 261 |
self.load_path = "uninitialized"
|
|
|
|
| 263 |
def _load_tokenizer(self):
|
| 264 |
fast_err = None
|
| 265 |
tok = None
|
| 266 |
+
common = dict(
|
| 267 |
+
pretrained_model_name_or_path=self.repo_id,
|
| 268 |
+
token=self.hf_token,
|
| 269 |
+
cache_dir=str(SPACE_CACHE),
|
| 270 |
+
trust_remote_code=True,
|
| 271 |
+
local_files_only=False,
|
| 272 |
+
force_download=True if self.force_tok_redownload else False,
|
| 273 |
+
revision=None,
|
| 274 |
+
)
|
| 275 |
try:
|
| 276 |
+
tok = AutoTokenizer.from_pretrained(use_fast=True, **common)
|
|
|
|
|
|
|
|
|
|
| 277 |
except Exception as e:
|
| 278 |
fast_err = e
|
| 279 |
if tok is None:
|
| 280 |
+
tok = AutoTokenizer.from_pretrained(use_fast=False, **common)
|
|
|
|
|
|
|
|
|
|
| 281 |
if tok.pad_token is None and tok.eos_token:
|
| 282 |
tok.pad_token = tok.eos_token
|
| 283 |
return tok, fast_err
|
|
|
|
| 373 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 374 |
|
| 375 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
| 376 |
+
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool) -> ModelWrapper:
|
| 377 |
+
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}::{'force' if force_tok_redownload else 'cache'}"
|
| 378 |
if key not in _MODEL_CACHE:
|
| 379 |
+
m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa, force_tok_redownload)
|
| 380 |
m.load()
|
| 381 |
_MODEL_CACHE[key] = m
|
| 382 |
return _MODEL_CACHE[key]
|
|
|
|
| 426 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 427 |
|
| 428 |
# =========================
|
| 429 |
+
# Multilingual regex fallback (optional)
|
| 430 |
# =========================
|
| 431 |
def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
|
| 432 |
low = text.lower()
|
|
|
|
| 453 |
def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
|
| 454 |
return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
|
| 455 |
|
| 456 |
+
def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str, force_tok_redownload: bool) -> str:
|
| 457 |
t0 = _now_ms()
|
| 458 |
try:
|
| 459 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
|
| 460 |
_ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
|
| 461 |
return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
|
| 462 |
except Exception as e:
|
|
|
|
| 478 |
use_sdpa: bool,
|
| 479 |
max_input_tokens: int,
|
| 480 |
hf_token: str,
|
| 481 |
+
force_tok_redownload: bool,
|
| 482 |
) -> Tuple[str, str, str, str, str, str, str, str, str]:
|
| 483 |
|
| 484 |
t0 = _now_ms()
|
| 485 |
|
|
|
|
| 486 |
raw_text = ""
|
| 487 |
if transcript_file:
|
| 488 |
raw_text = read_text_file_any(transcript_file)
|
|
|
|
| 492 |
|
| 493 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 494 |
|
|
|
|
| 495 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 496 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 497 |
|
|
|
|
| 498 |
try:
|
| 499 |
sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
|
| 500 |
except Exception:
|
| 501 |
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
|
|
|
| 502 |
try:
|
| 503 |
label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
|
| 504 |
except Exception:
|
| 505 |
label_glossary = DEFAULT_LABEL_GLOSSARY
|
|
|
|
| 506 |
try:
|
| 507 |
fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
|
| 508 |
except Exception:
|
| 509 |
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 510 |
|
|
|
|
| 511 |
try:
|
| 512 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
|
| 513 |
except Exception as e:
|
| 514 |
return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
|
| 515 |
|
|
|
|
| 516 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 517 |
|
|
|
|
| 518 |
glossary_str = build_glossary_str(label_glossary, allowed)
|
| 519 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 520 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 523 |
glossary=glossary_str,
|
| 524 |
)
|
| 525 |
|
|
|
|
| 526 |
transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
|
| 527 |
prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
|
| 528 |
token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
|
| 529 |
prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
|
| 530 |
|
|
|
|
| 531 |
t1 = _now_ms()
|
| 532 |
try:
|
| 533 |
out = model.generate(sys_instructions, user_prompt)
|
|
|
|
| 538 |
parsed = robust_json_extract(out)
|
| 539 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 540 |
|
|
|
|
| 541 |
if use_fallback:
|
| 542 |
fb = multilingual_fallback(trunc, allowed, fallback_cues)
|
| 543 |
if fb["labels"]:
|
|
|
|
| 546 |
merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
|
| 547 |
filtered = {"labels": merged_labels, "tasks": merged_tasks}
|
| 548 |
|
|
|
|
| 549 |
diag = "\n".join([
|
| 550 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 551 |
f"Model: {model_repo}",
|
|
|
|
| 557 |
f"Allowed labels: {', '.join(allowed)}",
|
| 558 |
])
|
| 559 |
|
|
|
|
| 560 |
labs = filtered.get("labels", [])
|
| 561 |
tasks = filtered.get("tasks", [])
|
| 562 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
| 569 |
summary += "\n\nTasks: (none)"
|
| 570 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 571 |
|
|
|
|
| 572 |
metrics = ""
|
| 573 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 574 |
truth_obj = None
|
|
|
|
| 601 |
else:
|
| 602 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 603 |
|
| 604 |
+
context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in DEFAULT_LABEL_GLOSSARY.items() if k in allowed)
|
| 605 |
+
instructions_preview = "```\n" + (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS) + "\n```"
|
|
|
|
| 606 |
|
| 607 |
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
|
| 608 |
|
|
|
|
| 629 |
use_sdpa: bool,
|
| 630 |
max_input_tokens: int,
|
| 631 |
hf_token: str,
|
| 632 |
+
force_tok_redownload: bool,
|
| 633 |
limit_files: int,
|
| 634 |
) -> Tuple[str, str, pd.DataFrame, str]:
|
| 635 |
|
|
|
|
| 649 |
except Exception:
|
| 650 |
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 651 |
|
|
|
|
| 652 |
work = Path("/tmp/batch")
|
| 653 |
if work.exists():
|
| 654 |
for p in sorted(work.rglob("*"), reverse=True):
|
|
|
|
| 673 |
if not stems:
|
| 674 |
return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
|
| 675 |
|
|
|
|
| 676 |
try:
|
| 677 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload)
|
| 678 |
except Exception as e:
|
| 679 |
return (f"Model load failed: {e}", "", pd.DataFrame(), "")
|
| 680 |
|
|
|
|
| 773 |
# UI
|
| 774 |
# =========================
|
| 775 |
MODEL_CHOICES = [
|
| 776 |
+
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 777 |
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 778 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 779 |
]
|
| 780 |
|
|
|
|
| 781 |
custom_css = """
|
| 782 |
:root { --radius: 14px; }
|
| 783 |
.gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
|
|
|
|
| 791 |
|
| 792 |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
|
| 793 |
gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
|
| 794 |
+
gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN). Optional rules fallback for recall. Batch evaluation included.</div>")
|
| 795 |
|
| 796 |
with gr.Tab("Single transcript"):
|
| 797 |
with gr.Row():
|
|
|
|
| 835 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 836 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 837 |
use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
|
| 838 |
+
force_tok_redownload = gr.Checkbox(label="Force fresh tokenizer download", value=False)
|
| 839 |
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 840 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 841 |
warm_btn = gr.Button("Warm up model (load & compile kernels)")
|
|
|
|
| 861 |
|
| 862 |
# Reset labels
|
| 863 |
reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text)
|
| 864 |
+
|
| 865 |
# Warm-up
|
| 866 |
+
warm_btn.click(
|
| 867 |
+
fn=warmup_model,
|
| 868 |
+
inputs=[repo, use_4bit, use_sdpa, hf_token, force_tok_redownload],
|
| 869 |
+
outputs=diag
|
| 870 |
+
)
|
| 871 |
|
| 872 |
def _pack_context_md(glossary_json, allowed_text):
|
| 873 |
try:
|
|
|
|
| 885 |
inputs=[
|
| 886 |
text, file, gt_text, gt_file, use_cleaning, use_fallback,
|
| 887 |
labels_text, sys_instr_tb, glossary_tb, fallback_tb,
|
| 888 |
+
repo, use_4bit, use_sdpa, max_tokens, hf_token, force_tok_redownload
|
| 889 |
],
|
| 890 |
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
|
| 891 |
)
|
|
|
|
| 903 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 904 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 905 |
use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
|
| 906 |
+
force_tok_redownload_b = gr.Checkbox(label="Force fresh tokenizer download", value=False)
|
| 907 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 908 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 909 |
sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
|
|
|
|
| 926 |
inputs=[
|
| 927 |
zip_in, use_cleaning_b, use_fallback_b,
|
| 928 |
sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
|
| 929 |
+
repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, force_tok_redownload_b, limit_files
|
| 930 |
],
|
| 931 |
outputs=[status, diag_b, df_out, csv_out],
|
| 932 |
)
|
| 933 |
|
| 934 |
if __name__ == "__main__":
|
| 935 |
+
# Optional: print environment info to logs
|
| 936 |
+
try:
|
| 937 |
+
print("Torch version:", torch.__version__)
|
| 938 |
+
print("CUDA available:", torch.cuda.is_available())
|
| 939 |
+
if torch.cuda.is_available():
|
| 940 |
+
print("CUDA (compiled):", torch.version.cuda)
|
| 941 |
+
print("Device:", torch.cuda.get_device_name(0))
|
| 942 |
+
except Exception as _:
|
| 943 |
+
pass
|
| 944 |
+
|
| 945 |
demo.launch()
|