Talk2TaskDemo1 / app.py
RishiRP's picture
Update app.py
05188c4 verified
raw
history blame
8.95 kB
Allowed Labels:
{allowed_labels_list}
Output STRICT JSON only, no prose:
{{
"labels": ["LabelA","LabelB", ...],
"tasks": [
{{"label": "LabelA", "explanation": "…", "evidence": "…"}},
{{"label": "LabelB", "explanation": "…", "evidence": "…"}}
]
}}
"""
# =========================
# Utils
# =========================
def _now_ms(): return int(time.time() * 1000)
def read_file_to_text(file: gr.File) -> str:
if not file or not file.name:
return ""
name = file.name.lower()
data = file.read()
if name.endswith(".json"):
try:
obj = json.loads(data.decode("utf-8", errors="ignore"))
if isinstance(obj, dict) and "transcript" in obj:
return str(obj["transcript"])
return json.dumps(obj, ensure_ascii=False)
except Exception:
return data.decode("utf-8", errors="ignore")
else:
return data.decode("utf-8", errors="ignore")
def normalize_labels(labels: List[str]) -> List[str]:
return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
return {lab.lower(): lab for lab in allowed}
def robust_json_extract(text: str) -> Dict[str, Any]:
if not text:
return {"labels": [], "tasks": []}
start, end = text.find("{"), text.rfind("}")
candidate = text[start:end+1] if (start != -1 and end != -1) else text
try:
return json.loads(candidate)
except Exception:
candidate = re.sub(r",\s*}", "}", candidate)
candidate = re.sub(r",\s*]", "]", candidate)
try: return json.loads(candidate)
except Exception: return {"labels": [], "tasks": []}
def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
out = {"labels": [], "tasks": []}
allowed_map = canonicalize_map(allowed)
filt_labels = []
for l in pred.get("labels", []):
k = str(l).strip().lower()
if k in allowed_map: filt_labels.append(allowed_map[k])
filt_labels = normalize_labels(filt_labels)
filt_tasks = []
for t in pred.get("tasks", []):
if not isinstance(t, dict): continue
k = str(t.get("label", "")).strip().lower()
if k in allowed_map:
new_t = dict(t); new_t["label"] = allowed_map[k]
filt_tasks.append(new_t)
from_tasks = [tt["label"] for tt in filt_tasks]
merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
out["labels"], out["tasks"] = merged, filt_tasks
return out
def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
toks = tokenizer(text, add_special_tokens=False)["input_ids"]
if len(toks) <= max_tokens: return text
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
# =========================
# Model
# =========================
class ModelWrapper:
def __init__(self, repo_id, hf_token, load_in_4bit):
self.repo_id, self.hf_token, self.load_in_4bit = repo_id, hf_token, load_in_4bit
self.tokenizer, self.model = None, None
def load(self):
qcfg = None
if self.load_in_4bit and DEVICE == "cuda":
qcfg = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
trust_remote_code=True, use_fast=True,
)
if self.tokenizer.pad_token is None and self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None,
low_cpu_mem_usage=True, quantization_config=qcfg,
attn_implementation="sdpa",
)
@torch.inference_mode()
def generate(self, system_prompt, user_prompt):
if hasattr(self.tokenizer, "apply_chat_template"):
msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}]
inputs = self.tokenizer.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt")
inputs = inputs.to(self.model.device)
else:
text = f"<s>[SYSTEM]{system_prompt}[/SYSTEM][USER]{user_prompt}[/USER]"
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
out_ids = self.model.generate(**inputs, generation_config=GEN_CONFIG,
eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id)
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
def get_model(repo_id, hf_token, load_in_4bit):
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
if key not in _MODEL_CACHE:
m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load()
_MODEL_CACHE[key] = m
return _MODEL_CACHE[key]
# =========================
# Pipeline
# =========================
def run_extraction(text, file, labels_text, repo, use_4bit, max_tokens, hf_token):
t0 = _now_ms()
raw = read_file_to_text(file) if file else (text or "")
raw = raw.strip()
if not raw:
return "", "", "No transcript.", json.dumps({"labels":[], "tasks":[]}, indent=2)
user_labels = [ln.strip() for ln in (labels_text or "").splitlines() if ln.strip()]
allowed = normalize_labels(user_labels or DEFAULT_ALLOWED_LABELS)
try:
model = get_model(repo, hf_token.strip() or None, use_4bit)
except Exception as e:
return "", "", f"Model load failed: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2)
trunc = truncate_tokens(model.tokenizer, raw, max_tokens)
user_prompt = USER_PROMPT_TEMPLATE.format(transcript=trunc, allowed_labels_list="\n".join(f"- {l}" for l in allowed))
t1 = _now_ms()
try:
out = model.generate(SYSTEM_PROMPT, user_prompt)
except Exception as e:
return "", "", f"Gen error: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2)
t2 = _now_ms()
parsed = robust_json_extract(out)
filtered = restrict_to_allowed(parsed, allowed)
diag = "\n".join([
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
f"Model: {repo}",
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
f"Allowed labels: {', '.join(allowed)}"
])
summary = "Detected labels:\n" + "\n".join(f"- {l}" for l in filtered["labels"]) if filtered["labels"] else "Detected labels: (none)"
if filtered["tasks"]:
summary += "\n\nTasks:\n" + "\n".join(f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:100]}" for t in filtered["tasks"])
else:
summary += "\n\nTasks: (none)"
return summary, json.dumps(filtered, indent=2), diag, out.strip()
# =========================
# UI
# =========================
MODEL_CHOICES = [
"swiss-ai/Apertus-8B-Instruct-2509",
"meta-llama/Meta-Llama-3-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Talk2Task — Task Extraction Demo")
with gr.Row():
with gr.Column(scale=3):
file = gr.File(label="Drag & drop transcript (.txt/.md/.json)", file_types=[".txt",".md",".json"], type="filepath")
text = gr.Textbox(label="Or paste transcript", lines=12)
labels_text = gr.Textbox(label="Allowed Labels (one per line)", lines=8)
with gr.Column(scale=2):
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
btn = gr.Button("Run Extraction", variant="primary")
with gr.Row():
summary = gr.Textbox(label="Summary", lines=12)
json_out = gr.Code(label="JSON Output", language="json")
with gr.Row():
diag = gr.Textbox(label="Diagnostics", lines=6)
raw = gr.Textbox(label="Raw Model Output", lines=6)
btn.click(fn=run_extraction, inputs=[text,file,labels_text,repo,use_4bit,max_tokens,hf_token], outputs=[summary,json_out,diag,raw])
if __name__ == "__main__":
demo.launch()