Spaces:
Sleeping
Sleeping
| Allowed Labels: | |
| {allowed_labels_list} | |
| Output STRICT JSON only, no prose: | |
| {{ | |
| "labels": ["LabelA","LabelB", ...], | |
| "tasks": [ | |
| {{"label": "LabelA", "explanation": "…", "evidence": "…"}}, | |
| {{"label": "LabelB", "explanation": "…", "evidence": "…"}} | |
| ] | |
| }} | |
| """ | |
| # ========================= | |
| # Utils | |
| # ========================= | |
| def _now_ms(): return int(time.time() * 1000) | |
| def read_file_to_text(file: gr.File) -> str: | |
| if not file or not file.name: | |
| return "" | |
| name = file.name.lower() | |
| data = file.read() | |
| if name.endswith(".json"): | |
| try: | |
| obj = json.loads(data.decode("utf-8", errors="ignore")) | |
| if isinstance(obj, dict) and "transcript" in obj: | |
| return str(obj["transcript"]) | |
| return json.dumps(obj, ensure_ascii=False) | |
| except Exception: | |
| return data.decode("utf-8", errors="ignore") | |
| else: | |
| return data.decode("utf-8", errors="ignore") | |
| def normalize_labels(labels: List[str]) -> List[str]: | |
| return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()])) | |
| def canonicalize_map(allowed: List[str]) -> Dict[str, str]: | |
| return {lab.lower(): lab for lab in allowed} | |
| def robust_json_extract(text: str) -> Dict[str, Any]: | |
| if not text: | |
| return {"labels": [], "tasks": []} | |
| start, end = text.find("{"), text.rfind("}") | |
| candidate = text[start:end+1] if (start != -1 and end != -1) else text | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| candidate = re.sub(r",\s*}", "}", candidate) | |
| candidate = re.sub(r",\s*]", "]", candidate) | |
| try: return json.loads(candidate) | |
| except Exception: return {"labels": [], "tasks": []} | |
| def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]: | |
| out = {"labels": [], "tasks": []} | |
| allowed_map = canonicalize_map(allowed) | |
| filt_labels = [] | |
| for l in pred.get("labels", []): | |
| k = str(l).strip().lower() | |
| if k in allowed_map: filt_labels.append(allowed_map[k]) | |
| filt_labels = normalize_labels(filt_labels) | |
| filt_tasks = [] | |
| for t in pred.get("tasks", []): | |
| if not isinstance(t, dict): continue | |
| k = str(t.get("label", "")).strip().lower() | |
| if k in allowed_map: | |
| new_t = dict(t); new_t["label"] = allowed_map[k] | |
| filt_tasks.append(new_t) | |
| from_tasks = [tt["label"] for tt in filt_tasks] | |
| merged = normalize_labels(list(set(filt_labels) | set(from_tasks))) | |
| out["labels"], out["tasks"] = merged, filt_tasks | |
| return out | |
| def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str: | |
| toks = tokenizer(text, add_special_tokens=False)["input_ids"] | |
| if len(toks) <= max_tokens: return text | |
| return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True) | |
| # ========================= | |
| # Model | |
| # ========================= | |
| class ModelWrapper: | |
| def __init__(self, repo_id, hf_token, load_in_4bit): | |
| self.repo_id, self.hf_token, self.load_in_4bit = repo_id, hf_token, load_in_4bit | |
| self.tokenizer, self.model = None, None | |
| def load(self): | |
| qcfg = None | |
| if self.load_in_4bit and DEVICE == "cuda": | |
| qcfg = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE), | |
| trust_remote_code=True, use_fast=True, | |
| ) | |
| if self.tokenizer.pad_token is None and self.tokenizer.eos_token: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE), | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=True, quantization_config=qcfg, | |
| attn_implementation="sdpa", | |
| ) | |
| @torch.inference_mode() | |
| def generate(self, system_prompt, user_prompt): | |
| if hasattr(self.tokenizer, "apply_chat_template"): | |
| msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}] | |
| inputs = self.tokenizer.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt") | |
| inputs = inputs.to(self.model.device) | |
| else: | |
| text = f"<s>[SYSTEM]{system_prompt}[/SYSTEM][USER]{user_prompt}[/USER]" | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) | |
| with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")): | |
| out_ids = self.model.generate(**inputs, generation_config=GEN_CONFIG, | |
| eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id) | |
| return self.tokenizer.decode(out_ids[0], skip_special_tokens=True) | |
| _MODEL_CACHE: Dict[str, ModelWrapper] = {} | |
| def get_model(repo_id, hf_token, load_in_4bit): | |
| key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}" | |
| if key not in _MODEL_CACHE: | |
| m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load() | |
| _MODEL_CACHE[key] = m | |
| return _MODEL_CACHE[key] | |
| # ========================= | |
| # Pipeline | |
| # ========================= | |
| def run_extraction(text, file, labels_text, repo, use_4bit, max_tokens, hf_token): | |
| t0 = _now_ms() | |
| raw = read_file_to_text(file) if file else (text or "") | |
| raw = raw.strip() | |
| if not raw: | |
| return "", "", "No transcript.", json.dumps({"labels":[], "tasks":[]}, indent=2) | |
| user_labels = [ln.strip() for ln in (labels_text or "").splitlines() if ln.strip()] | |
| allowed = normalize_labels(user_labels or DEFAULT_ALLOWED_LABELS) | |
| try: | |
| model = get_model(repo, hf_token.strip() or None, use_4bit) | |
| except Exception as e: | |
| return "", "", f"Model load failed: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2) | |
| trunc = truncate_tokens(model.tokenizer, raw, max_tokens) | |
| user_prompt = USER_PROMPT_TEMPLATE.format(transcript=trunc, allowed_labels_list="\n".join(f"- {l}" for l in allowed)) | |
| t1 = _now_ms() | |
| try: | |
| out = model.generate(SYSTEM_PROMPT, user_prompt) | |
| except Exception as e: | |
| return "", "", f"Gen error: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2) | |
| t2 = _now_ms() | |
| parsed = robust_json_extract(out) | |
| filtered = restrict_to_allowed(parsed, allowed) | |
| diag = "\n".join([ | |
| f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})", | |
| f"Model: {repo}", | |
| f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms", | |
| f"Allowed labels: {', '.join(allowed)}" | |
| ]) | |
| summary = "Detected labels:\n" + "\n".join(f"- {l}" for l in filtered["labels"]) if filtered["labels"] else "Detected labels: (none)" | |
| if filtered["tasks"]: | |
| summary += "\n\nTasks:\n" + "\n".join(f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:100]}" for t in filtered["tasks"]) | |
| else: | |
| summary += "\n\nTasks: (none)" | |
| return summary, json.dumps(filtered, indent=2), diag, out.strip() | |
| # ========================= | |
| # UI | |
| # ========================= | |
| MODEL_CHOICES = [ | |
| "swiss-ai/Apertus-8B-Instruct-2509", | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Talk2Task — Task Extraction Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| file = gr.File(label="Drag & drop transcript (.txt/.md/.json)", file_types=[".txt",".md",".json"], type="filepath") | |
| text = gr.Textbox(label="Or paste transcript", lines=12) | |
| labels_text = gr.Textbox(label="Allowed Labels (one per line)", lines=8) | |
| with gr.Column(scale=2): | |
| repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True) | |
| max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096) | |
| hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN","")) | |
| btn = gr.Button("Run Extraction", variant="primary") | |
| with gr.Row(): | |
| summary = gr.Textbox(label="Summary", lines=12) | |
| json_out = gr.Code(label="JSON Output", language="json") | |
| with gr.Row(): | |
| diag = gr.Textbox(label="Diagnostics", lines=6) | |
| raw = gr.Textbox(label="Raw Model Output", lines=6) | |
| btn.click(fn=run_extraction, inputs=[text,file,labels_text,repo,use_4bit,max_tokens,hf_token], outputs=[summary,json_out,diag,raw]) | |
| if __name__ == "__main__": | |
| demo.launch() | |