Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -209,21 +209,35 @@ def clean_transcript(text: str) -> str:
|
|
| 209 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 210 |
return s
|
| 211 |
|
| 212 |
-
def
|
| 213 |
-
|
|
|
|
| 214 |
return ""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
if name.endswith(".json"):
|
| 218 |
try:
|
| 219 |
-
|
| 220 |
-
if isinstance(obj, dict) and "transcript" in obj:
|
| 221 |
-
return str(obj["transcript"])
|
| 222 |
-
return json.dumps(obj, ensure_ascii=False)
|
| 223 |
except Exception:
|
| 224 |
-
return
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
return data.decode("utf-8", errors="ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
| 229 |
toks = tokenizer(text, add_special_tokens=False)["input_ids"]
|
|
@@ -270,30 +284,47 @@ class ModelWrapper:
|
|
| 270 |
|
| 271 |
@torch.inference_mode()
|
| 272 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
|
|
| 273 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
**
|
| 286 |
generation_config=GEN_CONFIG,
|
| 287 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 288 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 289 |
)
|
|
|
|
|
|
|
|
|
|
| 290 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 291 |
|
| 292 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
| 293 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
| 294 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 295 |
if key not in _MODEL_CACHE:
|
| 296 |
-
m = ModelWrapper(repo_id, hf_token, load_in_4bit)
|
|
|
|
| 297 |
_MODEL_CACHE[key] = m
|
| 298 |
return _MODEL_CACHE[key]
|
| 299 |
|
|
@@ -303,8 +334,6 @@ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> Mode
|
|
| 303 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 304 |
ALLOWED_LABELS = OFFICIAL_LABELS
|
| 305 |
LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
|
| 306 |
-
FN_PENALTY = 2.0
|
| 307 |
-
FP_PENALTY = 1.0
|
| 308 |
|
| 309 |
def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
|
| 310 |
if not isinstance(sample_labels, list):
|
|
@@ -315,13 +344,10 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 315 |
raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
|
| 316 |
if label in seen:
|
| 317 |
raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
|
| 318 |
-
seen.add(label); uniq.append(label)
|
| 319 |
-
valid = []
|
| 320 |
-
for label in uniq:
|
| 321 |
if label not in ALLOWED_LABELS:
|
| 322 |
raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
|
| 323 |
-
|
| 324 |
-
return
|
| 325 |
|
| 326 |
if len(y_true) != len(y_pred):
|
| 327 |
raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
|
|
@@ -339,13 +365,37 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 339 |
for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
|
| 340 |
y_pred_binary[i, LABEL_TO_IDX[label]] = 1
|
| 341 |
|
| 342 |
-
fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
|
| 343 |
-
fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
|
| 344 |
weighted = 2.0 * fn + 1.0 * fp
|
| 345 |
max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
|
| 346 |
per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
|
| 347 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
# =========================
|
| 350 |
# Inference helpers
|
| 351 |
# =========================
|
|
@@ -358,34 +408,44 @@ def build_keyword_context(allowed: List[str]) -> str:
|
|
| 358 |
|
| 359 |
def run_single(
|
| 360 |
transcript_text: str,
|
| 361 |
-
transcript_file
|
|
|
|
|
|
|
| 362 |
use_cleaning: bool,
|
|
|
|
| 363 |
allowed_labels_text: str,
|
| 364 |
model_repo: str,
|
| 365 |
use_4bit: bool,
|
| 366 |
max_input_tokens: int,
|
| 367 |
hf_token: str,
|
| 368 |
-
) -> Tuple[str, str, str, str]:
|
| 369 |
|
| 370 |
t0 = _now_ms()
|
| 371 |
|
| 372 |
-
|
| 373 |
-
raw_text =
|
|
|
|
|
|
|
|
|
|
| 374 |
if not raw_text:
|
| 375 |
-
return "", "", "No transcript provided.",
|
| 376 |
|
| 377 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 378 |
|
|
|
|
| 379 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 380 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 381 |
|
|
|
|
| 382 |
try:
|
| 383 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 384 |
except Exception as e:
|
| 385 |
-
return "", "", f"Model load failed: {e}",
|
| 386 |
|
|
|
|
| 387 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 388 |
|
|
|
|
| 389 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 390 |
keyword_ctx = build_keyword_context(allowed)
|
| 391 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
@@ -394,25 +454,38 @@ def run_single(
|
|
| 394 |
keyword_context=keyword_ctx,
|
| 395 |
)
|
| 396 |
|
|
|
|
| 397 |
t1 = _now_ms()
|
| 398 |
try:
|
| 399 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 400 |
except Exception as e:
|
| 401 |
-
return "", "", f"Generation error: {e}",
|
| 402 |
t2 = _now_ms()
|
| 403 |
|
| 404 |
parsed = robust_json_extract(out)
|
| 405 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
diag = "\n".join([
|
| 408 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 409 |
f"Model: {model_repo}",
|
| 410 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
|
|
|
| 411 |
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 412 |
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
|
| 413 |
f"Allowed labels: {', '.join(allowed)}",
|
| 414 |
])
|
| 415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
labs = filtered.get("labels", [])
|
| 417 |
tasks = filtered.get("tasks", [])
|
| 418 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
@@ -424,20 +497,59 @@ def run_single(
|
|
| 424 |
else:
|
| 425 |
summary += "\n\nTasks: (none)"
|
| 426 |
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
# =========================
|
| 430 |
# Batch mode (ZIP with transcripts + truths)
|
| 431 |
# =========================
|
| 432 |
-
def
|
| 433 |
exdir.mkdir(parents=True, exist_ok=True)
|
| 434 |
-
with
|
|
|
|
|
|
|
| 435 |
zf.extractall(exdir)
|
| 436 |
return [p for p in exdir.rglob("*") if p.is_file()]
|
| 437 |
|
| 438 |
def run_batch(
|
| 439 |
-
|
| 440 |
use_cleaning: bool,
|
|
|
|
| 441 |
model_repo: str,
|
| 442 |
use_4bit: bool,
|
| 443 |
max_input_tokens: int,
|
|
@@ -445,24 +557,20 @@ def run_batch(
|
|
| 445 |
limit_files: int,
|
| 446 |
) -> Tuple[str, str, pd.DataFrame, str]:
|
| 447 |
|
| 448 |
-
if not
|
| 449 |
return ("No ZIP provided.", "", pd.DataFrame(), "")
|
| 450 |
|
| 451 |
work = Path("/tmp/batch")
|
| 452 |
if work.exists():
|
| 453 |
for p in sorted(work.rglob("*"), reverse=True):
|
| 454 |
-
try:
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
try:
|
| 459 |
-
work.rmdir()
|
| 460 |
-
except Exception:
|
| 461 |
-
pass
|
| 462 |
work.mkdir(parents=True, exist_ok=True)
|
| 463 |
|
| 464 |
-
|
| 465 |
-
files =
|
| 466 |
|
| 467 |
txts: Dict[str, Path] = {}
|
| 468 |
gts: Dict[str, Path] = {}
|
|
@@ -508,6 +616,12 @@ def run_batch(
|
|
| 508 |
|
| 509 |
parsed = robust_json_extract(out)
|
| 510 |
filtered = restrict_to_allowed(parsed, allowed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
pred_labels = filtered.get("labels", [])
|
| 512 |
y_pred.append(pred_labels)
|
| 513 |
|
|
@@ -543,6 +657,7 @@ def run_batch(
|
|
| 543 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 544 |
f"Model: {model_repo}",
|
| 545 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
|
|
|
| 546 |
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 547 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 548 |
]
|
|
@@ -563,7 +678,6 @@ def run_batch(
|
|
| 563 |
# save CSV for download
|
| 564 |
out_csv = Path("/tmp/batch_results.csv")
|
| 565 |
df.to_csv(out_csv, index=False, encoding="utf-8")
|
| 566 |
-
|
| 567 |
return ("Batch done.", diag_str, df, str(out_csv))
|
| 568 |
|
| 569 |
# =========================
|
|
@@ -585,16 +699,31 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 585 |
with gr.Tab("Single transcript"):
|
| 586 |
with gr.Row():
|
| 587 |
with gr.Column(scale=3):
|
|
|
|
| 588 |
file = gr.File(
|
| 589 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 590 |
file_types=[".txt", ".md", ".json"],
|
| 591 |
type="filepath",
|
| 592 |
)
|
| 593 |
-
text = gr.Textbox(label="Or paste transcript", lines=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
use_cleaning = gr.Checkbox(
|
| 595 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 596 |
value=True,
|
| 597 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
labels_text = gr.Textbox(
|
| 599 |
label="Allowed Labels (one per line; empty = official list)",
|
| 600 |
value="",
|
|
@@ -613,11 +742,17 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 613 |
with gr.Row():
|
| 614 |
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 615 |
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
|
|
|
|
|
|
|
|
|
| 616 |
|
| 617 |
run_btn.click(
|
| 618 |
fn=run_single,
|
| 619 |
-
inputs=[
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
| 621 |
)
|
| 622 |
|
| 623 |
with gr.Tab("Batch evaluation"):
|
|
@@ -625,6 +760,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 625 |
with gr.Column(scale=3):
|
| 626 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 627 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
|
|
|
| 628 |
with gr.Column(scale=2):
|
| 629 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 630 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
|
@@ -636,15 +772,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 636 |
with gr.Row():
|
| 637 |
status = gr.Textbox(label="Status", lines=1)
|
| 638 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
| 639 |
-
|
| 640 |
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|
| 641 |
csv_out = gr.File(label="Download CSV", interactive=False)
|
| 642 |
|
| 643 |
run_batch_btn.click(
|
| 644 |
fn=run_batch,
|
| 645 |
-
inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
|
| 646 |
outputs=[status, diag_b, df_out, csv_out],
|
| 647 |
)
|
| 648 |
|
| 649 |
if __name__ == "__main__":
|
|
|
|
| 650 |
demo.launch()
|
|
|
|
| 209 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 210 |
return s
|
| 211 |
|
| 212 |
+
def read_text_file_any(file_input) -> str:
|
| 213 |
+
"""Works for gr.File(type='filepath') and raw strings/Path and file-like."""
|
| 214 |
+
if not file_input:
|
| 215 |
return ""
|
| 216 |
+
# filepath string
|
| 217 |
+
if isinstance(file_input, (str, Path)):
|
|
|
|
| 218 |
try:
|
| 219 |
+
return Path(file_input).read_text(encoding="utf-8", errors="ignore")
|
|
|
|
|
|
|
|
|
|
| 220 |
except Exception:
|
| 221 |
+
return ""
|
| 222 |
+
# gr.File object or file-like
|
| 223 |
+
try:
|
| 224 |
+
data = file_input.read()
|
| 225 |
return data.decode("utf-8", errors="ignore")
|
| 226 |
+
except Exception:
|
| 227 |
+
return ""
|
| 228 |
+
|
| 229 |
+
def read_json_file_any(file_input) -> Optional[dict]:
|
| 230 |
+
if not file_input:
|
| 231 |
+
return None
|
| 232 |
+
if isinstance(file_input, (str, Path)):
|
| 233 |
+
try:
|
| 234 |
+
return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore"))
|
| 235 |
+
except Exception:
|
| 236 |
+
return None
|
| 237 |
+
try:
|
| 238 |
+
return json.loads(file_input.read().decode("utf-8", errors="ignore"))
|
| 239 |
+
except Exception:
|
| 240 |
+
return None
|
| 241 |
|
| 242 |
def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
| 243 |
toks = tokenizer(text, add_special_tokens=False)["input_ids"]
|
|
|
|
| 284 |
|
| 285 |
@torch.inference_mode()
|
| 286 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 287 |
+
# Build inputs as input_ids=... (avoid **tensor bug)
|
| 288 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 289 |
+
messages = [
|
| 290 |
+
{"role": "system", "content": system_prompt},
|
| 291 |
+
{"role": "user", "content": user_prompt},
|
| 292 |
+
]
|
| 293 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 294 |
+
messages,
|
| 295 |
+
tokenize=True,
|
| 296 |
+
add_generation_prompt=True,
|
| 297 |
+
return_tensors="pt",
|
| 298 |
+
)
|
| 299 |
+
input_ids = input_ids.to(self.model.device)
|
| 300 |
+
gen_kwargs = dict(
|
| 301 |
+
input_ids=input_ids,
|
| 302 |
+
generation_config=GEN_CONFIG,
|
| 303 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 304 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 305 |
+
)
|
| 306 |
else:
|
| 307 |
+
enc = self.tokenizer(
|
| 308 |
+
f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n",
|
| 309 |
+
return_tensors="pt"
|
| 310 |
+
).to(self.model.device)
|
| 311 |
+
gen_kwargs = dict(
|
| 312 |
+
**enc,
|
| 313 |
generation_config=GEN_CONFIG,
|
| 314 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 315 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 316 |
)
|
| 317 |
+
|
| 318 |
+
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| 319 |
+
out_ids = self.model.generate(**gen_kwargs)
|
| 320 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 321 |
|
| 322 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
| 323 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
| 324 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 325 |
if key not in _MODEL_CACHE:
|
| 326 |
+
m = ModelWrapper(repo_id, hf_token, load_in_4bit)
|
| 327 |
+
m.load()
|
| 328 |
_MODEL_CACHE[key] = m
|
| 329 |
return _MODEL_CACHE[key]
|
| 330 |
|
|
|
|
| 334 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 335 |
ALLOWED_LABELS = OFFICIAL_LABELS
|
| 336 |
LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
|
|
|
|
|
|
|
| 337 |
|
| 338 |
def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
|
| 339 |
if not isinstance(sample_labels, list):
|
|
|
|
| 344 |
raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
|
| 345 |
if label in seen:
|
| 346 |
raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
|
|
|
|
|
|
|
|
|
|
| 347 |
if label not in ALLOWED_LABELS:
|
| 348 |
raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
|
| 349 |
+
seen.add(label); uniq.append(label)
|
| 350 |
+
return uniq
|
| 351 |
|
| 352 |
if len(y_true) != len(y_pred):
|
| 353 |
raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
|
|
|
|
| 365 |
for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
|
| 366 |
y_pred_binary[i, LABEL_TO_IDX[label]] = 1
|
| 367 |
|
| 368 |
+
fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1) # penalty 2x
|
| 369 |
+
fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1) # penalty 1x
|
| 370 |
weighted = 2.0 * fn + 1.0 * fp
|
| 371 |
max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
|
| 372 |
per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
|
| 373 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 374 |
|
| 375 |
+
# =========================
|
| 376 |
+
# Fallback: keyword heuristics if model returns empty
|
| 377 |
+
# =========================
|
| 378 |
+
def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
|
| 379 |
+
low = text.lower()
|
| 380 |
+
labels = []
|
| 381 |
+
tasks = []
|
| 382 |
+
for lab in allowed:
|
| 383 |
+
hits = []
|
| 384 |
+
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 385 |
+
if kw.lower() in low:
|
| 386 |
+
# capture small evidence window
|
| 387 |
+
i = low.find(kw.lower())
|
| 388 |
+
start = max(0, i - 40); end = min(len(text), i + len(kw) + 40)
|
| 389 |
+
hits.append(text[start:end].strip())
|
| 390 |
+
if hits:
|
| 391 |
+
labels.append(lab)
|
| 392 |
+
tasks.append({
|
| 393 |
+
"label": lab,
|
| 394 |
+
"explanation": "Keyword match in transcript.",
|
| 395 |
+
"evidence": hits[0]
|
| 396 |
+
})
|
| 397 |
+
return {"labels": normalize_labels(labels), "tasks": tasks}
|
| 398 |
+
|
| 399 |
# =========================
|
| 400 |
# Inference helpers
|
| 401 |
# =========================
|
|
|
|
| 408 |
|
| 409 |
def run_single(
|
| 410 |
transcript_text: str,
|
| 411 |
+
transcript_file, # filepath or file-like
|
| 412 |
+
gt_json_text: str,
|
| 413 |
+
gt_json_file, # filepath or file-like
|
| 414 |
use_cleaning: bool,
|
| 415 |
+
use_keyword_fallback: bool,
|
| 416 |
allowed_labels_text: str,
|
| 417 |
model_repo: str,
|
| 418 |
use_4bit: bool,
|
| 419 |
max_input_tokens: int,
|
| 420 |
hf_token: str,
|
| 421 |
+
) -> Tuple[str, str, str, str, str, str]:
|
| 422 |
|
| 423 |
t0 = _now_ms()
|
| 424 |
|
| 425 |
+
# Transcript
|
| 426 |
+
raw_text = ""
|
| 427 |
+
if transcript_file:
|
| 428 |
+
raw_text = read_text_file_any(transcript_file)
|
| 429 |
+
raw_text = (raw_text or transcript_text or "").strip()
|
| 430 |
if not raw_text:
|
| 431 |
+
return "", "", "No transcript provided.", "", "", ""
|
| 432 |
|
| 433 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 434 |
|
| 435 |
+
# Allowed labels
|
| 436 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 437 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 438 |
|
| 439 |
+
# Model
|
| 440 |
try:
|
| 441 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 442 |
except Exception as e:
|
| 443 |
+
return "", "", f"Model load failed: {e}", "", "", ""
|
| 444 |
|
| 445 |
+
# Truncate
|
| 446 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 447 |
|
| 448 |
+
# Build prompt
|
| 449 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 450 |
keyword_ctx = build_keyword_context(allowed)
|
| 451 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 454 |
keyword_context=keyword_ctx,
|
| 455 |
)
|
| 456 |
|
| 457 |
+
# Generate
|
| 458 |
t1 = _now_ms()
|
| 459 |
try:
|
| 460 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 461 |
except Exception as e:
|
| 462 |
+
return "", "", f"Generation error: {e}", "", "", ""
|
| 463 |
t2 = _now_ms()
|
| 464 |
|
| 465 |
parsed = robust_json_extract(out)
|
| 466 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 467 |
|
| 468 |
+
# Fallback if empty
|
| 469 |
+
if use_keyword_fallback and not filtered.get("labels"):
|
| 470 |
+
fb = keyword_fallback(trunc, allowed)
|
| 471 |
+
if fb["labels"]:
|
| 472 |
+
filtered = fb
|
| 473 |
+
|
| 474 |
+
# Diagnostics
|
| 475 |
diag = "\n".join([
|
| 476 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 477 |
f"Model: {model_repo}",
|
| 478 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 479 |
+
f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
|
| 480 |
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 481 |
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
|
| 482 |
f"Allowed labels: {', '.join(allowed)}",
|
| 483 |
])
|
| 484 |
|
| 485 |
+
# Context preview shown in UI
|
| 486 |
+
context_preview = "Allowed Labels:\n" + "\n".join(f"- {l}" for l in allowed) + "\n\nKeyword cues:\n" + keyword_ctx
|
| 487 |
+
|
| 488 |
+
# Summary
|
| 489 |
labs = filtered.get("labels", [])
|
| 490 |
tasks = filtered.get("tasks", [])
|
| 491 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
| 497 |
else:
|
| 498 |
summary += "\n\nTasks: (none)"
|
| 499 |
|
| 500 |
+
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 501 |
+
|
| 502 |
+
# Optional single-file scoring if GT provided
|
| 503 |
+
metrics = ""
|
| 504 |
+
true_labels = None
|
| 505 |
+
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 506 |
+
truth_obj = None
|
| 507 |
+
if gt_json_file:
|
| 508 |
+
truth_obj = read_json_file_any(gt_json_file)
|
| 509 |
+
if (not truth_obj) and gt_json_text:
|
| 510 |
+
try:
|
| 511 |
+
truth_obj = json.loads(gt_json_text)
|
| 512 |
+
except Exception:
|
| 513 |
+
pass
|
| 514 |
+
if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list):
|
| 515 |
+
true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS]
|
| 516 |
+
pred_labels = labs
|
| 517 |
+
try:
|
| 518 |
+
score = evaluate_predictions([true_labels], [pred_labels])
|
| 519 |
+
tp = len(set(true_labels) & set(pred_labels))
|
| 520 |
+
fp = len(set(pred_labels) - set(true_labels))
|
| 521 |
+
fn = len(set(true_labels) - set(pred_labels))
|
| 522 |
+
recall = tp / (tp + fn) if (tp + fn) else 1.0
|
| 523 |
+
precision = tp / (tp + fp) if (tp + fp) else 1.0
|
| 524 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
|
| 525 |
+
metrics = (
|
| 526 |
+
f"Weighted score: {score:.3f}\n"
|
| 527 |
+
f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n"
|
| 528 |
+
f"TP={tp} FP={fp} FN={fn}\n"
|
| 529 |
+
f"Truth: {', '.join(true_labels)}"
|
| 530 |
+
)
|
| 531 |
+
except Exception as e:
|
| 532 |
+
metrics = f"Scoring error: {e}"
|
| 533 |
+
else:
|
| 534 |
+
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 535 |
+
|
| 536 |
+
return summary, json_out, diag, out.strip(), context_preview, metrics
|
| 537 |
|
| 538 |
# =========================
|
| 539 |
# Batch mode (ZIP with transcripts + truths)
|
| 540 |
# =========================
|
| 541 |
+
def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
|
| 542 |
exdir.mkdir(parents=True, exist_ok=True)
|
| 543 |
+
with open(path, "rb") as f:
|
| 544 |
+
data = f.read()
|
| 545 |
+
with zipfile.ZipFile(io.BytesIO(data)) as zf:
|
| 546 |
zf.extractall(exdir)
|
| 547 |
return [p for p in exdir.rglob("*") if p.is_file()]
|
| 548 |
|
| 549 |
def run_batch(
|
| 550 |
+
zip_path, # filepath string
|
| 551 |
use_cleaning: bool,
|
| 552 |
+
use_keyword_fallback: bool,
|
| 553 |
model_repo: str,
|
| 554 |
use_4bit: bool,
|
| 555 |
max_input_tokens: int,
|
|
|
|
| 557 |
limit_files: int,
|
| 558 |
) -> Tuple[str, str, pd.DataFrame, str]:
|
| 559 |
|
| 560 |
+
if not zip_path:
|
| 561 |
return ("No ZIP provided.", "", pd.DataFrame(), "")
|
| 562 |
|
| 563 |
work = Path("/tmp/batch")
|
| 564 |
if work.exists():
|
| 565 |
for p in sorted(work.rglob("*"), reverse=True):
|
| 566 |
+
try: p.unlink()
|
| 567 |
+
except Exception: pass
|
| 568 |
+
try: work.rmdir()
|
| 569 |
+
except Exception: pass
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
work.mkdir(parents=True, exist_ok=True)
|
| 571 |
|
| 572 |
+
# Unzip
|
| 573 |
+
files = read_zip_from_path(zip_path, work)
|
| 574 |
|
| 575 |
txts: Dict[str, Path] = {}
|
| 576 |
gts: Dict[str, Path] = {}
|
|
|
|
| 616 |
|
| 617 |
parsed = robust_json_extract(out)
|
| 618 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 619 |
+
|
| 620 |
+
if use_keyword_fallback and not filtered.get("labels"):
|
| 621 |
+
fb = keyword_fallback(trunc, allowed)
|
| 622 |
+
if fb["labels"]:
|
| 623 |
+
filtered = fb
|
| 624 |
+
|
| 625 |
pred_labels = filtered.get("labels", [])
|
| 626 |
y_pred.append(pred_labels)
|
| 627 |
|
|
|
|
| 657 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 658 |
f"Model: {model_repo}",
|
| 659 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 660 |
+
f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
|
| 661 |
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 662 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 663 |
]
|
|
|
|
| 678 |
# save CSV for download
|
| 679 |
out_csv = Path("/tmp/batch_results.csv")
|
| 680 |
df.to_csv(out_csv, index=False, encoding="utf-8")
|
|
|
|
| 681 |
return ("Batch done.", diag_str, df, str(out_csv))
|
| 682 |
|
| 683 |
# =========================
|
|
|
|
| 699 |
with gr.Tab("Single transcript"):
|
| 700 |
with gr.Row():
|
| 701 |
with gr.Column(scale=3):
|
| 702 |
+
gr.Markdown("### Transcript")
|
| 703 |
file = gr.File(
|
| 704 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 705 |
file_types=[".txt", ".md", ".json"],
|
| 706 |
type="filepath",
|
| 707 |
)
|
| 708 |
+
text = gr.Textbox(label="Or paste transcript", lines=10)
|
| 709 |
+
|
| 710 |
+
gr.Markdown("### Ground truth JSON (optional)")
|
| 711 |
+
gt_file = gr.File(
|
| 712 |
+
label="Upload ground truth JSON (expects {'labels': [...]})",
|
| 713 |
+
file_types=[".json"],
|
| 714 |
+
type="filepath",
|
| 715 |
+
)
|
| 716 |
+
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{"labels": ["schedule_meeting"]}')
|
| 717 |
+
|
| 718 |
use_cleaning = gr.Checkbox(
|
| 719 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 720 |
value=True,
|
| 721 |
)
|
| 722 |
+
use_keyword_fallback = gr.Checkbox(
|
| 723 |
+
label="Keyword fallback if model returns empty",
|
| 724 |
+
value=True,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
labels_text = gr.Textbox(
|
| 728 |
label="Allowed Labels (one per line; empty = official list)",
|
| 729 |
value="",
|
|
|
|
| 742 |
with gr.Row():
|
| 743 |
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 744 |
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
| 745 |
+
with gr.Row():
|
| 746 |
+
context_used = gr.Code(label="Effective context used this run (labels + keyword cues)", language="markdown")
|
| 747 |
+
single_metrics = gr.Textbox(label="Single-file metrics (if ground truth provided)", lines=6)
|
| 748 |
|
| 749 |
run_btn.click(
|
| 750 |
fn=run_single,
|
| 751 |
+
inputs=[
|
| 752 |
+
text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
|
| 753 |
+
labels_text, repo, use_4bit, max_tokens, hf_token
|
| 754 |
+
],
|
| 755 |
+
outputs=[summary, json_out, diag, raw, context_used, single_metrics],
|
| 756 |
)
|
| 757 |
|
| 758 |
with gr.Tab("Batch evaluation"):
|
|
|
|
| 760 |
with gr.Column(scale=3):
|
| 761 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 762 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 763 |
+
use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
|
| 764 |
with gr.Column(scale=2):
|
| 765 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 766 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
|
|
|
| 772 |
with gr.Row():
|
| 773 |
status = gr.Textbox(label="Status", lines=1)
|
| 774 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
|
|
|
| 775 |
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|
| 776 |
csv_out = gr.File(label="Download CSV", interactive=False)
|
| 777 |
|
| 778 |
run_batch_btn.click(
|
| 779 |
fn=run_batch,
|
| 780 |
+
inputs=[zip_in, use_cleaning_b, use_keyword_fallback_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
|
| 781 |
outputs=[status, diag_b, df_out, csv_out],
|
| 782 |
)
|
| 783 |
|
| 784 |
if __name__ == "__main__":
|
| 785 |
+
demo = demo # to satisfy some runtimes
|
| 786 |
demo.launch()
|