Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -45,6 +45,7 @@ OFFICIAL_LABELS = [
|
|
| 45 |
"update_kyc_purpose_of_businessrelation",
|
| 46 |
"update_kyc_total_assets",
|
| 47 |
]
|
|
|
|
| 48 |
|
| 49 |
# Per-label keyword cues (static prompt context to improve recall)
|
| 50 |
LABEL_KEYWORDS: Dict[str, List[str]] = {
|
|
@@ -213,13 +214,11 @@ def read_text_file_any(file_input) -> str:
|
|
| 213 |
"""Works for gr.File(type='filepath') and raw strings/Path and file-like."""
|
| 214 |
if not file_input:
|
| 215 |
return ""
|
| 216 |
-
# filepath string
|
| 217 |
if isinstance(file_input, (str, Path)):
|
| 218 |
try:
|
| 219 |
return Path(file_input).read_text(encoding="utf-8", errors="ignore")
|
| 220 |
except Exception:
|
| 221 |
return ""
|
| 222 |
-
# gr.File object or file-like
|
| 223 |
try:
|
| 224 |
data = file_input.read()
|
| 225 |
return data.decode("utf-8", errors="ignore")
|
|
@@ -284,7 +283,7 @@ class ModelWrapper:
|
|
| 284 |
|
| 285 |
@torch.inference_mode()
|
| 286 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 287 |
-
# Build inputs as input_ids=... (avoid **tensor bug)
|
| 288 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 289 |
messages = [
|
| 290 |
{"role": "system", "content": system_prompt},
|
|
@@ -382,10 +381,11 @@ def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
|
|
| 382 |
for lab in allowed:
|
| 383 |
hits = []
|
| 384 |
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 385 |
-
|
|
|
|
| 386 |
# capture small evidence window
|
| 387 |
-
i = low.find(
|
| 388 |
-
start = max(0, i - 40); end = min(len(text), i + len(
|
| 389 |
hits.append(text[start:end].strip())
|
| 390 |
if hits:
|
| 391 |
labels.append(lab)
|
|
@@ -418,7 +418,7 @@ def run_single(
|
|
| 418 |
use_4bit: bool,
|
| 419 |
max_input_tokens: int,
|
| 420 |
hf_token: str,
|
| 421 |
-
) -> Tuple[str, str, str, str, str, str]:
|
| 422 |
|
| 423 |
t0 = _now_ms()
|
| 424 |
|
|
@@ -428,11 +428,11 @@ def run_single(
|
|
| 428 |
raw_text = read_text_file_any(transcript_file)
|
| 429 |
raw_text = (raw_text or transcript_text or "").strip()
|
| 430 |
if not raw_text:
|
| 431 |
-
return "", "", "No transcript provided.", "", "", ""
|
| 432 |
|
| 433 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 434 |
|
| 435 |
-
# Allowed labels
|
| 436 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 437 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 438 |
|
|
@@ -440,7 +440,7 @@ def run_single(
|
|
| 440 |
try:
|
| 441 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 442 |
except Exception as e:
|
| 443 |
-
return "", "", f"Model load failed: {e}", "", "", ""
|
| 444 |
|
| 445 |
# Truncate
|
| 446 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
|
@@ -459,7 +459,7 @@ def run_single(
|
|
| 459 |
try:
|
| 460 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 461 |
except Exception as e:
|
| 462 |
-
return "", "", f"Generation error: {e}", "", "", ""
|
| 463 |
t2 = _now_ms()
|
| 464 |
|
| 465 |
parsed = robust_json_extract(out)
|
|
@@ -482,10 +482,16 @@ def run_single(
|
|
| 482 |
f"Allowed labels: {', '.join(allowed)}",
|
| 483 |
])
|
| 484 |
|
| 485 |
-
# Context preview shown in UI
|
| 486 |
-
context_preview =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
-
# Summary
|
| 489 |
labs = filtered.get("labels", [])
|
| 490 |
tasks = filtered.get("tasks", [])
|
| 491 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
@@ -496,7 +502,6 @@ def run_single(
|
|
| 496 |
)
|
| 497 |
else:
|
| 498 |
summary += "\n\nTasks: (none)"
|
| 499 |
-
|
| 500 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 501 |
|
| 502 |
# Optional single-file scoring if GT provided
|
|
@@ -533,7 +538,7 @@ def run_single(
|
|
| 533 |
else:
|
| 534 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 535 |
|
| 536 |
-
return summary, json_out, diag, out.strip(), context_preview, metrics
|
| 537 |
|
| 538 |
# =========================
|
| 539 |
# Batch mode (ZIP with transcripts + truths)
|
|
@@ -569,7 +574,6 @@ def run_batch(
|
|
| 569 |
except Exception: pass
|
| 570 |
work.mkdir(parents=True, exist_ok=True)
|
| 571 |
|
| 572 |
-
# Unzip
|
| 573 |
files = read_zip_from_path(zip_path, work)
|
| 574 |
|
| 575 |
txts: Dict[str, Path] = {}
|
|
@@ -642,7 +646,7 @@ def run_batch(
|
|
| 642 |
|
| 643 |
rows.append({
|
| 644 |
"file": stem,
|
| 645 |
-
"true_labels": ", "
|
| 646 |
"pred_labels": ", ".join(pred_labels),
|
| 647 |
"TP": len(tp), "FP": len(fp), "FN": len(fn),
|
| 648 |
"gen_ms": t1 - t0
|
|
@@ -689,32 +693,43 @@ MODEL_CHOICES = [
|
|
| 689 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 690 |
]
|
| 691 |
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
|
| 699 |
with gr.Tab("Single transcript"):
|
| 700 |
with gr.Row():
|
| 701 |
with gr.Column(scale=3):
|
| 702 |
-
gr.Markdown("
|
| 703 |
file = gr.File(
|
| 704 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 705 |
file_types=[".txt", ".md", ".json"],
|
| 706 |
type="filepath",
|
| 707 |
)
|
| 708 |
text = gr.Textbox(label="Or paste transcript", lines=10)
|
|
|
|
| 709 |
|
| 710 |
-
gr.Markdown("
|
| 711 |
gt_file = gr.File(
|
| 712 |
label="Upload ground truth JSON (expects {'labels': [...]})",
|
| 713 |
file_types=[".json"],
|
| 714 |
type="filepath",
|
| 715 |
)
|
| 716 |
-
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{"labels": ["schedule_meeting"]}')
|
|
|
|
| 717 |
|
|
|
|
| 718 |
use_cleaning = gr.Checkbox(
|
| 719 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 720 |
value=True,
|
|
@@ -723,28 +738,51 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 723 |
label="Keyword fallback if model returns empty",
|
| 724 |
value=True,
|
| 725 |
)
|
|
|
|
| 726 |
|
|
|
|
| 727 |
labels_text = gr.Textbox(
|
| 728 |
-
label="Allowed Labels (one per line
|
| 729 |
-
value=
|
| 730 |
lines=8,
|
| 731 |
)
|
|
|
|
|
|
|
|
|
|
| 732 |
with gr.Column(scale=2):
|
|
|
|
| 733 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 734 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 735 |
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 736 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 737 |
run_btn = gr.Button("Run Extraction", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
|
| 739 |
with gr.Row():
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
run_btn.click(
|
| 750 |
fn=run_single,
|
|
@@ -752,28 +790,38 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 752 |
text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
|
| 753 |
labels_text, repo, use_4bit, max_tokens, hf_token
|
| 754 |
],
|
| 755 |
-
outputs=[summary, json_out, diag, raw,
|
| 756 |
)
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
with gr.Tab("Batch evaluation"):
|
| 759 |
with gr.Row():
|
| 760 |
with gr.Column(scale=3):
|
|
|
|
| 761 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 762 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 763 |
use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
|
|
|
|
| 764 |
with gr.Column(scale=2):
|
|
|
|
| 765 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 766 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 767 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 768 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 769 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 770 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
|
|
|
| 771 |
|
| 772 |
with gr.Row():
|
|
|
|
| 773 |
status = gr.Textbox(label="Status", lines=1)
|
| 774 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
| 775 |
-
|
| 776 |
-
|
|
|
|
| 777 |
|
| 778 |
run_batch_btn.click(
|
| 779 |
fn=run_batch,
|
|
@@ -782,5 +830,4 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 782 |
)
|
| 783 |
|
| 784 |
if __name__ == "__main__":
|
| 785 |
-
demo = demo # to satisfy some runtimes
|
| 786 |
demo.launch()
|
|
|
|
| 45 |
"update_kyc_purpose_of_businessrelation",
|
| 46 |
"update_kyc_total_assets",
|
| 47 |
]
|
| 48 |
+
OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
|
| 49 |
|
| 50 |
# Per-label keyword cues (static prompt context to improve recall)
|
| 51 |
LABEL_KEYWORDS: Dict[str, List[str]] = {
|
|
|
|
| 214 |
"""Works for gr.File(type='filepath') and raw strings/Path and file-like."""
|
| 215 |
if not file_input:
|
| 216 |
return ""
|
|
|
|
| 217 |
if isinstance(file_input, (str, Path)):
|
| 218 |
try:
|
| 219 |
return Path(file_input).read_text(encoding="utf-8", errors="ignore")
|
| 220 |
except Exception:
|
| 221 |
return ""
|
|
|
|
| 222 |
try:
|
| 223 |
data = file_input.read()
|
| 224 |
return data.decode("utf-8", errors="ignore")
|
|
|
|
| 283 |
|
| 284 |
@torch.inference_mode()
|
| 285 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 286 |
+
# Build inputs as input_ids=... (avoid **tensor bug from earlier)
|
| 287 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 288 |
messages = [
|
| 289 |
{"role": "system", "content": system_prompt},
|
|
|
|
| 381 |
for lab in allowed:
|
| 382 |
hits = []
|
| 383 |
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 384 |
+
k = kw.lower()
|
| 385 |
+
if k in low:
|
| 386 |
# capture small evidence window
|
| 387 |
+
i = low.find(k)
|
| 388 |
+
start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
|
| 389 |
hits.append(text[start:end].strip())
|
| 390 |
if hits:
|
| 391 |
labels.append(lab)
|
|
|
|
| 418 |
use_4bit: bool,
|
| 419 |
max_input_tokens: int,
|
| 420 |
hf_token: str,
|
| 421 |
+
) -> Tuple[str, str, str, str, str, str, str]:
|
| 422 |
|
| 423 |
t0 = _now_ms()
|
| 424 |
|
|
|
|
| 428 |
raw_text = read_text_file_any(transcript_file)
|
| 429 |
raw_text = (raw_text or transcript_text or "").strip()
|
| 430 |
if not raw_text:
|
| 431 |
+
return "", "", "No transcript provided.", "", "", "", ""
|
| 432 |
|
| 433 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 434 |
|
| 435 |
+
# Allowed labels (pre-filled defaults)
|
| 436 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 437 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 438 |
|
|
|
|
| 440 |
try:
|
| 441 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 442 |
except Exception as e:
|
| 443 |
+
return "", "", f"Model load failed: {e}", "", "", "", ""
|
| 444 |
|
| 445 |
# Truncate
|
| 446 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
|
|
|
| 459 |
try:
|
| 460 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 461 |
except Exception as e:
|
| 462 |
+
return "", "", f"Generation error: {e}", "", "", "", ""
|
| 463 |
t2 = _now_ms()
|
| 464 |
|
| 465 |
parsed = robust_json_extract(out)
|
|
|
|
| 482 |
f"Allowed labels: {', '.join(allowed)}",
|
| 483 |
])
|
| 484 |
|
| 485 |
+
# Context & instructions preview shown in UI
|
| 486 |
+
context_preview = (
|
| 487 |
+
"### Allowed Labels\n"
|
| 488 |
+
+ "\n".join(f"- {l}" for l in allowed)
|
| 489 |
+
+ "\n\n### Keyword cues per label\n"
|
| 490 |
+
+ keyword_ctx
|
| 491 |
+
)
|
| 492 |
+
instructions_preview = "```\n" + SYSTEM_PROMPT + "\n```"
|
| 493 |
|
| 494 |
+
# Summary & JSON
|
| 495 |
labs = filtered.get("labels", [])
|
| 496 |
tasks = filtered.get("tasks", [])
|
| 497 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
| 502 |
)
|
| 503 |
else:
|
| 504 |
summary += "\n\nTasks: (none)"
|
|
|
|
| 505 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 506 |
|
| 507 |
# Optional single-file scoring if GT provided
|
|
|
|
| 538 |
else:
|
| 539 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 540 |
|
| 541 |
+
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics
|
| 542 |
|
| 543 |
# =========================
|
| 544 |
# Batch mode (ZIP with transcripts + truths)
|
|
|
|
| 574 |
except Exception: pass
|
| 575 |
work.mkdir(parents=True, exist_ok=True)
|
| 576 |
|
|
|
|
| 577 |
files = read_zip_from_path(zip_path, work)
|
| 578 |
|
| 579 |
txts: Dict[str, Path] = {}
|
|
|
|
| 646 |
|
| 647 |
rows.append({
|
| 648 |
"file": stem,
|
| 649 |
+
"true_labels": ", "..join(gt_labels),
|
| 650 |
"pred_labels": ", ".join(pred_labels),
|
| 651 |
"TP": len(tp), "FP": len(fp), "FN": len(fn),
|
| 652 |
"gen_ms": t1 - t0
|
|
|
|
| 693 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 694 |
]
|
| 695 |
|
| 696 |
+
custom_css = """
|
| 697 |
+
:root { --radius: 14px; }
|
| 698 |
+
.gradio-container { font-family: Inter, ui-sans-serif, system-ui; }
|
| 699 |
+
.card { border: 1px solid rgba(255,255,255,.08); border-radius: var(--radius); padding: 14px 16px; background: rgba(255,255,255,.02); box-shadow: 0 1px 10px rgba(0,0,0,.12) inset; }
|
| 700 |
+
.header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
|
| 701 |
+
.subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; }
|
| 702 |
+
hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; }
|
| 703 |
+
.accordion-title { font-weight: 600; }
|
| 704 |
+
.gr-button { border-radius: 12px !important; }
|
| 705 |
+
"""
|
| 706 |
+
|
| 707 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
|
| 708 |
+
gr.Markdown("<div class='header'>Talk2Task — Task Extraction (UBS Challenge)</div>")
|
| 709 |
+
gr.Markdown("<div class='subtle'>False negatives are penalised 2× more than false positives in the official score. This UI biases for recall, shows the exact instructions & context, and supports single or batch evaluation.</div>")
|
| 710 |
|
| 711 |
with gr.Tab("Single transcript"):
|
| 712 |
with gr.Row():
|
| 713 |
with gr.Column(scale=3):
|
| 714 |
+
gr.Markdown("<div class='card'><div class='header'>Transcript</div>", elem_id="card1")
|
| 715 |
file = gr.File(
|
| 716 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 717 |
file_types=[".txt", ".md", ".json"],
|
| 718 |
type="filepath",
|
| 719 |
)
|
| 720 |
text = gr.Textbox(label="Or paste transcript", lines=10)
|
| 721 |
+
gr.Markdown("<hr class='sep'/>")
|
| 722 |
|
| 723 |
+
gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>", elem_id="card1b")
|
| 724 |
gt_file = gr.File(
|
| 725 |
label="Upload ground truth JSON (expects {'labels': [...]})",
|
| 726 |
file_types=[".json"],
|
| 727 |
type="filepath",
|
| 728 |
)
|
| 729 |
+
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
|
| 730 |
+
gr.Markdown("</div>") # close card
|
| 731 |
|
| 732 |
+
gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>", elem_id="card2")
|
| 733 |
use_cleaning = gr.Checkbox(
|
| 734 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 735 |
value=True,
|
|
|
|
| 738 |
label="Keyword fallback if model returns empty",
|
| 739 |
value=True,
|
| 740 |
)
|
| 741 |
+
gr.Markdown("</div>")
|
| 742 |
|
| 743 |
+
gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>", elem_id="card3")
|
| 744 |
labels_text = gr.Textbox(
|
| 745 |
+
label="Allowed Labels (one per line)",
|
| 746 |
+
value=OFFICIAL_LABELS_TEXT, # prefilled
|
| 747 |
lines=8,
|
| 748 |
)
|
| 749 |
+
reset_btn = gr.Button("Reset to official labels")
|
| 750 |
+
gr.Markdown("</div>")
|
| 751 |
+
|
| 752 |
with gr.Column(scale=2):
|
| 753 |
+
gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card4")
|
| 754 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 755 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 756 |
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 757 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 758 |
run_btn = gr.Button("Run Extraction", variant="primary")
|
| 759 |
+
gr.Markdown("</div>")
|
| 760 |
+
|
| 761 |
+
gr.Markdown("<div class='card'><div class='header'>Outputs</div>", elem_id="card5")
|
| 762 |
+
summary = gr.Textbox(label="Summary", lines=12)
|
| 763 |
+
json_out = gr.Code(label="Strict JSON Output", language="json")
|
| 764 |
+
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 765 |
+
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
| 766 |
+
gr.Markdown("</div>")
|
| 767 |
|
| 768 |
with gr.Row():
|
| 769 |
+
with gr.Column():
|
| 770 |
+
with gr.Accordion("Instructions used (system prompt)", open=False):
|
| 771 |
+
instr_md = gr.Markdown("")
|
| 772 |
+
with gr.Column():
|
| 773 |
+
with gr.Accordion("Context used (allowed labels + keyword cues)", open=True):
|
| 774 |
+
context_md = gr.Markdown("")
|
| 775 |
+
|
| 776 |
+
# reset button behavior
|
| 777 |
+
def _reset_labels():
|
| 778 |
+
return OFFICIAL_LABELS_TEXT
|
| 779 |
+
reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
|
| 780 |
+
|
| 781 |
+
# single run
|
| 782 |
+
def _pack_context_md(allowed: str) -> str:
|
| 783 |
+
allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
|
| 784 |
+
ctx = build_keyword_context(allowed_list)
|
| 785 |
+
return "### Allowed Labels\n" + "\n".join(f"- {l}" for l in allowed_list) + "\n\n### Keyword cues per label\n" + ctx
|
| 786 |
|
| 787 |
run_btn.click(
|
| 788 |
fn=run_single,
|
|
|
|
| 790 |
text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
|
| 791 |
labels_text, repo, use_4bit, max_tokens, hf_token
|
| 792 |
],
|
| 793 |
+
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False)],
|
| 794 |
)
|
| 795 |
|
| 796 |
+
# also keep instructions visible at initial load
|
| 797 |
+
instr_md.value = "```\n" + SYSTEM_PROMPT + "\n```"
|
| 798 |
+
context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
|
| 799 |
+
|
| 800 |
with gr.Tab("Batch evaluation"):
|
| 801 |
with gr.Row():
|
| 802 |
with gr.Column(scale=3):
|
| 803 |
+
gr.Markdown("<div class='card'><div class='header'>ZIP input</div>", elem_id="card6")
|
| 804 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 805 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 806 |
use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
|
| 807 |
+
gr.Markdown("</div>")
|
| 808 |
with gr.Column(scale=2):
|
| 809 |
+
gr.Markdown("<div class='card'><div class='header'>Model & run</div>", elem_id="card7")
|
| 810 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 811 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 812 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 813 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 814 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 815 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 816 |
+
gr.Markdown("</div>")
|
| 817 |
|
| 818 |
with gr.Row():
|
| 819 |
+
gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>", elem_id="card8")
|
| 820 |
status = gr.Textbox(label="Status", lines=1)
|
| 821 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
| 822 |
+
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|
| 823 |
+
csv_out = gr.File(label="Download CSV", interactive=False)
|
| 824 |
+
gr.Markdown("</div>")
|
| 825 |
|
| 826 |
run_batch_btn.click(
|
| 827 |
fn=run_batch,
|
|
|
|
| 830 |
)
|
| 831 |
|
| 832 |
if __name__ == "__main__":
|
|
|
|
| 833 |
demo.launch()
|