Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# =========================
|
| 15 |
# Utilities
|
|
@@ -55,8 +161,7 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
|
|
| 55 |
continue
|
| 56 |
k = str(t.get("label", "")).strip().lower()
|
| 57 |
if k in allowed_map:
|
| 58 |
-
new_t = dict(t)
|
| 59 |
-
new_t["label"] = allowed_map[k]
|
| 60 |
filt_tasks.append(new_t)
|
| 61 |
merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
|
| 62 |
out["labels"] = merged
|
|
@@ -64,10 +169,8 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
|
|
| 64 |
return out
|
| 65 |
|
| 66 |
# =========================
|
| 67 |
-
# Default pre-processing
|
| 68 |
# =========================
|
| 69 |
-
# These are conservative; they remove boilerplate that appears in many files
|
| 70 |
-
# and does not affect tasks. You can toggle this in the UI.
|
| 71 |
_DISCLAIMER_PATTERNS = [
|
| 72 |
r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
|
| 73 |
r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
|
|
@@ -87,8 +190,7 @@ def clean_transcript(text: str) -> str:
|
|
| 87 |
if not text:
|
| 88 |
return text
|
| 89 |
s = text
|
| 90 |
-
|
| 91 |
-
# Remove common timestamps and speaker prefixes (line-wise)
|
| 92 |
lines = []
|
| 93 |
for ln in s.splitlines():
|
| 94 |
ln2 = ln
|
|
@@ -96,16 +198,13 @@ def clean_transcript(text: str) -> str:
|
|
| 96 |
ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
|
| 97 |
lines.append(ln2)
|
| 98 |
s = "\n".join(lines)
|
| 99 |
-
|
| 100 |
-
# Remove top disclaimers
|
| 101 |
for pat in _DISCLAIMER_PATTERNS:
|
| 102 |
s = re.sub(pat, "", s).strip()
|
| 103 |
-
|
| 104 |
-
# Remove trailing footers/signatures
|
| 105 |
for pat in _FOOTER_PATTERNS:
|
| 106 |
s = re.sub(pat, "", s)
|
| 107 |
-
|
| 108 |
-
# Collapse repeated whitespace
|
| 109 |
s = re.sub(r"[ \t]+", " ", s)
|
| 110 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 111 |
return s
|
|
@@ -194,8 +293,7 @@ _MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
|
| 194 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
| 195 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 196 |
if key not in _MODEL_CACHE:
|
| 197 |
-
m = ModelWrapper(repo_id, hf_token, load_in_4bit)
|
| 198 |
-
m.load()
|
| 199 |
_MODEL_CACHE[key] = m
|
| 200 |
return _MODEL_CACHE[key]
|
| 201 |
|
|
@@ -211,7 +309,6 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 211 |
def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
|
| 212 |
if not isinstance(sample_labels, list):
|
| 213 |
raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
|
| 214 |
-
# dedupe
|
| 215 |
seen, uniq = set(), []
|
| 216 |
for label in sample_labels:
|
| 217 |
if not isinstance(label, str):
|
|
@@ -219,7 +316,6 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 219 |
if label in seen:
|
| 220 |
raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
|
| 221 |
seen.add(label); uniq.append(label)
|
| 222 |
-
# validity
|
| 223 |
valid = []
|
| 224 |
for label in uniq:
|
| 225 |
if label not in ALLOWED_LABELS:
|
|
@@ -257,10 +353,7 @@ def build_keyword_context(allowed: List[str]) -> str:
|
|
| 257 |
parts = []
|
| 258 |
for lab in allowed:
|
| 259 |
kws = LABEL_KEYWORDS.get(lab, [])
|
| 260 |
-
if kws
|
| 261 |
-
parts.append(f"- {lab}: " + ", ".join(kws))
|
| 262 |
-
else:
|
| 263 |
-
parts.append(f"- {lab}: (no default cues)")
|
| 264 |
return "\n".join(parts)
|
| 265 |
|
| 266 |
def run_single(
|
|
@@ -276,29 +369,23 @@ def run_single(
|
|
| 276 |
|
| 277 |
t0 = _now_ms()
|
| 278 |
|
| 279 |
-
# Get transcript
|
| 280 |
raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
|
| 281 |
raw_text = (raw_text or "").strip()
|
| 282 |
if not raw_text:
|
| 283 |
return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 284 |
|
| 285 |
-
# Cleaning
|
| 286 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 287 |
|
| 288 |
-
# Allowed labels
|
| 289 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 290 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 291 |
|
| 292 |
-
# Model
|
| 293 |
try:
|
| 294 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 295 |
except Exception as e:
|
| 296 |
return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 297 |
|
| 298 |
-
# Truncate
|
| 299 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 300 |
|
| 301 |
-
# Build prompt
|
| 302 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 303 |
keyword_ctx = build_keyword_context(allowed)
|
| 304 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
@@ -307,7 +394,6 @@ def run_single(
|
|
| 307 |
keyword_context=keyword_ctx,
|
| 308 |
)
|
| 309 |
|
| 310 |
-
# Generate
|
| 311 |
t1 = _now_ms()
|
| 312 |
try:
|
| 313 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
|
@@ -315,11 +401,9 @@ def run_single(
|
|
| 315 |
return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 316 |
t2 = _now_ms()
|
| 317 |
|
| 318 |
-
# Parse + filter
|
| 319 |
parsed = robust_json_extract(out)
|
| 320 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 321 |
|
| 322 |
-
# Diagnostics
|
| 323 |
diag = "\n".join([
|
| 324 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 325 |
f"Model: {model_repo}",
|
|
@@ -329,7 +413,6 @@ def run_single(
|
|
| 329 |
f"Allowed labels: {', '.join(allowed)}",
|
| 330 |
])
|
| 331 |
|
| 332 |
-
# Summary
|
| 333 |
labs = filtered.get("labels", [])
|
| 334 |
tasks = filtered.get("tasks", [])
|
| 335 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
@@ -350,11 +433,7 @@ def read_zip(fileobj: io.BytesIO, exdir: Path) -> List[Path]:
|
|
| 350 |
exdir.mkdir(parents=True, exist_ok=True)
|
| 351 |
with zipfile.ZipFile(fileobj) as zf:
|
| 352 |
zf.extractall(exdir)
|
| 353 |
-
|
| 354 |
-
for p in exdir.rglob("*"):
|
| 355 |
-
if p.is_file():
|
| 356 |
-
out.append(p)
|
| 357 |
-
return out
|
| 358 |
|
| 359 |
def run_batch(
|
| 360 |
zip_file: gr.File,
|
|
@@ -364,25 +443,27 @@ def run_batch(
|
|
| 364 |
max_input_tokens: int,
|
| 365 |
hf_token: str,
|
| 366 |
limit_files: int,
|
| 367 |
-
) -> Tuple[str, str,
|
| 368 |
|
| 369 |
if not zip_file:
|
| 370 |
-
return ("No ZIP provided.", "",
|
| 371 |
|
| 372 |
work = Path("/tmp/batch")
|
| 373 |
if work.exists():
|
| 374 |
-
for p in work.rglob("*"):
|
| 375 |
-
try:
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
work.mkdir(parents=True, exist_ok=True)
|
| 380 |
|
| 381 |
-
# Unzip
|
| 382 |
data = zip_file.read()
|
| 383 |
files = read_zip(io.BytesIO(data), work)
|
| 384 |
|
| 385 |
-
# Gather pairs by stem
|
| 386 |
txts: Dict[str, Path] = {}
|
| 387 |
gts: Dict[str, Path] = {}
|
| 388 |
for p in files:
|
|
@@ -395,15 +476,14 @@ def run_batch(
|
|
| 395 |
if limit_files > 0:
|
| 396 |
stems = stems[:limit_files]
|
| 397 |
if not stems:
|
| 398 |
-
return ("No .txt transcripts found in ZIP.", "",
|
| 399 |
|
| 400 |
-
# Model
|
| 401 |
try:
|
| 402 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 403 |
except Exception as e:
|
| 404 |
-
return (f"Model load failed: {e}", "",
|
| 405 |
|
| 406 |
-
allowed = OFFICIAL_LABELS[:]
|
| 407 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 408 |
keyword_ctx = build_keyword_context(allowed)
|
| 409 |
|
|
@@ -431,20 +511,17 @@ def run_batch(
|
|
| 431 |
pred_labels = filtered.get("labels", [])
|
| 432 |
y_pred.append(pred_labels)
|
| 433 |
|
| 434 |
-
# Ground truth (optional)
|
| 435 |
gt_labels = []
|
| 436 |
if stem in gts:
|
| 437 |
try:
|
| 438 |
gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
|
| 439 |
-
if isinstance(gt_obj, dict) and
|
| 440 |
gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
|
| 441 |
except Exception:
|
| 442 |
pass
|
| 443 |
y_true.append(gt_labels)
|
| 444 |
|
| 445 |
-
|
| 446 |
-
gt_set = set(gt_labels)
|
| 447 |
-
pr_set = set(pred_labels)
|
| 448 |
tp = sorted(gt_set & pr_set)
|
| 449 |
fp = sorted(pr_set - gt_set)
|
| 450 |
fn = sorted(gt_set - pr_set)
|
|
@@ -457,8 +534,6 @@ def run_batch(
|
|
| 457 |
"gen_ms": t1 - t0
|
| 458 |
})
|
| 459 |
|
| 460 |
-
# Metrics
|
| 461 |
-
# If there is no ground truth in the ZIP, we still compute a table and skip score.
|
| 462 |
have_truth = any(len(v) > 0 for v in y_true)
|
| 463 |
score = evaluate_predictions(y_true, y_pred) if have_truth else None
|
| 464 |
|
|
@@ -472,7 +547,6 @@ def run_batch(
|
|
| 472 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 473 |
]
|
| 474 |
if have_truth and score is not None:
|
| 475 |
-
# Simple derived metrics
|
| 476 |
total_tp = int(df["TP"].sum())
|
| 477 |
total_fp = int(df["FP"].sum())
|
| 478 |
total_fn = int(df["FN"].sum())
|
|
@@ -486,12 +560,11 @@ def run_batch(
|
|
| 486 |
]
|
| 487 |
diag_str = "\n".join(diag)
|
| 488 |
|
| 489 |
-
# CSV
|
| 490 |
-
|
| 491 |
-
df.to_csv(
|
| 492 |
-
csv_data = csv_buf.getvalue()
|
| 493 |
|
| 494 |
-
return ("Batch done.", diag_str,
|
| 495 |
|
| 496 |
# =========================
|
| 497 |
# UI
|
|
@@ -505,10 +578,8 @@ MODEL_CHOICES = [
|
|
| 505 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
| 506 |
gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
|
| 507 |
gr.Markdown(
|
| 508 |
-
"
|
| 509 |
-
"
|
| 510 |
-
"_Note: False negatives are penalised twice as much as false positives in the official metric; "
|
| 511 |
-
"we bias for recall._"
|
| 512 |
)
|
| 513 |
|
| 514 |
with gr.Tab("Single transcript"):
|
|
@@ -520,9 +591,12 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 520 |
type="filepath",
|
| 521 |
)
|
| 522 |
text = gr.Textbox(label="Or paste transcript", lines=14)
|
| 523 |
-
use_cleaning = gr.Checkbox(
|
|
|
|
|
|
|
|
|
|
| 524 |
labels_text = gr.Textbox(
|
| 525 |
-
label="Allowed Labels (one per line;
|
| 526 |
value="",
|
| 527 |
lines=8,
|
| 528 |
)
|
|
@@ -561,23 +635,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 561 |
|
| 562 |
with gr.Row():
|
| 563 |
status = gr.Textbox(label="Status", lines=1)
|
| 564 |
-
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=
|
| 565 |
-
|
| 566 |
-
with gr.Row():
|
| 567 |
-
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, times)", interactive=False)
|
| 568 |
-
csv_out = gr.File(label="Download CSV (click to save)", interactive=False)
|
| 569 |
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
return ""
|
| 573 |
-
out_path = Path("/tmp/batch_results.csv")
|
| 574 |
-
out_path.write_text(csv_text, encoding="utf-8")
|
| 575 |
-
return str(out_path)
|
| 576 |
|
| 577 |
run_batch_btn.click(
|
| 578 |
fn=run_batch,
|
| 579 |
inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
|
| 580 |
-
outputs=[status, diag_b,
|
| 581 |
)
|
| 582 |
|
| 583 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import zipfile
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
AutoModelForCausalLM,
|
| 19 |
+
BitsAndBytesConfig,
|
| 20 |
+
GenerationConfig,
|
| 21 |
+
)
|
| 22 |
|
| 23 |
+
# =========================
|
| 24 |
+
# Global config
|
| 25 |
+
# =========================
|
| 26 |
+
SPACE_CACHE = Path.home() / ".cache" / "huggingface"
|
| 27 |
+
SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
GEN_CONFIG = GenerationConfig(
|
| 31 |
+
temperature=0.2,
|
| 32 |
+
top_p=0.9,
|
| 33 |
+
do_sample=False,
|
| 34 |
+
max_new_tokens=256,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Official UBS label set (strict)
|
| 38 |
+
OFFICIAL_LABELS = [
|
| 39 |
+
"plan_contact",
|
| 40 |
+
"schedule_meeting",
|
| 41 |
+
"update_contact_info_non_postal",
|
| 42 |
+
"update_contact_info_postal_address",
|
| 43 |
+
"update_kyc_activity",
|
| 44 |
+
"update_kyc_origin_of_assets",
|
| 45 |
+
"update_kyc_purpose_of_businessrelation",
|
| 46 |
+
"update_kyc_total_assets",
|
| 47 |
+
]
|
| 48 |
|
| 49 |
+
# Per-label keyword cues (static prompt context to improve recall)
|
| 50 |
+
LABEL_KEYWORDS: Dict[str, List[str]] = {
|
| 51 |
+
"plan_contact": [
|
| 52 |
+
"call back", "follow up", "reach out", "contact later", "check-in",
|
| 53 |
+
"email them", "touch base", "remind", "send a note"
|
| 54 |
+
],
|
| 55 |
+
"schedule_meeting": [
|
| 56 |
+
"book a meeting", "set up a meeting", "schedule a call",
|
| 57 |
+
"appointment", "calendar", "meeting next week", "meet on", "time slot"
|
| 58 |
+
],
|
| 59 |
+
"update_contact_info_non_postal": [
|
| 60 |
+
"phone change", "new phone", "email change", "new email",
|
| 61 |
+
"update contact details", "update mobile", "alternate phone"
|
| 62 |
+
],
|
| 63 |
+
"update_contact_info_postal_address": [
|
| 64 |
+
"moved to", "new address", "postal address", "mailing address",
|
| 65 |
+
"change of address", "residential address"
|
| 66 |
+
],
|
| 67 |
+
"update_kyc_activity": [
|
| 68 |
+
"activity update", "economic activity", "employment status",
|
| 69 |
+
"occupation", "job change", "business activity"
|
| 70 |
+
],
|
| 71 |
+
"update_kyc_origin_of_assets": [
|
| 72 |
+
"source of funds", "origin of assets", "where money comes from",
|
| 73 |
+
"inheritance", "salary", "business income", "asset origin"
|
| 74 |
+
],
|
| 75 |
+
"update_kyc_purpose_of_businessrelation": [
|
| 76 |
+
"purpose of relationship", "why the account", "reason for banking",
|
| 77 |
+
"investment purpose", "relationship purpose"
|
| 78 |
+
],
|
| 79 |
+
"update_kyc_total_assets": [
|
| 80 |
+
"total assets", "net worth", "assets under ownership",
|
| 81 |
+
"portfolio size", "how much you own"
|
| 82 |
+
],
|
| 83 |
+
}
|
| 84 |
|
| 85 |
+
# =========================
|
| 86 |
+
# Instructions (string-safe; concatenated)
|
| 87 |
+
# =========================
|
| 88 |
+
SYSTEM_PROMPT = (
|
| 89 |
+
"You are a precise banking assistant that extracts ACTIONABLE TASKS from "
|
| 90 |
+
"client–advisor transcripts. Be conservative with hallucinations but "
|
| 91 |
+
"prioritise RECALL: if unsure and the transcript plausibly implies an "
|
| 92 |
+
"action, include the label and explain briefly.\n\n"
|
| 93 |
+
"Output STRICT JSON only:\n\n"
|
| 94 |
+
"{\n"
|
| 95 |
+
' "labels": ["<Label1>", "..."],\n'
|
| 96 |
+
' "tasks": [\n'
|
| 97 |
+
' {"label": "<Label1>", "explanation": "<why>", "evidence": "<quoted text/snippet>"}\n'
|
| 98 |
+
" ]\n"
|
| 99 |
+
"}\n\n"
|
| 100 |
+
"Rules:\n"
|
| 101 |
+
"- Use ONLY allowed labels supplied to you. Case-insensitive during reasoning, "
|
| 102 |
+
" but output the canonical label text exactly.\n"
|
| 103 |
+
"- If none truly apply, return empty lists.\n"
|
| 104 |
+
"- Keep explanations concise; put the minimal evidence snippet that justifies the task.\n"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
USER_PROMPT_TEMPLATE = (
|
| 108 |
+
"Transcript (cleaned):\n"
|
| 109 |
+
"```\n{transcript}\n```\n\n"
|
| 110 |
+
"Allowed Labels (canonical; use only these):\n"
|
| 111 |
+
"{allowed_labels_list}\n\n"
|
| 112 |
+
"Context cues (keywords/phrases that often indicate each label):\n"
|
| 113 |
+
"{keyword_context}\n\n"
|
| 114 |
+
"Instructions:\n"
|
| 115 |
+
"- Identify EVERY concrete task implied by the conversation.\n"
|
| 116 |
+
"- Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).\n"
|
| 117 |
+
"- Return STRICT JSON only in the exact schema described by the system prompt.\n"
|
| 118 |
+
)
|
| 119 |
|
| 120 |
# =========================
|
| 121 |
# Utilities
|
|
|
|
| 161 |
continue
|
| 162 |
k = str(t.get("label", "")).strip().lower()
|
| 163 |
if k in allowed_map:
|
| 164 |
+
new_t = dict(t); new_t["label"] = allowed_map[k]
|
|
|
|
| 165 |
filt_tasks.append(new_t)
|
| 166 |
merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
|
| 167 |
out["labels"] = merged
|
|
|
|
| 169 |
return out
|
| 170 |
|
| 171 |
# =========================
|
| 172 |
+
# Default pre-processing (toggleable)
|
| 173 |
# =========================
|
|
|
|
|
|
|
| 174 |
_DISCLAIMER_PATTERNS = [
|
| 175 |
r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
|
| 176 |
r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
|
|
|
|
| 190 |
if not text:
|
| 191 |
return text
|
| 192 |
s = text
|
| 193 |
+
# remove timestamps/speaker prefixes line-wise
|
|
|
|
| 194 |
lines = []
|
| 195 |
for ln in s.splitlines():
|
| 196 |
ln2 = ln
|
|
|
|
| 198 |
ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
|
| 199 |
lines.append(ln2)
|
| 200 |
s = "\n".join(lines)
|
| 201 |
+
# remove top disclaimers
|
|
|
|
| 202 |
for pat in _DISCLAIMER_PATTERNS:
|
| 203 |
s = re.sub(pat, "", s).strip()
|
| 204 |
+
# remove trailing footers
|
|
|
|
| 205 |
for pat in _FOOTER_PATTERNS:
|
| 206 |
s = re.sub(pat, "", s)
|
| 207 |
+
# collapse whitespace
|
|
|
|
| 208 |
s = re.sub(r"[ \t]+", " ", s)
|
| 209 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 210 |
return s
|
|
|
|
| 293 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
| 294 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 295 |
if key not in _MODEL_CACHE:
|
| 296 |
+
m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load()
|
|
|
|
| 297 |
_MODEL_CACHE[key] = m
|
| 298 |
return _MODEL_CACHE[key]
|
| 299 |
|
|
|
|
| 309 |
def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
|
| 310 |
if not isinstance(sample_labels, list):
|
| 311 |
raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
|
|
|
|
| 312 |
seen, uniq = set(), []
|
| 313 |
for label in sample_labels:
|
| 314 |
if not isinstance(label, str):
|
|
|
|
| 316 |
if label in seen:
|
| 317 |
raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
|
| 318 |
seen.add(label); uniq.append(label)
|
|
|
|
| 319 |
valid = []
|
| 320 |
for label in uniq:
|
| 321 |
if label not in ALLOWED_LABELS:
|
|
|
|
| 353 |
parts = []
|
| 354 |
for lab in allowed:
|
| 355 |
kws = LABEL_KEYWORDS.get(lab, [])
|
| 356 |
+
parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
|
|
|
|
|
|
|
|
|
|
| 357 |
return "\n".join(parts)
|
| 358 |
|
| 359 |
def run_single(
|
|
|
|
| 369 |
|
| 370 |
t0 = _now_ms()
|
| 371 |
|
|
|
|
| 372 |
raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
|
| 373 |
raw_text = (raw_text or "").strip()
|
| 374 |
if not raw_text:
|
| 375 |
return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 376 |
|
|
|
|
| 377 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 378 |
|
|
|
|
| 379 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 380 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 381 |
|
|
|
|
| 382 |
try:
|
| 383 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 384 |
except Exception as e:
|
| 385 |
return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 386 |
|
|
|
|
| 387 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 388 |
|
|
|
|
| 389 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 390 |
keyword_ctx = build_keyword_context(allowed)
|
| 391 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 394 |
keyword_context=keyword_ctx,
|
| 395 |
)
|
| 396 |
|
|
|
|
| 397 |
t1 = _now_ms()
|
| 398 |
try:
|
| 399 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
|
|
|
| 401 |
return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 402 |
t2 = _now_ms()
|
| 403 |
|
|
|
|
| 404 |
parsed = robust_json_extract(out)
|
| 405 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 406 |
|
|
|
|
| 407 |
diag = "\n".join([
|
| 408 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 409 |
f"Model: {model_repo}",
|
|
|
|
| 413 |
f"Allowed labels: {', '.join(allowed)}",
|
| 414 |
])
|
| 415 |
|
|
|
|
| 416 |
labs = filtered.get("labels", [])
|
| 417 |
tasks = filtered.get("tasks", [])
|
| 418 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
| 433 |
exdir.mkdir(parents=True, exist_ok=True)
|
| 434 |
with zipfile.ZipFile(fileobj) as zf:
|
| 435 |
zf.extractall(exdir)
|
| 436 |
+
return [p for p in exdir.rglob("*") if p.is_file()]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
def run_batch(
|
| 439 |
zip_file: gr.File,
|
|
|
|
| 443 |
max_input_tokens: int,
|
| 444 |
hf_token: str,
|
| 445 |
limit_files: int,
|
| 446 |
+
) -> Tuple[str, str, pd.DataFrame, str]:
|
| 447 |
|
| 448 |
if not zip_file:
|
| 449 |
+
return ("No ZIP provided.", "", pd.DataFrame(), "")
|
| 450 |
|
| 451 |
work = Path("/tmp/batch")
|
| 452 |
if work.exists():
|
| 453 |
+
for p in sorted(work.rglob("*"), reverse=True):
|
| 454 |
+
try:
|
| 455 |
+
p.unlink()
|
| 456 |
+
except Exception:
|
| 457 |
+
pass
|
| 458 |
+
try:
|
| 459 |
+
work.rmdir()
|
| 460 |
+
except Exception:
|
| 461 |
+
pass
|
| 462 |
work.mkdir(parents=True, exist_ok=True)
|
| 463 |
|
|
|
|
| 464 |
data = zip_file.read()
|
| 465 |
files = read_zip(io.BytesIO(data), work)
|
| 466 |
|
|
|
|
| 467 |
txts: Dict[str, Path] = {}
|
| 468 |
gts: Dict[str, Path] = {}
|
| 469 |
for p in files:
|
|
|
|
| 476 |
if limit_files > 0:
|
| 477 |
stems = stems[:limit_files]
|
| 478 |
if not stems:
|
| 479 |
+
return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
|
| 480 |
|
|
|
|
| 481 |
try:
|
| 482 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 483 |
except Exception as e:
|
| 484 |
+
return (f"Model load failed: {e}", "", pd.DataFrame(), "")
|
| 485 |
|
| 486 |
+
allowed = OFFICIAL_LABELS[:]
|
| 487 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 488 |
keyword_ctx = build_keyword_context(allowed)
|
| 489 |
|
|
|
|
| 511 |
pred_labels = filtered.get("labels", [])
|
| 512 |
y_pred.append(pred_labels)
|
| 513 |
|
|
|
|
| 514 |
gt_labels = []
|
| 515 |
if stem in gts:
|
| 516 |
try:
|
| 517 |
gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
|
| 518 |
+
if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list):
|
| 519 |
gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
|
| 520 |
except Exception:
|
| 521 |
pass
|
| 522 |
y_true.append(gt_labels)
|
| 523 |
|
| 524 |
+
gt_set, pr_set = set(gt_labels), set(pred_labels)
|
|
|
|
|
|
|
| 525 |
tp = sorted(gt_set & pr_set)
|
| 526 |
fp = sorted(pr_set - gt_set)
|
| 527 |
fn = sorted(gt_set - pr_set)
|
|
|
|
| 534 |
"gen_ms": t1 - t0
|
| 535 |
})
|
| 536 |
|
|
|
|
|
|
|
| 537 |
have_truth = any(len(v) > 0 for v in y_true)
|
| 538 |
score = evaluate_predictions(y_true, y_pred) if have_truth else None
|
| 539 |
|
|
|
|
| 547 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 548 |
]
|
| 549 |
if have_truth and score is not None:
|
|
|
|
| 550 |
total_tp = int(df["TP"].sum())
|
| 551 |
total_fp = int(df["FP"].sum())
|
| 552 |
total_fn = int(df["FN"].sum())
|
|
|
|
| 560 |
]
|
| 561 |
diag_str = "\n".join(diag)
|
| 562 |
|
| 563 |
+
# save CSV for download
|
| 564 |
+
out_csv = Path("/tmp/batch_results.csv")
|
| 565 |
+
df.to_csv(out_csv, index=False, encoding="utf-8")
|
|
|
|
| 566 |
|
| 567 |
+
return ("Batch done.", diag_str, df, str(out_csv))
|
| 568 |
|
| 569 |
# =========================
|
| 570 |
# UI
|
|
|
|
| 578 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
| 579 |
gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
|
| 580 |
gr.Markdown(
|
| 581 |
+
"Extract challenge labels from transcripts. False negatives are penalised 2× more than false positives "
|
| 582 |
+
"in the official score, so the app biases for recall."
|
|
|
|
|
|
|
| 583 |
)
|
| 584 |
|
| 585 |
with gr.Tab("Single transcript"):
|
|
|
|
| 591 |
type="filepath",
|
| 592 |
)
|
| 593 |
text = gr.Textbox(label="Or paste transcript", lines=14)
|
| 594 |
+
use_cleaning = gr.Checkbox(
|
| 595 |
+
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 596 |
+
value=True,
|
| 597 |
+
)
|
| 598 |
labels_text = gr.Textbox(
|
| 599 |
+
label="Allowed Labels (one per line; empty = official list)",
|
| 600 |
value="",
|
| 601 |
lines=8,
|
| 602 |
)
|
|
|
|
| 635 |
|
| 636 |
with gr.Row():
|
| 637 |
status = gr.Textbox(label="Status", lines=1)
|
| 638 |
+
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
+
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|
| 641 |
+
csv_out = gr.File(label="Download CSV", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
run_batch_btn.click(
|
| 644 |
fn=run_batch,
|
| 645 |
inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
|
| 646 |
+
outputs=[status, diag_b, df_out, csv_out],
|
| 647 |
)
|
| 648 |
|
| 649 |
if __name__ == "__main__":
|