Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import io | |
| import json | |
| import time | |
| import zipfile | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| GenerationConfig, | |
| LlamaTokenizer, # manual fallback | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # ========================= | |
| # Global config | |
| # ========================= | |
| SPACE_CACHE = Path.home() / ".cache" / "huggingface" | |
| SPACE_CACHE.mkdir(parents=True, exist_ok=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Force slow path by default; avoid Rust tokenizer JSON parsing | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| os.environ.setdefault("TOKENIZERS_PREFER_FAST", "false") | |
| GEN_CONFIG = GenerationConfig( | |
| temperature=0.0, | |
| top_p=1.0, | |
| do_sample=False, | |
| max_new_tokens=128, # raise if JSON truncates | |
| ) | |
| OFFICIAL_LABELS = [ | |
| "plan_contact", | |
| "schedule_meeting", | |
| "update_contact_info_non_postal", | |
| "update_contact_info_postal_address", | |
| "update_kyc_activity", | |
| "update_kyc_origin_of_assets", | |
| "update_kyc_purpose_of_businessrelation", | |
| "update_kyc_total_assets", | |
| ] | |
| OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS) | |
| # ========================= | |
| # Editable defaults (shown in UI) | |
| # ========================= | |
| DEFAULT_SYSTEM_INSTRUCTIONS = ( | |
| "You extract ACTIONABLE TASKS from client–advisor transcripts. " | |
| "The transcript may be in German, French, Italian, or English. " | |
| "Prioritize RECALL: if a label plausibly applies, include it. " | |
| "Use ONLY the canonical labels provided. " | |
| "Return STRICT JSON only with keys 'labels' and 'tasks'. " | |
| "Each task must include 'label', a brief 'explanation', and a short 'evidence' quote from the transcript." | |
| ) | |
| DEFAULT_LABEL_GLOSSARY = { | |
| "plan_contact": "Commitment to contact later (advisor/client will reach out, follow-up promised).", | |
| "schedule_meeting": "Scheduling or confirming a meeting/call/appointment (time/date/slot/virtual).", | |
| "update_contact_info_non_postal": "Change or confirmation of phone/email (non-postal contact details).", | |
| "update_contact_info_postal_address": "Change or confirmation of postal/residential/mailing address.", | |
| "update_kyc_activity": "Change/confirmation of occupation, employment status, or economic activity.", | |
| "update_kyc_origin_of_assets": "Discussion/confirmation of source of funds / origin of assets.", | |
| "update_kyc_purpose_of_businessrelation": "Purpose of the banking relationship/account usage.", | |
| "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.", | |
| } | |
| # Minimal multilingual fallback rules (optional) | |
| DEFAULT_FALLBACK_CUES = { | |
| "plan_contact": [ | |
| r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b", | |
| r"\bcontact (you|me|us)\b", | |
| r"\bin verbindung setzen\b", r"\brückmeldung\b", r"\bich\s+melde\b|\bwir\s+melden\b", r"\bnachfassen\b", | |
| r"\bje vous recontacte\b|\bnous vous recontacterons\b", r"\bprendre contact\b|\breprendre contact\b", | |
| r"\bla ricontatter[oò]\b|\bci metteremo in contatto\b", r"\btenersi in contatto\b", | |
| ], | |
| "schedule_meeting": [ | |
| r"\b(let'?s\s+)?meet(ing|s)?\b", r"\bschedule( a)? (call|meeting|appointment)\b", | |
| r"\bbook( a)? (slot|time|meeting)\b", r"\b(next week|tomorrow|this (afternoon|morning|evening))\b", | |
| r"\bconfirm( the)? (time|meeting|appointment)\b", | |
| r"\btermin(e|s)?\b|\bvereinbaren\b|\bansetzen\b|\babstimmen\b|\bbesprechung(en)?\b|\bvirtuell(e|en)?\b", | |
| r"\bnächste(n|r)? woche\b|\b(dienstag|montag|mittwoch|donnerstag|freitag)\b|\bnachmittag|vormittag|morgen\b", | |
| r"\brendez[- ]?vous\b|\bréunion\b|\bfixer\b|\bplanifier\b|\bse rencontrer\b|\bse voir\b", | |
| r"\bla semaine prochaine\b|\bdemain\b|\bcet (après-midi|apres-midi|après midi|apres midi|matin|soir)\b", | |
| r"\bappuntamento\b|\briunione\b|\borganizzare\b|\bprogrammare\b|\bincontrarci\b|\bcalendario\b", | |
| r"\bla prossima settimana\b|\bdomani\b|\b(questo|questa)\s*(pomeriggio|mattina|sera)\b", | |
| ], | |
| "update_kyc_origin_of_assets": [ | |
| r"\bsource of funds\b|\borigin of assets\b|\bproof of (funds|assets)\b", | |
| r"\bvermögensursprung(e|s)?\b|\bherkunft der mittel\b|\bnachweis\b", | |
| r"\borigine des fonds\b|\borigine du patrimoine\b|\bjustificatif(s)?\b", | |
| r"\borigine dei fondi\b|\borigine del patrimonio\b|\bprova dei fondi\b|\bgiustificativo\b", | |
| ], | |
| "update_kyc_activity": [ | |
| r"\bemployment status\b|\boccupation\b|\bjob change\b|\bsalary history\b", | |
| r"\bbeschäftigungsstatus\b|\bberuf\b|\bjobwechsel\b|\bgehaltshistorie\b|\btätigkeit\b", | |
| r"\bstatut professionnel\b|\bprofession\b|\bchangement d'emploi\b|\bhistorique salarial\b|\bactivité\b", | |
| r"\bstato occupazionale\b|\bprofessione\b|\bcambio di lavoro\b|\bstoria salariale\b|\battivit[aà]\b", | |
| ], | |
| } | |
| # ========================= | |
| # Prompt template | |
| # ========================= | |
| USER_PROMPT_TEMPLATE = ( | |
| "Transcript (may be DE/FR/IT/EN):\n" | |
| "```\n{transcript}\n```\n\n" | |
| "Allowed Labels (canonical; use only these):\n" | |
| "{allowed_labels_list}\n\n" | |
| "Label Glossary (concise semantics):\n" | |
| "{glossary}\n\n" | |
| "Return STRICT JSON ONLY in this exact schema:\n" | |
| '{\n "labels": ["<Label1>", "..."],\n' | |
| ' "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<quote>"}]\n}\n' | |
| ) | |
| # ========================= | |
| # Utilities | |
| # ========================= | |
| def _now_ms() -> int: | |
| return int(time.time() * 1000) | |
| def normalize_labels(labels: List[str]) -> List[str]: | |
| return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()])) | |
| def canonicalize_map(allowed: List[str]) -> Dict[str, str]: | |
| return {lab.lower(): lab for lab in allowed} | |
| def robust_json_extract(text: str) -> Dict[str, Any]: | |
| if not text: | |
| return {"labels": [], "tasks": []} | |
| start, end = text.find("{"), text.rfind("}") | |
| candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| candidate = re.sub(r",\s*}", "}", candidate) | |
| candidate = re.sub(r",\s*]", "]", candidate) | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| return {"labels": [], "tasks": []} | |
| def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]: | |
| out = {"labels": [], "tasks": []} | |
| allowed_map = canonicalize_map(allowed) | |
| filt_labels = [] | |
| for l in pred.get("labels", []) or []: | |
| k = str(l).strip().lower() | |
| if k in allowed_map: | |
| filt_labels.append(allowed_map[k]) | |
| filt_labels = normalize_labels(filt_labels) | |
| filt_tasks = [] | |
| for t in pred.get("tasks", []) or []: | |
| if not isinstance(t, dict): | |
| continue | |
| k = str(t.get("label", "")).strip().lower() | |
| if k in allowed_map: | |
| new_t = dict(t); new_t["label"] = allowed_map[k] | |
| new_t = { | |
| "label": new_t["label"], | |
| "explanation": str(new_t.get("explanation", ""))[:300], | |
| "evidence": str(new_t.get("evidence", ""))[:300], | |
| } | |
| filt_tasks.append(new_t) | |
| merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks})) | |
| out["labels"] = merged | |
| out["tasks"] = filt_tasks | |
| return out | |
| # ========================= | |
| # Pre-processing | |
| # ========================= | |
| _DISCLAIMER_PATTERNS = [ | |
| r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)", | |
| r"(?is)^\s*the information contained.+?(?:\n{2,}|$)", | |
| r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)", | |
| ] | |
| _FOOTER_PATTERNS = [ | |
| r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$", | |
| r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$", | |
| ] | |
| _TIMESTAMP_SPEAKER = [ | |
| r"\[\d{1,2}:\d{2}(:\d{2})?\]", | |
| r"^\s*(advisor|client|client advisor)\s*:\s*", | |
| r"^\s*(speaker\s*\d+)\s*:\s*", | |
| ] | |
| def clean_transcript(text: str) -> str: | |
| if not text: | |
| return text | |
| s = text | |
| # strip speaker/timestamps | |
| lines = [] | |
| for ln in s.splitlines(): | |
| ln2 = ln | |
| for pat in _TIMESTAMP_SPEAKER: | |
| ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE) | |
| lines.append(ln2) | |
| s = "\n".join(lines) | |
| # disclaimers (top) | |
| for pat in _DISCLAIMER_PATTERNS: | |
| s = re.sub(pat, "", s).strip() | |
| # footers | |
| for pat in _FOOTER_PATTERNS: | |
| s = re.sub(pat, "", s) | |
| # whitespace tidy | |
| s = re.sub(r"[ \t]+", " ", s) | |
| s = re.sub(r"\n{3,}", "\n\n", s).strip() | |
| return s | |
| def read_text_file_any(file_input) -> str: | |
| if not file_input: | |
| return "" | |
| if isinstance(file_input, (str, Path)): | |
| try: | |
| return Path(file_input).read_text(encoding="utf-8", errors="ignore") | |
| except Exception: | |
| return "" | |
| try: | |
| data = file_input.read() | |
| return data.decode("utf-8", errors="ignore") | |
| except Exception: | |
| return "" | |
| def read_json_file_any(file_input) -> Optional[dict]: | |
| if not file_input: | |
| return None | |
| if isinstance(file_input, (str, Path)): | |
| try: | |
| return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore")) | |
| except Exception: | |
| return None | |
| try: | |
| return json.loads(file_input.read().decode("utf-8", errors="ignore")) | |
| except Exception: | |
| return None | |
| def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str: | |
| toks = tokenizer(text, add_special_tokens=False)["input_ids"] | |
| if len(toks) <= max_tokens: | |
| return text | |
| return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True) | |
| # ========================= | |
| # Cache purge for fresh downloads | |
| # ========================= | |
| def _purge_repo_from_cache(repo_id: str): | |
| try: | |
| base = SPACE_CACHE | |
| safe = repo_id.replace("/", "--") | |
| for p in base.glob(f"models--{safe}*"): | |
| try: | |
| if p.is_file(): | |
| p.unlink() | |
| else: | |
| for sub in sorted(p.rglob("*"), reverse=True): | |
| try: | |
| if sub.is_file() or sub.is_symlink(): | |
| sub.unlink() | |
| else: | |
| sub.rmdir() | |
| except Exception: | |
| pass | |
| p.rmdir() | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # ========================= | |
| # HF model wrapper (with manual LlamaTokenizer fallback) | |
| # ========================= | |
| class ModelWrapper: | |
| def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool): | |
| self.repo_id = repo_id | |
| self.hf_token = hf_token | |
| self.load_in_4bit = load_in_4bit | |
| self.use_sdpa = use_sdpa | |
| self.force_tok_redownload = force_tok_redownload | |
| self.tokenizer = None | |
| self.model = None | |
| self.load_path = "uninitialized" | |
| def _try_auto_tokenizer(self, use_fast: bool): | |
| return AutoTokenizer.from_pretrained( | |
| self.repo_id, | |
| token=self.hf_token, | |
| cache_dir=str(SPACE_CACHE), | |
| trust_remote_code=True, | |
| local_files_only=False, | |
| force_download=True if self.force_tok_redownload else False, | |
| use_fast=use_fast, | |
| ) | |
| def _try_manual_llama_tokenizer(self): | |
| # Download only tokenizer.model; ignore tokenizer.json entirely | |
| sp_path = hf_hub_download(repo_id=self.repo_id, filename="tokenizer.model", token=self.hf_token, cache_dir=str(SPACE_CACHE)) | |
| tok = LlamaTokenizer(vocab_file=sp_path) | |
| if tok.pad_token is None and tok.eos_token: | |
| tok.pad_token = tok.eos_token | |
| return tok | |
| def _load_tokenizer(self): | |
| if self.force_tok_redownload: | |
| _purge_repo_from_cache(self.repo_id) | |
| # 1) Slow auto | |
| try: | |
| tok = self._try_auto_tokenizer(use_fast=False) | |
| if tok.pad_token is None and tok.eos_token: | |
| tok.pad_token = tok.eos_token | |
| self.load_path = "tok:AUTO_SLOW" | |
| return tok | |
| except Exception: | |
| pass | |
| # 2) Manual LlamaTokenizer from tokenizer.model | |
| try: | |
| tok = self._try_manual_llama_tokenizer() | |
| self.load_path = "tok:LLAMA_SPM" | |
| return tok | |
| except Exception: | |
| pass | |
| # 3) Fast auto (last resort) | |
| tok = self._try_auto_tokenizer(use_fast=True) # will raise if broken | |
| if tok.pad_token is None and tok.eos_token: | |
| tok.pad_token = tok.eos_token | |
| self.load_path = "tok:AUTO_FAST" | |
| return tok | |
| def load(self): | |
| qcfg = None | |
| if self.load_in_4bit and DEVICE == "cuda": | |
| qcfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tok = self._load_tokenizer() | |
| errors = [] | |
| for desc, kwargs in [ | |
| ("auto_device_no_lowcpu" + ("_sdpa" if (self.use_sdpa and DEVICE=="cuda") else ""), | |
| dict( | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=False, | |
| quantization_config=qcfg, | |
| trust_remote_code=True, | |
| cache_dir=str(SPACE_CACHE), | |
| attn_implementation=("sdpa" if (self.use_sdpa and DEVICE == "cuda") else None), | |
| )), | |
| ("auto_device_no_sdpa", | |
| dict( | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=False, | |
| quantization_config=qcfg, | |
| trust_remote_code=True, | |
| cache_dir=str(SPACE_CACHE), | |
| )), | |
| ("cpu_then_to_cuda" if DEVICE == "cuda" else "cpu_only", | |
| dict( | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map=None, | |
| low_cpu_mem_usage=False, | |
| quantization_config=None if DEVICE != "cuda" else qcfg, | |
| trust_remote_code=True, | |
| cache_dir=str(SPACE_CACHE), | |
| )), | |
| ]: | |
| try: | |
| mdl = AutoModelForCausalLM.from_pretrained(self.repo_id, token=self.hf_token, **kwargs) | |
| if desc.startswith("cpu_then_to_cuda") and DEVICE == "cuda": | |
| mdl = mdl.to(torch.device("cuda")) | |
| self.tokenizer = tok | |
| self.model = mdl | |
| self.load_path = f"{self.load_path} | {desc}" | |
| return | |
| except Exception as e: | |
| errors.append(f"{desc}: {e}") | |
| raise RuntimeError("All load attempts failed:\n" + "\n".join(errors)) | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| if hasattr(self.tokenizer, "apply_chat_template"): | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| input_ids = input_ids.to(self.model.device) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| generation_config=GEN_CONFIG, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| else: | |
| enc = self.tokenizer( | |
| f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n", | |
| return_tensors="pt" | |
| ).to(self.model.device) | |
| gen_kwargs = dict( | |
| **enc, | |
| generation_config=GEN_CONFIG, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): | |
| out_ids = self.model.generate(**gen_kwargs) | |
| return self.tokenizer.decode(out_ids[0], skip_special_tokens=True) | |
| _MODEL_CACHE: Dict[str, ModelWrapper] = {} | |
| def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool) -> ModelWrapper: | |
| key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}::{'force' if force_tok_redownload else 'cache'}" | |
| if key not in _MODEL_CACHE: | |
| m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa, force_tok_redownload) | |
| m.load() | |
| _MODEL_CACHE[key] = m | |
| return _MODEL_CACHE[key] | |
| # ========================= | |
| # Evaluation (official weighted score) | |
| # ========================= | |
| def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float: | |
| ALLOWED_LABELS = OFFICIAL_LABELS | |
| LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)} | |
| def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]: | |
| if not isinstance(sample_labels, list): | |
| raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}") | |
| seen, uniq = set(), [] | |
| for label in sample_labels: | |
| if not isinstance(label, str): | |
| raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})") | |
| if label in seen: | |
| raise ValueError(f"{sample_name} contains duplicate label: '{label}'") | |
| if label not in ALLOWED_LABELS: | |
| raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}") | |
| seen.add(label); uniq.append(label) | |
| return uniq | |
| if len(y_true) != len(y_pred): | |
| raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}") | |
| n_samples = len(y_true) | |
| n_labels = len(OFFICIAL_LABELS) | |
| y_true_binary = np.zeros((n_samples, n_labels), dtype=int) | |
| y_pred_binary = np.zeros((n_samples, n_labels), dtype=int) | |
| for i, sample_labels in enumerate(y_true): | |
| for label in _process_sample_labels(sample_labels, f"y_true[{i}]"): | |
| y_true_binary[i, LABEL_TO_IDX[label]] = 1 | |
| for i, sample_labels in enumerate(y_pred): | |
| for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"): | |
| y_pred_binary[i, LABEL_TO_IDX[label]] = 1 | |
| fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1) | |
| fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1) | |
| weighted = 2.0 * fn + 1.0 * fp | |
| max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1)) | |
| per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0) | |
| return float(max(0.0, min(1.0, np.mean(per_sample)))) | |
| # ========================= | |
| # Multilingual regex fallback (optional) | |
| # ========================= | |
| def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]: | |
| low = text.lower() | |
| labels, tasks = [], [] | |
| for lab in allowed: | |
| for pat in cues.get(lab, []): | |
| m = re.search(pat, low) | |
| if m: | |
| i = m.start() | |
| start = max(0, i - 60); end = min(len(text), i + len(m.group(0)) + 60) | |
| if lab not in labels: | |
| labels.append(lab) | |
| tasks.append({ | |
| "label": lab, | |
| "explanation": "Rule hit (multilingual fallback)", | |
| "evidence": text[start:end].strip() | |
| }) | |
| break | |
| return {"labels": normalize_labels(labels), "tasks": tasks} | |
| # ========================= | |
| # Inference helpers | |
| # ========================= | |
| def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str: | |
| return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed]) | |
| def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str, force_tok_redownload: bool) -> str: | |
| t0 = _now_ms() | |
| try: | |
| model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload) | |
| _ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}') | |
| return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}" | |
| except Exception as e: | |
| return f"Warm-up failed: {e}" | |
| def run_single( | |
| transcript_text: str, | |
| transcript_file, | |
| gt_json_text: str, | |
| gt_json_file, | |
| use_cleaning: bool, | |
| use_fallback: bool, | |
| allowed_labels_text: str, | |
| sys_instructions_text: str, | |
| glossary_json_text: str, | |
| fallback_json_text: str, | |
| model_repo: str, | |
| use_4bit: bool, | |
| use_sdpa: bool, | |
| max_input_tokens: int, | |
| hf_token: str, | |
| force_tok_redownload: bool, | |
| ) -> Tuple[str, str, str, str, str, str, str, str, str]: | |
| t0 = _now_ms() | |
| raw_text = "" | |
| if transcript_file: | |
| raw_text = read_text_file_any(transcript_file) | |
| raw_text = (raw_text or transcript_text or "").strip() | |
| if not raw_text: | |
| return "", "", "No transcript provided.", "", "", "", "", "", "" | |
| text = clean_transcript(raw_text) if use_cleaning else raw_text | |
| user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()] | |
| allowed = normalize_labels(user_allowed or OFFICIAL_LABELS) | |
| try: | |
| sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS | |
| except Exception: | |
| sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS | |
| try: | |
| label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY | |
| except Exception: | |
| label_glossary = DEFAULT_LABEL_GLOSSARY | |
| try: | |
| fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES | |
| except Exception: | |
| fallback_cues = DEFAULT_FALLBACK_CUES | |
| try: | |
| model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload) | |
| except Exception as e: | |
| return "", "", f"Model load failed: {e}", "", "", "", "", "", "" | |
| trunc = truncate_tokens(model.tokenizer, text, max_input_tokens) | |
| glossary_str = build_glossary_str(label_glossary, allowed) | |
| allowed_list_str = "\n".join(f"- {l}" for l in allowed) | |
| user_prompt = USER_PROMPT_TEMPLATE.format( | |
| transcript=trunc, | |
| allowed_labels_list=allowed_list_str, | |
| glossary=glossary_str, | |
| ) | |
| transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"]) | |
| prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"]) | |
| token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}" | |
| prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```" | |
| t1 = _now_ms() | |
| try: | |
| out = model.generate(sys_instructions, user_prompt) | |
| except Exception as e: | |
| return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, "" | |
| t2 = _now_ms() | |
| parsed = robust_json_extract(out) | |
| filtered = restrict_to_allowed(parsed, allowed) | |
| if use_fallback: | |
| fb = multilingual_fallback(trunc, allowed, fallback_cues) | |
| if fb["labels"]: | |
| merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"]))) | |
| existing = {tt.get("label") for tt in filtered.get("tasks", [])} | |
| merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing] | |
| filtered = {"labels": merged_labels, "tasks": merged_tasks} | |
| diag = "\n".join([ | |
| f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})", | |
| f"Model: {model_repo}", | |
| f"Input cleaned: {'Yes' if use_cleaning else 'No'}", | |
| f"Fallback rules: {'Yes' if use_fallback else 'No'}", | |
| f"SDPA attention: {'Yes' if use_sdpa else 'No'}", | |
| f"Tokens (input limit): ≤ {max_input_tokens}", | |
| f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms", | |
| f"Allowed labels: {', '.join(allowed)}", | |
| ]) | |
| labs = filtered.get("labels", []) | |
| tasks = filtered.get("tasks", []) | |
| summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)") | |
| if tasks: | |
| summary += "\n\nTasks:\n" + "\n".join( | |
| f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}" | |
| for t in tasks | |
| ) | |
| else: | |
| summary += "\n\nTasks: (none)" | |
| json_out = json.dumps(filtered, indent=2, ensure_ascii=False) | |
| metrics = "" | |
| if gt_json_file or (gt_json_text and gt_json_text.strip()): | |
| truth_obj = None | |
| if gt_json_file: | |
| truth_obj = read_json_file_any(gt_json_file) | |
| if (not truth_obj) and gt_json_text: | |
| try: | |
| truth_obj = json.loads(gt_json_text) | |
| except Exception: | |
| pass | |
| if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list): | |
| true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS] | |
| pred_labels = labs | |
| try: | |
| score = evaluate_predictions([true_labels], [pred_labels]) | |
| tp = len(set(true_labels) & set(pred_labels)) | |
| fp = len(set(pred_labels) - set(true_labels)) | |
| fn = len(set(true_labels) - set(pred_labels)) | |
| recall = tp / (tp + fn) if (tp + fn) else 1.0 | |
| precision = tp / (tp + fp) if (tp + fp) else 1.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0 | |
| metrics = ( | |
| f"Weighted score: {score:.3f}\n" | |
| f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n" | |
| f"TP={tp} FP={fp} FN={fn}\n" | |
| f"Truth: {', '.join(true_labels)}" | |
| ) | |
| except Exception as e: | |
| metrics = f"Scoring error: {e}" | |
| else: | |
| metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}." | |
| context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in DEFAULT_LABEL_GLOSSARY.items() if k in allowed) | |
| instructions_preview = "```\n" + (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS) + "\n```" | |
| return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text | |
| # ========================= | |
| # Batch mode | |
| # ========================= | |
| def read_zip_from_path(path: str, exdir: Path) -> List[Path]: | |
| exdir.mkdir(parents=True, exist_ok=True) | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| with zipfile.ZipFile(io.BytesIO(data)) as zf: | |
| zf.extractall(exdir) | |
| return [p for p in exdir.rglob("*") if p.is_file()] | |
| def run_batch( | |
| zip_path, | |
| use_cleaning: bool, | |
| use_fallback: bool, | |
| sys_instructions_text: str, | |
| glossary_json_text: str, | |
| fallback_json_text: str, | |
| model_repo: str, | |
| use_4bit: bool, | |
| use_sdpa: bool, | |
| max_input_tokens: int, | |
| hf_token: str, | |
| force_tok_redownload: bool, | |
| limit_files: int, | |
| ) -> Tuple[str, str, pd.DataFrame, str]: | |
| if not zip_path: | |
| return ("No ZIP provided.", "", pd.DataFrame(), "") | |
| try: | |
| sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS | |
| except Exception: | |
| sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS | |
| try: | |
| label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY | |
| except Exception: | |
| label_glossary = DEFAULT_LABEL_GLOSSARY | |
| try: | |
| fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES | |
| except Exception: | |
| fallback_cues = DEFAULT_FALLBACK_CUES | |
| work = Path("/tmp/batch") | |
| if work.exists(): | |
| for p in sorted(work.rglob("*"), reverse=True): | |
| try: p.unlink() | |
| except Exception: pass | |
| try: work.rmdir() | |
| except Exception: pass | |
| work.mkdir(parents=True, exist_ok=True) | |
| files = read_zip_from_path(zip_path, work) | |
| txts: Dict[str, Path] = {} | |
| gts: Dict[str, Path] = {} | |
| for p in files: | |
| if p.suffix.lower() == ".txt": | |
| txts[p.stem] = p | |
| elif p.suffix.lower() == ".json": | |
| gts[p.stem] = p | |
| stems = sorted(txts.keys()) | |
| if limit_files > 0: | |
| stems = stems[:limit_files] | |
| if not stems: | |
| return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "") | |
| try: | |
| model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa, force_tok_redownload) | |
| except Exception as e: | |
| return (f"Model load failed: {e}", "", pd.DataFrame(), "") | |
| allowed = OFFICIAL_LABELS[:] | |
| glossary_str = build_glossary_str(label_glossary, allowed) | |
| allowed_list_str = "\n".join(f"- {l}" for l in allowed) | |
| y_true, y_pred = [], [] | |
| rows = [] | |
| t_start = _now_ms() | |
| for stem in stems: | |
| raw = txts[stem].read_text(encoding="utf-8", errors="ignore") | |
| text = clean_transcript(raw) if use_cleaning else raw | |
| trunc = truncate_tokens(model.tokenizer, text, max_input_tokens) | |
| user_prompt = USER_PROMPT_TEMPLATE.format( | |
| transcript=trunc, | |
| allowed_labels_list=allowed_list_str, | |
| glossary=glossary_str, | |
| ) | |
| t0 = _now_ms() | |
| out = model.generate(sys_instructions, user_prompt) | |
| t1 = _now_ms() | |
| parsed = robust_json_extract(out) | |
| filtered = restrict_to_allowed(parsed, allowed) | |
| if use_fallback: | |
| fb = multilingual_fallback(trunc, allowed, fallback_cues) | |
| if fb["labels"]: | |
| merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"]))) | |
| existing = {tt.get("label") for tt in filtered.get("tasks", [])} | |
| merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing] | |
| filtered = {"labels": merged_labels, "tasks": merged_tasks} | |
| pred_labels = filtered.get("labels", []) | |
| y_pred.append(pred_labels) | |
| gt_labels = [] | |
| if stem in gts: | |
| try: | |
| gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore")) | |
| if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list): | |
| gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS] | |
| except Exception: | |
| pass | |
| y_true.append(gt_labels) | |
| gt_set, pr_set = set(gt_labels), set(pred_labels) | |
| tp = sorted(gt_set & pr_set) | |
| fp = sorted(pr_set - gt_set) | |
| fn = sorted(gt_set - pr_set) | |
| rows.append({ | |
| "file": stem, | |
| "true_labels": ", ".join(gt_labels), | |
| "pred_labels": ", ".join(pred_labels), | |
| "TP": len(tp), "FP": len(fp), "FN": len(fn), | |
| "gen_ms": t1 - t0 | |
| }) | |
| have_truth = any(len(v) > 0 for v in y_true) | |
| score = evaluate_predictions(y_true, y_pred) if have_truth else None | |
| df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"]) | |
| diag = [ | |
| f"Processed files: {len(stems)}", | |
| f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})", | |
| f"Model: {model_repo}", | |
| f"Fallback rules: {'Yes' if use_fallback else 'No'}", | |
| f"SDPA attention: {'Yes' if use_sdpa else 'No'}", | |
| f"Tokens (input limit): ≤ {max_input_tokens}", | |
| f"Batch time: {_now_ms()-t_start} ms", | |
| ] | |
| if have_truth and score is not None: | |
| total_tp = int(df["TP"].sum()) | |
| total_fp = int(df["FP"].sum()) | |
| total_fn = int(df["FN"].sum()) | |
| recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0 | |
| precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0 | |
| diag += [ | |
| f"Official weighted score (0–1): {score:.3f}", | |
| f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}", | |
| f"Total TP={total_tp} FP={total_fp} FN={total_fn}", | |
| ] | |
| diag_str = "\n".join(diag) | |
| out_csv = Path("/tmp/batch_results.csv") | |
| df.to_csv(out_csv, index=False, encoding="utf-8") | |
| return ("Batch done.", diag_str, df, str(out_csv)) | |
| # ========================= | |
| # UI | |
| # ========================= | |
| MODEL_CHOICES = [ | |
| "swiss-ai/Apertus-8B-Instruct-2509", | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| ] | |
| # White, modern UI (no purple) | |
| custom_css = """ | |
| :root { --radius: 14px; } | |
| .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; } | |
| .card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 14px 16px; background: #ffffff; box-shadow: 0 1px 2px rgba(0,0,0,.03); } | |
| .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; color: #0f172a; } | |
| .subtle { color: #475569; font-size: 14px; margin-bottom: 12px; } | |
| hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 10px 0 16px; } | |
| .gr-button { border-radius: 12px !important; } | |
| a, .prose a { color: #0ea5e9; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo: | |
| gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>") | |
| gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN). Optional rules fallback for recall. Batch evaluation included.</div>") | |
| with gr.Tab("Single transcript"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("<div class='card'><div class='header'>Transcript</div>") | |
| file = gr.File( | |
| label="Drag & drop transcript (.txt / .md / .json)", | |
| file_types=[".txt", ".md", ".json"], | |
| type="filepath", | |
| ) | |
| text = gr.Textbox(label="Or paste transcript", lines=10, placeholder="Paste transcript in DE/FR/IT/EN…") | |
| gr.Markdown("<hr class='sep'/>") | |
| gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>") | |
| gt_file = gr.File( | |
| label="Upload ground truth JSON (expects {'labels': [...]})", | |
| file_types=[".json"], | |
| type="filepath", | |
| ) | |
| gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}') | |
| gr.Markdown("</div>") # close card | |
| gr.Markdown("<div class='card'><div class='header'>Processing options</div>") | |
| use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", value=True) | |
| use_fallback = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True) | |
| gr.Markdown("</div>") | |
| gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>") | |
| labels_text = gr.Textbox(label="Allowed Labels (one per line)", value=OFFICIAL_LABELS_TEXT, lines=8) | |
| reset_btn = gr.Button("Reset to official labels") | |
| gr.Markdown("</div>") | |
| gr.Markdown("<div class='card'><div class='header'>Editable instructions & context</div>") | |
| sys_instr_tb = gr.Textbox(label="System Instructions (editable)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=5) | |
| glossary_tb = gr.Code(label="Label Glossary (JSON; editable)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json") | |
| fallback_tb = gr.Code(label="Fallback Cues (Multilingual, JSON; editable)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json") | |
| gr.Markdown("</div>") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<div class='card'><div class='header'>Model & run</div>") | |
| repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True) | |
| use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True) | |
| force_tok_redownload = gr.Checkbox(label="Force fresh tokenizer download", value=False) | |
| max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048) | |
| hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN","")) | |
| warm_btn = gr.Button("Warm up model (load & compile kernels)") | |
| run_btn = gr.Button("Run Extraction", variant="primary") | |
| gr.Markdown("</div>") | |
| gr.Markdown("<div class='card'><div class='header'>Outputs</div>") | |
| summary = gr.Textbox(label="Summary", lines=12) | |
| json_out = gr.Code(label="Strict JSON Output", language="json") | |
| diag = gr.Textbox(label="Diagnostics", lines=10) | |
| raw = gr.Textbox(label="Raw Model Output", lines=8) | |
| metrics_tb = gr.Textbox(label="Metrics vs Ground Truth (optional)", lines=6) | |
| prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown") | |
| token_info = gr.Textbox(label="Token counts (transcript / prompt / load path)", lines=2) | |
| gr.Markdown("</div>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Instructions used (system prompt)", open=False): | |
| instr_md = gr.Markdown("```\n" + DEFAULT_SYSTEM_INSTRUCTIONS + "\n```") | |
| with gr.Column(): | |
| with gr.Accordion("Context used (glossary)", open=True): | |
| context_md = gr.Markdown("") | |
| # Reset labels | |
| reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text) | |
| # Warm-up | |
| warm_btn.click( | |
| fn=warmup_model, | |
| inputs=[repo, use_4bit, use_sdpa, hf_token, force_tok_redownload], | |
| outputs=diag | |
| ) | |
| def _pack_context_md(glossary_json, allowed_text): | |
| try: | |
| glossary = json.loads(glossary_json) if glossary_json else DEFAULT_LABEL_GLOSSARY | |
| except Exception: | |
| glossary = DEFAULT_LABEL_GLOSSARY | |
| allowed_list = [ln.strip() for ln in (allowed_text or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()] | |
| return "### Label Glossary (used)\n" + "\n".join(f"- {k}: {glossary.get(k,'')}" for k in allowed_list) | |
| context_md.value = _pack_context_md(json.dumps(DEFAULT_LABEL_GLOSSARY), OFFICIAL_LABELS_TEXT) | |
| # Run single | |
| run_btn.click( | |
| fn=run_single, | |
| inputs=[ | |
| text, file, gt_text, gt_file, use_cleaning, use_fallback, | |
| labels_text, sys_instr_tb, glossary_tb, fallback_tb, | |
| repo, use_4bit, use_sdpa, max_tokens, hf_token, force_tok_redownload | |
| ], | |
| outputs=[summary, json_out, diag, raw, context_md, instr_md, metrics_tb, prompt_preview, token_info], | |
| ) | |
| with gr.Tab("Batch evaluation"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("<div class='card'><div class='header'>ZIP input</div>") | |
| zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath") | |
| use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True) | |
| use_fallback_b = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True) | |
| gr.Markdown("</div>") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<div class='card'><div class='header'>Model & run</div>") | |
| repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True) | |
| use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True) | |
| force_tok_redownload_b = gr.Checkbox(label="Force fresh tokenizer download", value=False) | |
| max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048) | |
| hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN","")) | |
| sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4) | |
| glossary_tb_b = gr.Code(label="Label Glossary (JSON; editable for batch)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json") | |
| fallback_tb_b = gr.Code(label="Fallback Cues (Multilingual, JSON; editable for batch)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json") | |
| limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0) | |
| run_batch_btn = gr.Button("Run Batch", variant="primary") | |
| gr.Markdown("</div>") | |
| with gr.Row(): | |
| gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>") | |
| status = gr.Textbox(label="Status", lines=1) | |
| diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12) | |
| df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False) | |
| csv_out = gr.File(label="Download CSV", interactive=False) | |
| gr.Markdown("</div>") | |
| run_batch_btn.click( | |
| fn=run_batch, | |
| inputs=[ | |
| zip_in, use_cleaning_b, use_fallback_b, | |
| sys_instr_tb_b, glossary_tb_b, fallback_tb_b, | |
| repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, force_tok_redownload_b, limit_files | |
| ], | |
| outputs=[status, diag_b, df_out, csv_out], | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| print("Torch:", torch.__version__) | |
| print("CUDA available:", torch.cuda.is_available()) | |
| if torch.cuda.is_available(): | |
| print("CUDA (compiled):", torch.version.cuda) | |
| print("Device:", torch.cuda.get_device_name(0)) | |
| except Exception: | |
| pass | |
| demo.launch() | |