Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
| 1 |
import os, io, re, sys, time, json, zipfile, statistics
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import List, Dict, Tuple, Union
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import pandas as pd
|
| 7 |
import torch
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
# If the 'spaces' package is available (on Spaces), we use @spaces.GPU.
|
| 12 |
-
# Locally / on CPU hardware, we create a no-op decorator so the code still runs.
|
| 13 |
try:
|
| 14 |
-
import spaces #
|
| 15 |
except Exception:
|
| 16 |
class _DummySpaces:
|
| 17 |
def GPU(self, *args, **kwargs):
|
|
@@ -19,16 +17,22 @@ except Exception:
|
|
| 19 |
return deco
|
| 20 |
spaces = _DummySpaces()
|
| 21 |
|
| 22 |
-
#
|
| 23 |
HF_TOKEN = (
|
| 24 |
os.getenv("HF_TOKEN")
|
| 25 |
or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 26 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 27 |
)
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
ALLOWED_LABELS = [
|
| 33 |
"plan_contact",
|
| 34 |
"schedule_meeting",
|
|
@@ -39,7 +43,7 @@ ALLOWED_LABELS = [
|
|
| 39 |
"update_kyc_purpose_of_businessrelation",
|
| 40 |
"update_kyc_total_assets",
|
| 41 |
]
|
| 42 |
-
LABEL_TO_IDX = {l:i for i,l in enumerate(ALLOWED_LABELS)}
|
| 43 |
FN_PENALTY = 2.0
|
| 44 |
FP_PENALTY = 1.0
|
| 45 |
|
|
@@ -48,7 +52,7 @@ def safe_json_load(s: str):
|
|
| 48 |
return json.loads(s)
|
| 49 |
except Exception:
|
| 50 |
pass
|
| 51 |
-
m = re.search(r
|
| 52 |
if m:
|
| 53 |
try:
|
| 54 |
return json.loads(m.group(0))
|
|
@@ -62,29 +66,29 @@ def _coerce_labels_list(x):
|
|
| 62 |
for it in x:
|
| 63 |
if isinstance(it, str): out.append(it)
|
| 64 |
elif isinstance(it, dict):
|
| 65 |
-
for k in ("label","value","task","category","name"):
|
| 66 |
v = it.get(k)
|
| 67 |
if isinstance(v, str):
|
| 68 |
out.append(v); break
|
| 69 |
else:
|
| 70 |
if isinstance(it.get("labels"), list):
|
| 71 |
out += [s for s in it["labels"] if isinstance(s, str)]
|
| 72 |
-
|
|
|
|
| 73 |
for s in out:
|
| 74 |
if s not in seen:
|
| 75 |
norm.append(s); seen.add(s)
|
| 76 |
return norm
|
| 77 |
if isinstance(x, dict):
|
| 78 |
-
for k in ("expected_labels","labels","targets","y_true"):
|
| 79 |
if k in x: return _coerce_labels_list(x[k])
|
| 80 |
if "one_hot" in x and isinstance(x["one_hot"], dict):
|
| 81 |
-
return [k for k,v in x["one_hot"].items() if v]
|
| 82 |
return []
|
| 83 |
|
| 84 |
def classic_metrics(pred_labels, exp_labels):
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
pred = set(pred_labels); gold = set(exp_labels)
|
| 88 |
if not pred and not gold:
|
| 89 |
return True, 1.0, 1.0, 1.0, 1.0
|
| 90 |
inter = pred & gold; union = pred | gold
|
|
@@ -108,9 +112,7 @@ def ubs_score_one(true_labels, pred_labels) -> float:
|
|
| 108 |
score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
|
| 109 |
return float(max(0.0, min(1.0, score)))
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
# Lightweight Preprocess
|
| 113 |
-
# =======================
|
| 114 |
EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
|
| 115 |
TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
|
| 116 |
DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
|
|
@@ -201,11 +203,9 @@ def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
|
|
| 201 |
ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
|
| 202 |
est = len(ids)
|
| 203 |
threshold = int(soft_cap_tokens * apply_only_if_ratio)
|
| 204 |
-
if est <= threshold:
|
| 205 |
-
return text
|
| 206 |
parts = text.splitlines()
|
| 207 |
-
if len(parts) <= min_lines_keep:
|
| 208 |
-
return text
|
| 209 |
|
| 210 |
keep_flags=[]
|
| 211 |
for ln in parts:
|
|
@@ -230,15 +230,13 @@ def shrink_to_token_cap_by_lines(text: str, soft_cap_tokens: int, tokenizer,
|
|
| 230 |
candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
|
| 231 |
candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
|
| 232 |
|
| 233 |
-
if len(candidate.splitlines()) < min_lines_keep:
|
| 234 |
-
return text
|
| 235 |
return candidate
|
| 236 |
|
| 237 |
def enforce_rules(labels, transcript_text):
|
| 238 |
labels = set(labels or [])
|
| 239 |
if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
|
| 240 |
-
labels.add("schedule_meeting")
|
| 241 |
-
labels.discard("plan_contact")
|
| 242 |
if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
|
| 243 |
labels.add("update_contact_info_non_postal")
|
| 244 |
kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
|
|
@@ -246,9 +244,7 @@ def enforce_rules(labels, transcript_text):
|
|
| 246 |
labels.discard("update_kyc_activity")
|
| 247 |
return sorted(labels)
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
# HF Model Wrapper
|
| 251 |
-
# =======================
|
| 252 |
class HFModel:
|
| 253 |
def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
|
| 254 |
self.repo_id = repo_id
|
|
@@ -260,19 +256,16 @@ class HFModel:
|
|
| 260 |
self.model = None
|
| 261 |
if load_4bit:
|
| 262 |
try:
|
| 263 |
-
|
| 264 |
-
load_in_4bit=True,
|
| 265 |
-
|
| 266 |
-
bnb_4bit_compute_dtype=torch_dtype,
|
| 267 |
-
bnb_4bit_quant_type="nf4"
|
| 268 |
)
|
| 269 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 270 |
repo_id, device_map="auto", trust_remote_code=trust_remote_code,
|
| 271 |
-
quantization_config=
|
| 272 |
)
|
| 273 |
except Exception as e:
|
| 274 |
print(f"[WARN] 4-bit load failed for {repo_id}: {e}\nFalling back to normal load...", file=sys.stderr)
|
| 275 |
-
|
| 276 |
if self.model is None:
|
| 277 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 278 |
repo_id, device_map="auto", trust_remote_code=trust_remote_code,
|
|
@@ -282,9 +275,6 @@ class HFModel:
|
|
| 282 |
self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
|
| 283 |
or getattr(self.model.config, "max_sequence_length", None) or 8192
|
| 284 |
|
| 285 |
-
def encode_len(self, text: str) -> int:
|
| 286 |
-
return len(self.tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids)
|
| 287 |
-
|
| 288 |
def apply_chat_template(self, system_text: str, user_text: str) -> str:
|
| 289 |
if getattr(self.tokenizer, "chat_template", None):
|
| 290 |
messages = [{"role":"system","content":system_text},
|
|
@@ -300,78 +290,63 @@ class HFModel:
|
|
| 300 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 301 |
t0 = time.perf_counter()
|
| 302 |
out = self.model.generate(
|
| 303 |
-
**inputs,
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
temperature=None,
|
| 307 |
-
top_p=None,
|
| 308 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
| 309 |
)
|
| 310 |
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 311 |
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
|
| 312 |
-
if text.startswith(prompt):
|
| 313 |
-
text = text[len(prompt):]
|
| 314 |
return latency_ms, text, prompt
|
| 315 |
|
| 316 |
-
# Cache
|
| 317 |
MODEL_CACHE: Dict[str, HFModel] = {}
|
| 318 |
-
|
| 319 |
def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
|
| 320 |
if repo_id not in MODEL_CACHE:
|
| 321 |
MODEL_CACHE[repo_id] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
|
| 322 |
return MODEL_CACHE[repo_id]
|
| 323 |
|
| 324 |
-
#
|
| 325 |
-
#
|
| 326 |
-
# =======================
|
| 327 |
-
@spaces.GPU(duration=180) # required by ZeroGPU; no-op on CPU
|
| 328 |
def gpu_generate(repo_id: str, system_text: str, user_text: str,
|
| 329 |
load_4bit: bool, dtype: str, trust_remote_code: bool):
|
|
|
|
| 330 |
hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
#
|
| 334 |
-
# Utility (ZIP I/O)
|
| 335 |
-
# =======================
|
| 336 |
def _read_zip_bytes(dataset_zip: Union[bytes, str, dict, None]) -> bytes:
|
| 337 |
-
if dataset_zip is None:
|
| 338 |
-
|
| 339 |
-
if isinstance(dataset_zip, bytes):
|
| 340 |
-
return dataset_zip
|
| 341 |
if isinstance(dataset_zip, str):
|
| 342 |
-
with open(dataset_zip, "rb") as f:
|
| 343 |
-
return f.read()
|
| 344 |
if isinstance(dataset_zip, dict) and "path" in dataset_zip:
|
| 345 |
-
with open(dataset_zip["path"], "rb") as f:
|
| 346 |
-
return f.read()
|
| 347 |
path = getattr(dataset_zip, "name", None)
|
| 348 |
if path and os.path.exists(path):
|
| 349 |
-
with open(path, "rb") as f:
|
| 350 |
-
|
| 351 |
-
raise ValueError("Unsupported file object received from Gradio")
|
| 352 |
|
| 353 |
def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
|
| 354 |
zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
|
| 355 |
-
names = zf.namelist()
|
| 356 |
samples = {}
|
| 357 |
-
for n in
|
| 358 |
p = Path(n)
|
| 359 |
if p.suffix.lower() == ".txt":
|
| 360 |
-
|
| 361 |
-
txt = zf.read(n).decode("utf-8", "replace")
|
| 362 |
-
samples.setdefault(sample_id, ["", []])[0] = txt
|
| 363 |
elif p.suffix.lower() == ".json":
|
| 364 |
-
sample_id = p.stem
|
| 365 |
try:
|
| 366 |
js = json.loads(zf.read(n).decode("utf-8", "replace"))
|
| 367 |
except Exception:
|
| 368 |
js = []
|
| 369 |
-
samples.setdefault(
|
| 370 |
return samples
|
| 371 |
|
| 372 |
-
#
|
| 373 |
-
# Core Inference (shared)
|
| 374 |
-
# =======================
|
| 375 |
DEFAULT_SYSTEM = (
|
| 376 |
"You are a task extraction assistant. "
|
| 377 |
"Always output valid JSON with a field \"labels\" (list of strings). "
|
|
@@ -386,6 +361,7 @@ DEFAULT_CONTEXT = (
|
|
| 386 |
"- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
|
| 387 |
)
|
| 388 |
|
|
|
|
| 389 |
def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window: int,
|
| 390 |
add_cues: bool, strip_smalltalk: bool, tokenizer) -> Tuple[str, int, int]:
|
| 391 |
before = len(tokenizer(raw_txt, return_tensors=None, add_special_tokens=False).input_ids)
|
|
@@ -395,10 +371,10 @@ def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window
|
|
| 395 |
lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
|
| 396 |
cue_lines = find_cue_lines(lines)
|
| 397 |
if cue_lines:
|
| 398 |
-
|
| 399 |
else:
|
| 400 |
-
|
| 401 |
-
t_kept = "\n".join(
|
| 402 |
cues = extract_cues(t_kept)
|
| 403 |
header = build_cues_header(cues) if add_cues else ""
|
| 404 |
proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
|
|
@@ -419,9 +395,7 @@ def explain_params_markdown() -> str:
|
|
| 419 |
"- **Load in 4-bit (GPU only)**: memory-saving quantization; has no effect on CPU Spaces."
|
| 420 |
)
|
| 421 |
|
| 422 |
-
#
|
| 423 |
-
# Single Transcript Mode
|
| 424 |
-
# =======================
|
| 425 |
def single_mode(
|
| 426 |
preset_model: str, custom_model: str,
|
| 427 |
system_text: str, context_text: str,
|
|
@@ -432,14 +406,14 @@ def single_mode(
|
|
| 432 |
):
|
| 433 |
repo_id = custom_model.strip() or preset_model.strip()
|
| 434 |
if not repo_id:
|
| 435 |
-
return "Please choose a model.", "", "", "", None, None, None
|
| 436 |
|
| 437 |
txt = (transcript_text or "").strip()
|
| 438 |
if transcript_file and hasattr(transcript_file, "name") and os.path.exists(transcript_file.name):
|
| 439 |
with open(transcript_file.name, "r", encoding="utf-8", errors="replace") as f:
|
| 440 |
txt = f.read()
|
| 441 |
if not txt:
|
| 442 |
-
return "Please paste a transcript or upload a .txt file.", "", "", "", None, None, None
|
| 443 |
|
| 444 |
exp = []
|
| 445 |
if expected_labels_json and hasattr(expected_labels_json, "name") and os.path.exists(expected_labels_json.name):
|
|
@@ -449,27 +423,27 @@ def single_mode(
|
|
| 449 |
except Exception:
|
| 450 |
exp = []
|
| 451 |
|
| 452 |
-
# tokenizer for preprocessing
|
| 453 |
try:
|
| 454 |
-
dummy_tok = AutoTokenizer.from_pretrained(
|
| 455 |
-
repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN
|
| 456 |
-
)
|
| 457 |
except Exception as e:
|
| 458 |
-
msg = ("Failed to load tokenizer for `{}`.
|
| 459 |
-
"Space β Settings β Secrets.\n\nError:
|
| 460 |
-
return msg, "", "", "", None, None, None
|
| 461 |
|
| 462 |
proc_text, tok_before, tok_after = prepare_input_text(
|
| 463 |
txt, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
|
| 464 |
)
|
| 465 |
-
user = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
|
| 466 |
system = (system_text or DEFAULT_SYSTEM).strip()
|
|
|
|
| 467 |
|
| 468 |
try:
|
| 469 |
-
latency_ms, raw_text,
|
|
|
|
|
|
|
| 470 |
except Exception as e:
|
| 471 |
-
msg = ("Failed to run `{}`. If gated, accept license and set HF_TOKEN.\n\nError: {}")
|
| 472 |
-
return msg, "", "", "", None, None, None
|
| 473 |
|
| 474 |
out = safe_json_load(raw_text)
|
| 475 |
pred_labels = enforce_rules(out.get("labels", []), proc_text)
|
|
@@ -499,12 +473,8 @@ def single_mode(
|
|
| 499 |
"model_calls": 1
|
| 500 |
},
|
| 501 |
"evaluation": None if not exp else {
|
| 502 |
-
"exact_match": exact,
|
| 503 |
-
"
|
| 504 |
-
"recall": rec,
|
| 505 |
-
"f1": f1,
|
| 506 |
-
"hamming": ham,
|
| 507 |
-
"ubs_score": ubs
|
| 508 |
}
|
| 509 |
}
|
| 510 |
zout.writestr("FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
|
|
@@ -526,45 +496,45 @@ def single_mode(
|
|
| 526 |
"ubs_score": round(ubs,6) if ubs is not None else None
|
| 527 |
}])
|
| 528 |
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
-
#
|
| 535 |
-
# Batch Mode (ZIP)
|
| 536 |
-
# =======================
|
| 537 |
def run_batch_ui(models_list, custom_models_str, instructions_text, context_text, dataset_zip,
|
| 538 |
soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
|
| 539 |
repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
|
| 540 |
|
| 541 |
models = [m for m in (models_list or [])]
|
| 542 |
-
|
| 543 |
-
models.extend(custom)
|
| 544 |
-
models = [m for m in models if m]
|
| 545 |
if not models:
|
| 546 |
-
return pd.DataFrame(), None, None, "Please pick at least one model."
|
| 547 |
|
| 548 |
if not dataset_zip:
|
| 549 |
-
return pd.DataFrame(), None, None, "Please upload a ZIP with *.txt (+ optional matching *.json)."
|
| 550 |
|
| 551 |
try:
|
| 552 |
zip_bytes = _read_zip_bytes(dataset_zip)
|
| 553 |
samples = parse_zip(zip_bytes)
|
| 554 |
except Exception as e:
|
| 555 |
-
return pd.DataFrame(), None, None, f"Failed to read ZIP: {e}"
|
| 556 |
|
| 557 |
-
rows = []
|
| 558 |
-
total_runs = 0
|
| 559 |
all_artifacts = io.BytesIO()
|
| 560 |
zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
|
|
|
|
| 561 |
|
| 562 |
for repo_id in models:
|
|
|
|
| 563 |
try:
|
| 564 |
-
dummy_tok = AutoTokenizer.from_pretrained(
|
| 565 |
-
repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN
|
| 566 |
-
)
|
| 567 |
except Exception as e:
|
|
|
|
| 568 |
rows.append({
|
| 569 |
"timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
|
| 570 |
"sample_id": None,
|
|
@@ -593,14 +563,10 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
|
|
| 593 |
continue
|
| 594 |
|
| 595 |
for sample_id, (transcript_text, exp_labels) in samples.items():
|
| 596 |
-
if not transcript_text.strip():
|
| 597 |
-
|
| 598 |
-
latencies = []
|
| 599 |
-
last_pred = None
|
| 600 |
for r in range(1, repeats+1):
|
| 601 |
-
if total_runs >= max_total_runs:
|
| 602 |
-
break
|
| 603 |
-
|
| 604 |
proc_text, before_tok, after_tok = prepare_input_text(
|
| 605 |
transcript_text, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
|
| 606 |
)
|
|
@@ -608,7 +574,10 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
|
|
| 608 |
user_text = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
|
| 609 |
|
| 610 |
try:
|
| 611 |
-
latency_ms, raw_text,
|
|
|
|
|
|
|
|
|
|
| 612 |
except Exception as e:
|
| 613 |
base = f"{repo_id.replace('/','_')}/{sample_id}/error_r{r}"
|
| 614 |
zout.writestr(base + "/ERROR.txt", f"Failed to run model via @spaces.GPU. If gated, accept license and set HF_TOKEN.\n\n{e}")
|
|
@@ -706,14 +675,35 @@ def run_batch_ui(models_list, custom_models_str, instructions_text, context_text
|
|
| 706 |
zout.close()
|
| 707 |
df = pd.DataFrame(rows)
|
| 708 |
if df.empty:
|
| 709 |
-
return pd.DataFrame(), None, None, "No runs executed (empty dataset / exceeded cap / gated models)."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
-
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
DARK_RED_CSS = """
|
| 718 |
:root, .gradio-container {
|
| 719 |
--color-background: #0b0b0d;
|
|
@@ -741,44 +731,36 @@ button, .gr-button {
|
|
| 741 |
}
|
| 742 |
"""
|
| 743 |
|
| 744 |
-
PRESET_MODELS = [
|
| 745 |
-
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 746 |
-
"Qwen/Qwen2.5-7B-Instruct",
|
| 747 |
-
"HuggingFaceH4/zephyr-7b-beta",
|
| 748 |
-
"tiiuae/falcon-7b-instruct"
|
| 749 |
-
]
|
| 750 |
-
|
| 751 |
-
DEFAULT_SYSTEM = (
|
| 752 |
-
"You are a task extraction assistant. "
|
| 753 |
-
"Always output valid JSON with a field \"labels\" (list of strings). "
|
| 754 |
-
"Use only from this set: " + json.dumps(ALLOWED_LABELS) + ". "
|
| 755 |
-
"Return JSON only."
|
| 756 |
-
)
|
| 757 |
-
DEFAULT_CONTEXT = (
|
| 758 |
-
"- plan_contact: conversation without a concrete meeting (no date/time)\n"
|
| 759 |
-
"- schedule_meeting: explicit date/time/modality confirmation\n"
|
| 760 |
-
"- update_contact_info_non_postal: changes to email/phone\n"
|
| 761 |
-
"- update_contact_info_postal_address: changes to mailing address\n"
|
| 762 |
-
"- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
|
| 763 |
-
)
|
| 764 |
-
|
| 765 |
with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo:
|
| 766 |
gr.Markdown("## π₯ From Talk to Task β Batch & Single Task Extraction")
|
| 767 |
-
|
| 768 |
-
"This tool extracts **task labels** from
|
| 769 |
"1) Pick a model (or paste a custom repo id). \n"
|
| 770 |
"2) Provide **Instructions** and **Context**, then supply a transcript (single) or a ZIP (batch). \n"
|
| 771 |
"3) Adjust parameters (soft token cap, preprocessing). \n"
|
| 772 |
-
"4) Run and review **latency**, **precision/recall/F1**, **UBS score**, and download artifacts
|
| 773 |
-
"_ZeroGPU-ready: model calls run inside an @spaces.GPU function when available._"
|
| 774 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
with gr.Tabs():
|
| 777 |
# Single
|
| 778 |
with gr.TabItem("Single Transcript (default)"):
|
| 779 |
with gr.Row():
|
| 780 |
with gr.Column():
|
| 781 |
-
preset_model = gr.Dropdown(choices=
|
|
|
|
| 782 |
custom_model = gr.Textbox(label="Custom model repo id (overrides preset)",
|
| 783 |
placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct")
|
| 784 |
instructions = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
|
|
@@ -795,6 +777,7 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 795 |
pre_window_s = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
|
| 796 |
add_cues_s = gr.Checkbox(value=True, label="Add cues header")
|
| 797 |
strip_smalltalk_s = gr.Checkbox(value=False, label="Strip smalltalk")
|
|
|
|
| 798 |
with gr.Column():
|
| 799 |
load_4bit_s = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
|
| 800 |
dtype_s = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
|
|
@@ -808,11 +791,8 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 808 |
single_status = gr.Markdown("")
|
| 809 |
|
| 810 |
def _run_single(*args):
|
| 811 |
-
status, m1, m2, m3, df, csv_buf, zip_buf = single_mode(*args)
|
| 812 |
-
if isinstance(df, pd.DataFrame)
|
| 813 |
-
return m1, m2, m3, df, csv_buf, zip_buf, status
|
| 814 |
-
else:
|
| 815 |
-
return m1 or "", m2 or "", m3 or "", pd.DataFrame(), None, None, status
|
| 816 |
|
| 817 |
run_single_btn.click(
|
| 818 |
_run_single,
|
|
@@ -820,7 +800,7 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 820 |
transcript_text, transcript_file, expected_labels_json,
|
| 821 |
soft_cap_s, preprocess_s, pre_window_s, add_cues_s, strip_smalltalk_s,
|
| 822 |
load_4bit_s, dtype_s, trust_remote_code_s],
|
| 823 |
-
outputs=[kpi1, kpi2, kpi3, single_table, single_csv, single_zip, single_status]
|
| 824 |
)
|
| 825 |
|
| 826 |
# Batch
|
|
@@ -828,7 +808,8 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 828 |
with gr.Row():
|
| 829 |
with gr.Column():
|
| 830 |
models_list = gr.Checkboxgroup(
|
| 831 |
-
choices=
|
|
|
|
| 832 |
)
|
| 833 |
custom_models = gr.Textbox(label="Custom model repo ids (comma-separated)",
|
| 834 |
placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct")
|
|
@@ -839,6 +820,7 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 839 |
label="Upload ZIP of transcripts (*.txt) + expected (*.json)",
|
| 840 |
file_types=[".zip"], file_count="single", type="filepath"
|
| 841 |
)
|
|
|
|
| 842 |
|
| 843 |
with gr.Row():
|
| 844 |
with gr.Column():
|
|
@@ -847,6 +829,7 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 847 |
pre_window = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
|
| 848 |
add_cues = gr.Checkbox(value=True, label="Add cues header")
|
| 849 |
strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
|
|
|
|
| 850 |
with gr.Column():
|
| 851 |
repeats = gr.Slider(1, 6, value=3, step=1, label="Repeats per config")
|
| 852 |
max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
|
|
@@ -862,7 +845,7 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 862 |
status = gr.Markdown("")
|
| 863 |
|
| 864 |
def _run_batch(*args):
|
| 865 |
-
df, csv_pair, zip_pair, msg = run_batch_ui(*args)
|
| 866 |
m1 = m2 = m3 = ""
|
| 867 |
if isinstance(df, pd.DataFrame) and not df.empty:
|
| 868 |
summaries = df[df["is_summary"] == True]
|
|
@@ -874,19 +857,17 @@ with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo
|
|
| 874 |
m3 = f"**Median latency (ms)**\n\n{int(med) if pd.notna(med) else 'β'}"
|
| 875 |
csv_buf = zip_buf = None
|
| 876 |
if isinstance(csv_pair, tuple):
|
| 877 |
-
name, data = csv_pair
|
| 878 |
-
csv_buf = io.BytesIO(data); csv_buf.name = name
|
| 879 |
if isinstance(zip_pair, tuple):
|
| 880 |
-
name, data = zip_pair
|
| 881 |
-
|
| 882 |
-
return m1, m2, m3, df, csv_buf, zip_buf, msg
|
| 883 |
|
| 884 |
run_btn.click(
|
| 885 |
_run_batch,
|
| 886 |
inputs=[models_list, custom_models, instructions_b, context_b, dataset_zip,
|
| 887 |
soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
|
| 888 |
repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
|
| 889 |
-
outputs=[kpi_b1, kpi_b2, kpi_b3, table, csv_dl, zip_dl, status]
|
| 890 |
)
|
| 891 |
|
| 892 |
demo.launch()
|
|
|
|
| 1 |
import os, io, re, sys, time, json, zipfile, statistics
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import List, Dict, Tuple, Union
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import pandas as pd
|
| 7 |
import torch
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 9 |
|
| 10 |
+
# ========= ZeroGPU support =========
|
|
|
|
|
|
|
| 11 |
try:
|
| 12 |
+
import spaces # available on HF Spaces
|
| 13 |
except Exception:
|
| 14 |
class _DummySpaces:
|
| 15 |
def GPU(self, *args, **kwargs):
|
|
|
|
| 17 |
return deco
|
| 18 |
spaces = _DummySpaces()
|
| 19 |
|
| 20 |
+
# ========= Auth token =========
|
| 21 |
HF_TOKEN = (
|
| 22 |
os.getenv("HF_TOKEN")
|
| 23 |
or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 24 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 25 |
)
|
| 26 |
|
| 27 |
+
# Console warning at startup (helps when logs are open)
|
| 28 |
+
if not HF_TOKEN:
|
| 29 |
+
print(
|
| 30 |
+
"[WARN] HF_TOKEN is not set. Gated models will fail. "
|
| 31 |
+
"Set it in Space β Settings β Variables and secrets.",
|
| 32 |
+
file=sys.stderr
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# ========= Labels & metrics =========
|
| 36 |
ALLOWED_LABELS = [
|
| 37 |
"plan_contact",
|
| 38 |
"schedule_meeting",
|
|
|
|
| 43 |
"update_kyc_purpose_of_businessrelation",
|
| 44 |
"update_kyc_total_assets",
|
| 45 |
]
|
| 46 |
+
LABEL_TO_IDX = {l: i for i, l in enumerate(ALLOWED_LABELS)}
|
| 47 |
FN_PENALTY = 2.0
|
| 48 |
FP_PENALTY = 1.0
|
| 49 |
|
|
|
|
| 52 |
return json.loads(s)
|
| 53 |
except Exception:
|
| 54 |
pass
|
| 55 |
+
m = re.search(r"\{.*\}", s, re.S)
|
| 56 |
if m:
|
| 57 |
try:
|
| 58 |
return json.loads(m.group(0))
|
|
|
|
| 66 |
for it in x:
|
| 67 |
if isinstance(it, str): out.append(it)
|
| 68 |
elif isinstance(it, dict):
|
| 69 |
+
for k in ("label", "value", "task", "category", "name"):
|
| 70 |
v = it.get(k)
|
| 71 |
if isinstance(v, str):
|
| 72 |
out.append(v); break
|
| 73 |
else:
|
| 74 |
if isinstance(it.get("labels"), list):
|
| 75 |
out += [s for s in it["labels"] if isinstance(s, str)]
|
| 76 |
+
# dedupe keep order
|
| 77 |
+
seen = set(); norm = []
|
| 78 |
for s in out:
|
| 79 |
if s not in seen:
|
| 80 |
norm.append(s); seen.add(s)
|
| 81 |
return norm
|
| 82 |
if isinstance(x, dict):
|
| 83 |
+
for k in ("expected_labels", "labels", "targets", "y_true"):
|
| 84 |
if k in x: return _coerce_labels_list(x[k])
|
| 85 |
if "one_hot" in x and isinstance(x["one_hot"], dict):
|
| 86 |
+
return [k for k, v in x["one_hot"].items() if v]
|
| 87 |
return []
|
| 88 |
|
| 89 |
def classic_metrics(pred_labels, exp_labels):
|
| 90 |
+
pred = set([str(x) for x in (pred_labels or []) if isinstance(x, (str,int,float,bool))])
|
| 91 |
+
gold = set([str(x) for x in (exp_labels or []) if isinstance(x, (str,int,float,bool))])
|
|
|
|
| 92 |
if not pred and not gold:
|
| 93 |
return True, 1.0, 1.0, 1.0, 1.0
|
| 94 |
inter = pred & gold; union = pred | gold
|
|
|
|
| 112 |
score = 1.0 if max_err == 0 else (1.0 - (weighted / max_err))
|
| 113 |
return float(max(0.0, min(1.0, score)))
|
| 114 |
|
| 115 |
+
# ========= Lightweight preprocessing =========
|
|
|
|
|
|
|
| 116 |
EMAIL_RX = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
|
| 117 |
TIME_RX = re.compile(r'\b(\d{1,2}:\d{2}\b|\b\d{1,2}\s?(am|pm)\b|\bafternoon\b|\bmorning\b|\bevening\b)', re.I)
|
| 118 |
DATE_RX = re.compile(r'\b(jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\b|\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b|\b20\d{2}\b', re.I)
|
|
|
|
| 203 |
ids = tokenizer(text, return_tensors=None, add_special_tokens=False).input_ids
|
| 204 |
est = len(ids)
|
| 205 |
threshold = int(soft_cap_tokens * apply_only_if_ratio)
|
| 206 |
+
if est <= threshold: return text
|
|
|
|
| 207 |
parts = text.splitlines()
|
| 208 |
+
if len(parts) <= min_lines_keep: return text
|
|
|
|
| 209 |
|
| 210 |
keep_flags=[]
|
| 211 |
for ln in parts:
|
|
|
|
| 230 |
candidate2_tokens = len(tokenizer(candidate2, return_tensors=None, add_special_tokens=False).input_ids)
|
| 231 |
candidate = candidate if cand_tokens <= candidate2_tokens else candidate2
|
| 232 |
|
| 233 |
+
if len(candidate.splitlines()) < min_lines_keep: return text
|
|
|
|
| 234 |
return candidate
|
| 235 |
|
| 236 |
def enforce_rules(labels, transcript_text):
|
| 237 |
labels = set(labels or [])
|
| 238 |
if (TIME_RX.search(transcript_text) or DATE_RX.search(transcript_text)) and MEET_RX.search(transcript_text):
|
| 239 |
+
labels.add("schedule_meeting"); labels.discard("plan_contact")
|
|
|
|
| 240 |
if EMAIL_RX.search(transcript_text) and re.search(r'\b(update|new|set|change|confirm(ed)?|for all communication)\b', transcript_text, re.I):
|
| 241 |
labels.add("update_contact_info_non_postal")
|
| 242 |
kyc_rx = re.compile(r'\b(kyc|aml|compliance|employer|occupation|purpose of (relationship|account)|source of (wealth|funds)|net worth|total assets)\b', re.I)
|
|
|
|
| 244 |
labels.discard("update_kyc_activity")
|
| 245 |
return sorted(labels)
|
| 246 |
|
| 247 |
+
# ========= HF model wrapper =========
|
|
|
|
|
|
|
| 248 |
class HFModel:
|
| 249 |
def __init__(self, repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
|
| 250 |
self.repo_id = repo_id
|
|
|
|
| 256 |
self.model = None
|
| 257 |
if load_4bit:
|
| 258 |
try:
|
| 259 |
+
q = BitsAndBytesConfig(
|
| 260 |
+
load_in_4bit=True, bnb_4bit_use_double_quant=True,
|
| 261 |
+
bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_quant_type="nf4"
|
|
|
|
|
|
|
| 262 |
)
|
| 263 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 264 |
repo_id, device_map="auto", trust_remote_code=trust_remote_code,
|
| 265 |
+
quantization_config=q, torch_dtype=torch_dtype, token=HF_TOKEN
|
| 266 |
)
|
| 267 |
except Exception as e:
|
| 268 |
print(f"[WARN] 4-bit load failed for {repo_id}: {e}\nFalling back to normal load...", file=sys.stderr)
|
|
|
|
| 269 |
if self.model is None:
|
| 270 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 271 |
repo_id, device_map="auto", trust_remote_code=trust_remote_code,
|
|
|
|
| 275 |
self.max_context = getattr(self.model.config, "max_position_embeddings", None) \
|
| 276 |
or getattr(self.model.config, "max_sequence_length", None) or 8192
|
| 277 |
|
|
|
|
|
|
|
|
|
|
| 278 |
def apply_chat_template(self, system_text: str, user_text: str) -> str:
|
| 279 |
if getattr(self.tokenizer, "chat_template", None):
|
| 280 |
messages = [{"role":"system","content":system_text},
|
|
|
|
| 290 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 291 |
t0 = time.perf_counter()
|
| 292 |
out = self.model.generate(
|
| 293 |
+
**inputs, max_new_tokens=max_new_tokens,
|
| 294 |
+
do_sample=False, temperature=None, top_p=None,
|
| 295 |
+
eos_token_id=self.tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
| 296 |
)
|
| 297 |
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 298 |
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
|
| 299 |
+
if text.startswith(prompt): text = text[len(prompt):]
|
|
|
|
| 300 |
return latency_ms, text, prompt
|
| 301 |
|
|
|
|
| 302 |
MODEL_CACHE: Dict[str, HFModel] = {}
|
|
|
|
| 303 |
def get_model(repo_id: str, load_4bit: bool, dtype: str, trust_remote_code: bool):
|
| 304 |
if repo_id not in MODEL_CACHE:
|
| 305 |
MODEL_CACHE[repo_id] = HFModel(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
|
| 306 |
return MODEL_CACHE[repo_id]
|
| 307 |
|
| 308 |
+
# ========= ZeroGPU functions =========
|
| 309 |
+
@spaces.GPU(duration=180, secrets=["HF_TOKEN"]) # pass token into ZeroGPU job
|
|
|
|
|
|
|
| 310 |
def gpu_generate(repo_id: str, system_text: str, user_text: str,
|
| 311 |
load_4bit: bool, dtype: str, trust_remote_code: bool):
|
| 312 |
+
token_seen = bool(os.getenv("HF_TOKEN"))
|
| 313 |
hf = get_model(repo_id, load_4bit=load_4bit, dtype=dtype, trust_remote_code=trust_remote_code)
|
| 314 |
+
lat, txt, prmpt = hf.generate_json(system_text.strip(), user_text.strip(), max_new_tokens=256)
|
| 315 |
+
return lat, txt, prmpt, token_seen
|
| 316 |
+
|
| 317 |
+
@spaces.GPU(duration=15, secrets=["HF_TOKEN"])
|
| 318 |
+
def gpu_check_token():
|
| 319 |
+
return bool(os.getenv("HF_TOKEN"))
|
| 320 |
|
| 321 |
+
# ========= ZIP helpers =========
|
|
|
|
|
|
|
| 322 |
def _read_zip_bytes(dataset_zip: Union[bytes, str, dict, None]) -> bytes:
|
| 323 |
+
if dataset_zip is None: raise ValueError("No ZIP provided")
|
| 324 |
+
if isinstance(dataset_zip, bytes): return dataset_zip
|
|
|
|
|
|
|
| 325 |
if isinstance(dataset_zip, str):
|
| 326 |
+
with open(dataset_zip, "rb") as f: return f.read()
|
|
|
|
| 327 |
if isinstance(dataset_zip, dict) and "path" in dataset_zip:
|
| 328 |
+
with open(dataset_zip["path"], "rb") as f: return f.read()
|
|
|
|
| 329 |
path = getattr(dataset_zip, "name", None)
|
| 330 |
if path and os.path.exists(path):
|
| 331 |
+
with open(path, "rb") as f: return f.read()
|
| 332 |
+
raise ValueError("Unsupported file object from Gradio")
|
|
|
|
| 333 |
|
| 334 |
def parse_zip(zip_bytes: bytes) -> Dict[str, Tuple[str, List[str]]]:
|
| 335 |
zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
|
|
|
|
| 336 |
samples = {}
|
| 337 |
+
for n in zf.namelist():
|
| 338 |
p = Path(n)
|
| 339 |
if p.suffix.lower() == ".txt":
|
| 340 |
+
samples.setdefault(p.stem, ["", []])[0] = zf.read(n).decode("utf-8", "replace")
|
|
|
|
|
|
|
| 341 |
elif p.suffix.lower() == ".json":
|
|
|
|
| 342 |
try:
|
| 343 |
js = json.loads(zf.read(n).decode("utf-8", "replace"))
|
| 344 |
except Exception:
|
| 345 |
js = []
|
| 346 |
+
samples.setdefault(p.stem, ["", []])[1] = _coerce_labels_list(js)
|
| 347 |
return samples
|
| 348 |
|
| 349 |
+
# ========= Prompts =========
|
|
|
|
|
|
|
| 350 |
DEFAULT_SYSTEM = (
|
| 351 |
"You are a task extraction assistant. "
|
| 352 |
"Always output valid JSON with a field \"labels\" (list of strings). "
|
|
|
|
| 361 |
"- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)"
|
| 362 |
)
|
| 363 |
|
| 364 |
+
# ========= Preprocess + build input =========
|
| 365 |
def prepare_input_text(raw_txt: str, soft_cap: int, preprocess: bool, pre_window: int,
|
| 366 |
add_cues: bool, strip_smalltalk: bool, tokenizer) -> Tuple[str, int, int]:
|
| 367 |
before = len(tokenizer(raw_txt, return_tensors=None, add_special_tokens=False).input_ids)
|
|
|
|
| 371 |
lines = [ln.strip() for ln in t_norm.splitlines() if ln.strip()]
|
| 372 |
cue_lines = find_cue_lines(lines)
|
| 373 |
if cue_lines:
|
| 374 |
+
kept = prune_by_window(lines, cue_lines, window=pre_window, strip_smalltalk=strip_smalltalk)
|
| 375 |
else:
|
| 376 |
+
kept = [ln for ln in lines if not (strip_smalltalk and SMALLTALK_RX.search(ln))]
|
| 377 |
+
t_kept = "\n".join(kept)
|
| 378 |
cues = extract_cues(t_kept)
|
| 379 |
header = build_cues_header(cues) if add_cues else ""
|
| 380 |
proc_text = (header + "\n\n" + t_kept).strip() if header else t_kept
|
|
|
|
| 395 |
"- **Load in 4-bit (GPU only)**: memory-saving quantization; has no effect on CPU Spaces."
|
| 396 |
)
|
| 397 |
|
| 398 |
+
# ========= Single mode =========
|
|
|
|
|
|
|
| 399 |
def single_mode(
|
| 400 |
preset_model: str, custom_model: str,
|
| 401 |
system_text: str, context_text: str,
|
|
|
|
| 406 |
):
|
| 407 |
repo_id = custom_model.strip() or preset_model.strip()
|
| 408 |
if not repo_id:
|
| 409 |
+
return "Please choose a model.", "", "", "", None, None, None, ""
|
| 410 |
|
| 411 |
txt = (transcript_text or "").strip()
|
| 412 |
if transcript_file and hasattr(transcript_file, "name") and os.path.exists(transcript_file.name):
|
| 413 |
with open(transcript_file.name, "r", encoding="utf-8", errors="replace") as f:
|
| 414 |
txt = f.read()
|
| 415 |
if not txt:
|
| 416 |
+
return "Please paste a transcript or upload a .txt file.", "", "", "", None, None, None, ""
|
| 417 |
|
| 418 |
exp = []
|
| 419 |
if expected_labels_json and hasattr(expected_labels_json, "name") and os.path.exists(expected_labels_json.name):
|
|
|
|
| 423 |
except Exception:
|
| 424 |
exp = []
|
| 425 |
|
| 426 |
+
# tokenizer for preprocessing
|
| 427 |
try:
|
| 428 |
+
dummy_tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
|
|
|
|
|
|
|
| 429 |
except Exception as e:
|
| 430 |
+
msg = (f"Failed to load tokenizer for `{repo_id}`. "
|
| 431 |
+
"If gated, accept license and set HF_TOKEN in Space β Settings β Secrets.\n\nError: " + str(e))
|
| 432 |
+
return msg, "", "", "", None, None, None, banner_text()
|
| 433 |
|
| 434 |
proc_text, tok_before, tok_after = prepare_input_text(
|
| 435 |
txt, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
|
| 436 |
)
|
|
|
|
| 437 |
system = (system_text or DEFAULT_SYSTEM).strip()
|
| 438 |
+
user = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
|
| 439 |
|
| 440 |
try:
|
| 441 |
+
latency_ms, raw_text, _prompt, gpu_token_seen = gpu_generate(
|
| 442 |
+
repo_id, system, user, load_4bit, dtype, trust_remote_code
|
| 443 |
+
)
|
| 444 |
except Exception as e:
|
| 445 |
+
msg = (f"Failed to run `{repo_id}`. If gated, accept license and set HF_TOKEN.\n\nError: {e}")
|
| 446 |
+
return msg, "", "", "", None, None, None, banner_text()
|
| 447 |
|
| 448 |
out = safe_json_load(raw_text)
|
| 449 |
pred_labels = enforce_rules(out.get("labels", []), proc_text)
|
|
|
|
| 473 |
"model_calls": 1
|
| 474 |
},
|
| 475 |
"evaluation": None if not exp else {
|
| 476 |
+
"exact_match": exact, "precision": prec, "recall": rec,
|
| 477 |
+
"f1": f1, "hamming": ham, "ubs_score": ubs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
}
|
| 479 |
}
|
| 480 |
zout.writestr("FINAL.json", json.dumps(final_json, ensure_ascii=False, indent=2))
|
|
|
|
| 496 |
"ubs_score": round(ubs,6) if ubs is not None else None
|
| 497 |
}])
|
| 498 |
|
| 499 |
+
csv_buf = io.BytesIO(row.to_csv(index=False).encode("utf-8")); csv_buf.name = "results_single.csv"
|
| 500 |
+
|
| 501 |
+
return (
|
| 502 |
+
"Done.",
|
| 503 |
+
kpi1, kpi2, kpi3,
|
| 504 |
+
row, csv_buf, zbuf,
|
| 505 |
+
banner_text(gpu_token_seen)
|
| 506 |
+
)
|
| 507 |
|
| 508 |
+
# ========= Batch mode =========
|
|
|
|
|
|
|
| 509 |
def run_batch_ui(models_list, custom_models_str, instructions_text, context_text, dataset_zip,
|
| 510 |
soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
|
| 511 |
repeats, max_total_runs, load_4bit, dtype, trust_remote_code):
|
| 512 |
|
| 513 |
models = [m for m in (models_list or [])]
|
| 514 |
+
models += [m.strip() for m in (custom_models_str or "").split(",") if m.strip()]
|
|
|
|
|
|
|
| 515 |
if not models:
|
| 516 |
+
return pd.DataFrame(), None, None, "Please pick at least one model.", banner_text()
|
| 517 |
|
| 518 |
if not dataset_zip:
|
| 519 |
+
return pd.DataFrame(), None, None, "Please upload a ZIP with *.txt (+ optional matching *.json).", banner_text()
|
| 520 |
|
| 521 |
try:
|
| 522 |
zip_bytes = _read_zip_bytes(dataset_zip)
|
| 523 |
samples = parse_zip(zip_bytes)
|
| 524 |
except Exception as e:
|
| 525 |
+
return pd.DataFrame(), None, None, f"Failed to read ZIP: {e}", banner_text()
|
| 526 |
|
| 527 |
+
rows = []; total_runs = 0
|
|
|
|
| 528 |
all_artifacts = io.BytesIO()
|
| 529 |
zout = zipfile.ZipFile(all_artifacts, "w", zipfile.ZIP_DEFLATED)
|
| 530 |
+
last_gpu_token_seen = None
|
| 531 |
|
| 532 |
for repo_id in models:
|
| 533 |
+
# tokenizer for preprocessing (auth check)
|
| 534 |
try:
|
| 535 |
+
dummy_tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=trust_remote_code, token=HF_TOKEN)
|
|
|
|
|
|
|
| 536 |
except Exception as e:
|
| 537 |
+
# gated or missing token; record a summary row and continue
|
| 538 |
rows.append({
|
| 539 |
"timestamp": pd.Timestamp.now().isoformat(timespec="seconds"),
|
| 540 |
"sample_id": None,
|
|
|
|
| 563 |
continue
|
| 564 |
|
| 565 |
for sample_id, (transcript_text, exp_labels) in samples.items():
|
| 566 |
+
if not transcript_text.strip(): continue
|
| 567 |
+
latencies = []; last_pred = None
|
|
|
|
|
|
|
| 568 |
for r in range(1, repeats+1):
|
| 569 |
+
if total_runs >= max_total_runs: break
|
|
|
|
|
|
|
| 570 |
proc_text, before_tok, after_tok = prepare_input_text(
|
| 571 |
transcript_text, soft_cap, preprocess, pre_window, add_cues, strip_smalltalk, dummy_tok
|
| 572 |
)
|
|
|
|
| 574 |
user_text = (context_text or DEFAULT_CONTEXT).strip() + "\n\nTRANSCRIPT\n" + proc_text.strip()
|
| 575 |
|
| 576 |
try:
|
| 577 |
+
latency_ms, raw_text, _prompt, token_seen = gpu_generate(
|
| 578 |
+
repo_id, system_text, user_text, load_4bit, dtype, trust_remote_code
|
| 579 |
+
)
|
| 580 |
+
last_gpu_token_seen = token_seen
|
| 581 |
except Exception as e:
|
| 582 |
base = f"{repo_id.replace('/','_')}/{sample_id}/error_r{r}"
|
| 583 |
zout.writestr(base + "/ERROR.txt", f"Failed to run model via @spaces.GPU. If gated, accept license and set HF_TOKEN.\n\n{e}")
|
|
|
|
| 675 |
zout.close()
|
| 676 |
df = pd.DataFrame(rows)
|
| 677 |
if df.empty:
|
| 678 |
+
return pd.DataFrame(), None, None, "No runs executed (empty dataset / exceeded cap / gated models).", banner_text(last_gpu_token_seen)
|
| 679 |
+
|
| 680 |
+
csv_pair = ("results.csv", df.to_csv(index=False).encode("utf-8"))
|
| 681 |
+
zip_pair = ("artifacts.zip", all_artifacts.getvalue())
|
| 682 |
+
return df, csv_pair, zip_pair, "Done.", banner_text(last_gpu_token_seen)
|
| 683 |
|
| 684 |
+
# ========= UI helpers =========
|
| 685 |
+
OPEN_MODEL_PRESETS = [
|
| 686 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 687 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 688 |
+
"HuggingFaceH4/zephyr-7b-beta",
|
| 689 |
+
"tiiuae/falcon-7b-instruct",
|
| 690 |
+
]
|
| 691 |
|
| 692 |
+
def banner_text(gpu_token_seen: bool | None = None) -> str:
|
| 693 |
+
app_seen = bool(HF_TOKEN)
|
| 694 |
+
lines = []
|
| 695 |
+
if not app_seen:
|
| 696 |
+
lines.append("π‘ **HF_TOKEN not detected in App** β gated models will fail unless you set it in **Settings β Variables and secrets**.")
|
| 697 |
+
else:
|
| 698 |
+
lines.append("π’ **HF_TOKEN detected in App**.")
|
| 699 |
+
if gpu_token_seen is None:
|
| 700 |
+
lines.append("βΉοΈ ZeroGPU token status: click **Run** or **Check ZeroGPU token** to verify.")
|
| 701 |
+
else:
|
| 702 |
+
lines.append("π’ **HF_TOKEN detected inside ZeroGPU job.**" if gpu_token_seen else "π΄ **HF_TOKEN missing inside ZeroGPU job** (add `secrets=[\"HF_TOKEN\"]` to @spaces.GPU).")
|
| 703 |
+
lines.append("β
Tip: use **Open models** (no license gating): " + ", ".join(OPEN_MODEL_PRESETS))
|
| 704 |
+
return "\n\n".join(lines)
|
| 705 |
+
|
| 706 |
+
# ========= UI (dark red) =========
|
| 707 |
DARK_RED_CSS = """
|
| 708 |
:root, .gradio-container {
|
| 709 |
--color-background: #0b0b0d;
|
|
|
|
| 731 |
}
|
| 732 |
"""
|
| 733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
with gr.Blocks(title="From Talk to Task β HF Space", css=DARK_RED_CSS) as demo:
|
| 735 |
gr.Markdown("## π₯ From Talk to Task β Batch & Single Task Extraction")
|
| 736 |
+
help_md = (
|
| 737 |
+
"This tool extracts **task labels** from transcripts using Hugging Face models. \n"
|
| 738 |
"1) Pick a model (or paste a custom repo id). \n"
|
| 739 |
"2) Provide **Instructions** and **Context**, then supply a transcript (single) or a ZIP (batch). \n"
|
| 740 |
"3) Adjust parameters (soft token cap, preprocessing). \n"
|
| 741 |
+
"4) Run and review **latency**, **precision/recall/F1**, **UBS score**, and download artifacts."
|
|
|
|
| 742 |
)
|
| 743 |
+
gr.Markdown(help_md)
|
| 744 |
+
|
| 745 |
+
# Status banner (token presence info)
|
| 746 |
+
banner = gr.Markdown(banner_text())
|
| 747 |
+
|
| 748 |
+
check_btn = gr.Button("Check ZeroGPU token")
|
| 749 |
+
def _check_token():
|
| 750 |
+
try:
|
| 751 |
+
present = gpu_check_token()
|
| 752 |
+
except Exception:
|
| 753 |
+
present = None
|
| 754 |
+
return banner_text(present)
|
| 755 |
+
check_btn.click(_check_token, outputs=banner)
|
| 756 |
|
| 757 |
with gr.Tabs():
|
| 758 |
# Single
|
| 759 |
with gr.TabItem("Single Transcript (default)"):
|
| 760 |
with gr.Row():
|
| 761 |
with gr.Column():
|
| 762 |
+
preset_model = gr.Dropdown(choices=OPEN_MODEL_PRESETS, value=OPEN_MODEL_PRESETS[0],
|
| 763 |
+
label="Model (Open presets β no gating)")
|
| 764 |
custom_model = gr.Textbox(label="Custom model repo id (overrides preset)",
|
| 765 |
placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct")
|
| 766 |
instructions = gr.Textbox(label="Instructions (System)", lines=8, value=DEFAULT_SYSTEM)
|
|
|
|
| 777 |
pre_window_s = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
|
| 778 |
add_cues_s = gr.Checkbox(value=True, label="Add cues header")
|
| 779 |
strip_smalltalk_s = gr.Checkbox(value=False, label="Strip smalltalk")
|
| 780 |
+
gr.Markdown(explain_params_markdown())
|
| 781 |
with gr.Column():
|
| 782 |
load_4bit_s = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
|
| 783 |
dtype_s = gr.Dropdown(choices=["bfloat16","float16","float32"], value="bfloat16", label="Compute dtype")
|
|
|
|
| 791 |
single_status = gr.Markdown("")
|
| 792 |
|
| 793 |
def _run_single(*args):
|
| 794 |
+
status, m1, m2, m3, df, csv_buf, zip_buf, btxt = single_mode(*args)
|
| 795 |
+
return m1 or "", m2 or "", m3 or "", (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (status or ""), (btxt or banner_text())
|
|
|
|
|
|
|
|
|
|
| 796 |
|
| 797 |
run_single_btn.click(
|
| 798 |
_run_single,
|
|
|
|
| 800 |
transcript_text, transcript_file, expected_labels_json,
|
| 801 |
soft_cap_s, preprocess_s, pre_window_s, add_cues_s, strip_smalltalk_s,
|
| 802 |
load_4bit_s, dtype_s, trust_remote_code_s],
|
| 803 |
+
outputs=[kpi1, kpi2, kpi3, single_table, single_csv, single_zip, single_status, banner]
|
| 804 |
)
|
| 805 |
|
| 806 |
# Batch
|
|
|
|
| 808 |
with gr.Row():
|
| 809 |
with gr.Column():
|
| 810 |
models_list = gr.Checkboxgroup(
|
| 811 |
+
choices=OPEN_MODEL_PRESETS, value=[OPEN_MODEL_PRESETS[0]],
|
| 812 |
+
label="Models (Open presets β select one or more)"
|
| 813 |
)
|
| 814 |
custom_models = gr.Textbox(label="Custom model repo ids (comma-separated)",
|
| 815 |
placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct")
|
|
|
|
| 820 |
label="Upload ZIP of transcripts (*.txt) + expected (*.json)",
|
| 821 |
file_types=[".zip"], file_count="single", type="filepath"
|
| 822 |
)
|
| 823 |
+
gr.Markdown("Zip must contain pairs like `ID.txt` and optional `ID.json` with expected labels (same base filename).")
|
| 824 |
|
| 825 |
with gr.Row():
|
| 826 |
with gr.Column():
|
|
|
|
| 829 |
pre_window = gr.Slider(0, 6, value=3, step=1, label="Window Β± lines around cues")
|
| 830 |
add_cues = gr.Checkbox(value=True, label="Add cues header")
|
| 831 |
strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
|
| 832 |
+
gr.Markdown(explain_params_markdown())
|
| 833 |
with gr.Column():
|
| 834 |
repeats = gr.Slider(1, 6, value=3, step=1, label="Repeats per config")
|
| 835 |
max_total_runs = gr.Slider(1, 200, value=40, step=1, label="Max total runs")
|
|
|
|
| 845 |
status = gr.Markdown("")
|
| 846 |
|
| 847 |
def _run_batch(*args):
|
| 848 |
+
df, csv_pair, zip_pair, msg, btxt = run_batch_ui(*args)
|
| 849 |
m1 = m2 = m3 = ""
|
| 850 |
if isinstance(df, pd.DataFrame) and not df.empty:
|
| 851 |
summaries = df[df["is_summary"] == True]
|
|
|
|
| 857 |
m3 = f"**Median latency (ms)**\n\n{int(med) if pd.notna(med) else 'β'}"
|
| 858 |
csv_buf = zip_buf = None
|
| 859 |
if isinstance(csv_pair, tuple):
|
| 860 |
+
name, data = csv_pair; csv_buf = io.BytesIO(data); csv_buf.name = name
|
|
|
|
| 861 |
if isinstance(zip_pair, tuple):
|
| 862 |
+
name, data = zip_pair; zip_buf = io.BytesIO(data); zip_buf.name = name
|
| 863 |
+
return m1, m2, m3, (df if isinstance(df, pd.DataFrame) else pd.DataFrame()), csv_buf, zip_buf, (msg or ""), (btxt or banner_text())
|
|
|
|
| 864 |
|
| 865 |
run_btn.click(
|
| 866 |
_run_batch,
|
| 867 |
inputs=[models_list, custom_models, instructions_b, context_b, dataset_zip,
|
| 868 |
soft_cap, preprocess, pre_window, add_cues, strip_smalltalk,
|
| 869 |
repeats, max_total_runs, load_4bit, dtype, trust_remote_code],
|
| 870 |
+
outputs=[kpi_b1, kpi_b2, kpi_b3, table, csv_dl, zip_dl, status, banner]
|
| 871 |
)
|
| 872 |
|
| 873 |
demo.launch()
|