RishiRP commited on
Commit
05188c4
·
verified ·
1 Parent(s): b80450d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -302
app.py CHANGED
@@ -1,11 +1,8 @@
1
 
2
- Allowed Labels (strict, case-insensitive match; output must use canonical label text exactly):
3
  {allowed_labels_list}
4
 
5
- Instructions:
6
- 1) Extract every concrete task the advisor or client must take.
7
- 2) For each, choose ONE label from Allowed Labels (or leave empty if none match).
8
- 3) Output STRICT JSON only, no prose:
9
  {{
10
  "labels": ["LabelA","LabelB", ...],
11
  "tasks": [
@@ -16,405 +13,199 @@ Instructions:
16
  """
17
 
18
  # =========================
19
- # Utilities
20
  # =========================
21
- def _now_ms() -> int:
22
- return int(time.time() * 1000)
23
 
24
  def read_file_to_text(file: gr.File) -> str:
25
  if not file or not file.name:
26
  return ""
27
  name = file.name.lower()
28
  data = file.read()
29
- # Restrict to light parsers (txt/md/json) for speed/reliability
30
  if name.endswith(".json"):
31
  try:
32
  obj = json.loads(data.decode("utf-8", errors="ignore"))
33
- # Accept either {"transcript": "..."} or list/str
34
  if isinstance(obj, dict) and "transcript" in obj:
35
  return str(obj["transcript"])
36
  return json.dumps(obj, ensure_ascii=False)
37
  except Exception:
38
  return data.decode("utf-8", errors="ignore")
39
  else:
40
- # txt / md or anything texty
41
- try:
42
- return data.decode("utf-8", errors="ignore")
43
- except Exception:
44
- try:
45
- return data.decode("latin-1", errors="ignore")
46
- except Exception:
47
- return ""
48
 
49
  def normalize_labels(labels: List[str]) -> List[str]:
50
  return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
51
 
52
  def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
53
- """
54
- Build a case-insensitive map: lowercase -> canonical label
55
- """
56
- m = {}
57
- for lab in allowed:
58
- m[lab.lower()] = lab
59
- return m
60
 
61
  def robust_json_extract(text: str) -> Dict[str, Any]:
62
- """
63
- Try to parse strict JSON from model output.
64
- If the model added extra tokens, strip to first {...} block.
65
- """
66
  if not text:
67
  return {"labels": [], "tasks": []}
68
-
69
- # Find first JSON object
70
- start = text.find("{")
71
- end = text.rfind("}")
72
- if start != -1 and end != -1 and end > start:
73
- candidate = text[start : end + 1]
74
- else:
75
- candidate = text
76
-
77
- # Remove trailing junk commas and try json.loads
78
  try:
79
  return json.loads(candidate)
80
  except Exception:
81
- # Fallback: try to repair common issues
82
  candidate = re.sub(r",\s*}", "}", candidate)
83
  candidate = re.sub(r",\s*]", "]", candidate)
84
- try:
85
- return json.loads(candidate)
86
- except Exception:
87
- return {"labels": [], "tasks": []}
88
 
89
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
90
- """
91
- Keep only tasks whose label ∈ allowed; map case-insensitively to canonical.
92
- """
93
  out = {"labels": [], "tasks": []}
94
- if not isinstance(pred, dict):
95
- return out
96
- raw_labels = pred.get("labels", []) or []
97
- raw_tasks = pred.get("tasks", []) or []
98
-
99
  allowed_map = canonicalize_map(allowed)
100
-
101
- # Filter labels
102
- filt_labels: List[str] = []
103
- for l in raw_labels:
104
- if not isinstance(l, str):
105
- continue
106
- k = l.strip().lower()
107
- if k in allowed_map:
108
- filt_labels.append(allowed_map[k])
109
  filt_labels = normalize_labels(filt_labels)
110
-
111
- # Filter tasks
112
  filt_tasks = []
113
- for t in raw_tasks:
114
- if not isinstance(t, dict):
115
- continue
116
- lbl = t.get("label", "")
117
- k = str(lbl).strip().lower()
118
  if k in allowed_map:
119
- new_t = dict(t)
120
- new_t["label"] = allowed_map[k]
121
  filt_tasks.append(new_t)
122
-
123
- # Ensure labels reflect tasks (union)
124
- from_tasks = [tt["label"] for tt in filt_tasks if isinstance(tt.get("label"), str)]
125
  merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
126
-
127
- out["labels"] = merged
128
- out["tasks"] = filt_tasks
129
  return out
130
 
131
- def truncate_tokens(tokenizer, text: str, max_input_tokens: int) -> str:
132
- if max_input_tokens <= 0:
133
- return text
134
- toks = tokenizer(text, add_special_tokens=False, return_attention_mask=False, return_tensors=None)["input_ids"]
135
- if len(toks) <= max_input_tokens:
136
- return text
137
- # Keep the tail (most recent part of the convo often carries actionable tasks)
138
- keep_ids = toks[-max_input_tokens:]
139
- return tokenizer.decode(keep_ids, skip_special_tokens=True)
140
 
141
  # =========================
142
- # Model Loading
143
  # =========================
144
  class ModelWrapper:
145
- def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
146
- self.repo_id = repo_id
147
- self.hf_token = hf_token
148
- self.load_in_4bit = load_in_4bit
149
- self.tokenizer = None
150
- self.model = None
151
 
152
  def load(self):
153
  qcfg = None
154
  if self.load_in_4bit and DEVICE == "cuda":
155
  qcfg = BitsAndBytesConfig(
156
- load_in_4bit=True,
157
- bnb_4bit_quant_type="nf4",
158
  bnb_4bit_compute_dtype=torch.float16,
159
  bnb_4bit_use_double_quant=True,
160
  )
161
-
162
- tok = AutoTokenizer.from_pretrained(
163
- self.repo_id,
164
- token=self.hf_token,
165
- cache_dir=str(SPACE_CACHE),
166
- trust_remote_code=True,
167
- use_fast=True,
168
  )
169
- # Some models lack pad token—safe default
170
- if tok.pad_token is None and tok.eos_token is not None:
171
- tok.pad_token = tok.eos_token
172
-
173
- model = AutoModelForCausalLM.from_pretrained(
174
- self.repo_id,
175
- token=self.hf_token,
176
- cache_dir=str(SPACE_CACHE),
177
  trust_remote_code=True,
178
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
179
  device_map="auto" if DEVICE == "cuda" else None,
180
- low_cpu_mem_usage=True,
181
- quantization_config=qcfg,
182
- attn_implementation="sdpa", # T4-safe and faster than 'eager'
183
  )
184
- self.tokenizer = tok
185
- self.model = model
186
 
187
  @torch.inference_mode()
188
- def generate(self, system_prompt: str, user_prompt: str) -> str:
189
- # Chat template if available; otherwise a simple format
190
  if hasattr(self.tokenizer, "apply_chat_template"):
191
- messages = [
192
- {"role": "system", "content": system_prompt},
193
- {"role": "user", "content": user_prompt},
194
- ]
195
- input_ids = self.tokenizer.apply_chat_template(
196
- messages,
197
- add_generation_prompt=True,
198
- return_tensors="pt",
199
- ).to(self.model.device)
200
  else:
201
- text = f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n"
202
- input_ids = self.tokenizer(text, return_tensors="pt").to(self.model.device)
 
 
 
 
203
 
204
- with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
205
- out_ids = self.model.generate(
206
- **input_ids,
207
- generation_config=GEN_CONFIG,
208
- eos_token_id=self.tokenizer.eos_token_id,
209
- pad_token_id=self.tokenizer.pad_token_id,
210
- )
211
- out = self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
212
- # Heuristic: strip the prompting part if the model echoes input
213
- if "}" in out:
214
- tail = out[out.rfind("}") + 1 :]
215
- body = out[: out.rfind("}") + 1]
216
- # Prefer the last JSON object if multiple
217
- if "{" in tail and "}" in tail:
218
- # do nothing—rare; handled by robust_json_extract
219
- pass
220
- return body
221
- return out
222
-
223
- # Keep one live model per repo for snappy re-runs
224
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
225
-
226
- def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
227
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
228
  if key not in _MODEL_CACHE:
229
- mw = ModelWrapper(repo_id, hf_token, load_in_4bit)
230
- mw.load()
231
- _MODEL_CACHE[key] = mw
232
  return _MODEL_CACHE[key]
233
 
234
  # =========================
235
- # Inference Pipeline
236
  # =========================
237
- def run_extraction(
238
- transcript_text: str,
239
- transcript_file: gr.File,
240
- allowed_labels_text: str,
241
- model_repo: str,
242
- use_4bit: bool,
243
- max_input_tokens: int,
244
- hf_token: str,
245
- ) -> Tuple[str, str, str, str]:
246
-
247
  t0 = _now_ms()
 
 
 
 
248
 
249
- # 1) Get transcript: prefer file (drag-drop), else textarea
250
- raw_text = ""
251
- if transcript_file:
252
- raw_text = read_file_to_text(transcript_file)
253
- if not raw_text:
254
- raw_text = transcript_text or ""
255
- raw_text = raw_text.strip()
256
-
257
- if not raw_text:
258
- return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
259
-
260
- # 2) Allowed labels: combine UI text with default (so we NEVER end up empty)
261
- user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
262
- allowed = normalize_labels(user_allowed or DEFAULT_ALLOWED_LABELS)
263
 
264
- # 3) Load model
265
- hf_tok = hf_token.strip() or None
266
  try:
267
- model = get_model(model_repo, hf_tok, load_in_4bit=use_4bit)
268
  except Exception as e:
269
- msg = (
270
- f"Model load failed for '{model_repo}'. If gated/private, set HF_TOKEN in Space secrets.\n"
271
- f"Error: {e}"
272
- )
273
- return "", "", msg, json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
274
-
275
- # 4) Truncate input to speed up
276
- trunc_text = truncate_tokens(model.tokenizer, raw_text, max_input_tokens=max_input_tokens)
277
 
278
- # 5) Build prompts
279
- allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
280
- user_prompt = USER_PROMPT_TEMPLATE.format(
281
- transcript=trunc_text,
282
- allowed_labels_list=allowed_list_str,
283
- )
284
 
285
- # 6) Generate
286
  t1 = _now_ms()
287
  try:
288
- model_out = model.generate(SYSTEM_PROMPT, user_prompt)
289
  except Exception as e:
290
- return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
291
  t2 = _now_ms()
292
 
293
- # 7) Parse & filter strictly to allowed
294
- parsed = robust_json_extract(model_out)
295
  filtered = restrict_to_allowed(parsed, allowed)
296
 
297
- # 8) Compose UI outputs
298
- # Diagnostics
299
- diag = [
300
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
301
- f"Model: {model_repo}",
302
- f"Tokens (input, approx): {max_input_tokens}",
303
- f"Latency: load+prep {(t1 - t0)} ms, generate {(t2 - t1)} ms, total {(t2 - t0)} ms",
304
- f"Allowed Labels Used (n={len(allowed)}): {', '.join(allowed)}",
305
- ]
306
- diag_str = "\n".join(diag)
307
-
308
- # Summary plain text
309
- labs = filtered.get("labels", [])
310
- tasks = filtered.get("tasks", [])
311
- summ_lines = []
312
- if labs:
313
- summ_lines.append("Detected labels:\n - " + "\n - ".join(labs))
314
- else:
315
- summ_lines.append("Detected labels: (none)")
316
-
317
- if tasks:
318
- summ_lines.append("\nTasks:")
319
- for t in tasks:
320
- lab = t.get("label", "")
321
- expl = t.get("explanation", "")
322
- ev = t.get("evidence", "")
323
- summ_lines.append(f"• [{lab}] {expl} | evidence: {ev[:140]}{'…' if len(ev)>140 else ''}")
324
  else:
325
- summ_lines.append("\nTasks: (none)")
326
-
327
- summary = "\n".join(summ_lines)
328
-
329
- # JSON pretty
330
- json_str = json.dumps(filtered, ensure_ascii=False, indent=2)
331
-
332
- # Raw model text (to help debug label empty issues)
333
- raw_out = model_out.strip()
334
-
335
- return summary, json_str, diag_str, raw_out
336
 
337
  # =========================
338
  # UI
339
  # =========================
340
  MODEL_CHOICES = [
341
- "swiss-ai/Apertus-8B-Instruct-2509", # default
342
- "meta-llama/Meta-Llama-3-8B-Instruct", # may be gated; handled in code
343
- "mistralai/Mistral-7B-Instruct-v0.3", # widely available, strong baseline
344
  ]
345
 
346
- with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
347
  gr.Markdown("# Talk2Task — Task Extraction Demo")
348
- gr.Markdown(
349
- "Drop a transcript file **or** paste text, choose a model, and get strict JSON back. "
350
- "For best speed, keep inputs concise or lower the input token limit."
351
- )
352
 
353
  with gr.Row():
354
  with gr.Column(scale=3):
355
- transcript_file = gr.File(
356
- label="Drag & drop transcript (.txt / .md / .json)",
357
- file_types=[".txt", ".md", ".json"],
358
- type="filepath",
359
- )
360
- transcript_text = gr.Textbox(
361
- label="Or paste transcript here",
362
- lines=14,
363
- placeholder="Paste conversation transcript…",
364
- )
365
- allowed_labels_text = gr.Textbox(
366
- label="Allowed Labels (one per line) — leave empty to use defaults",
367
- value="",
368
- lines=8,
369
- )
370
  with gr.Column(scale=2):
371
- model_repo = gr.Dropdown(
372
- label="Model Repository",
373
- choices=MODEL_CHOICES,
374
- value=MODEL_CHOICES[0],
375
- )
376
- use_4bit = gr.Checkbox(
377
- label="Use 4-bit quantization (recommended on GPU/T4)",
378
- value=True,
379
- )
380
- max_input_tokens = gr.Slider(
381
- label="Max input tokens (truncate from end for speed)",
382
- minimum=1024,
383
- maximum=8192,
384
- step=512,
385
- value=4096,
386
- )
387
- hf_token = gr.Textbox(
388
- label="HF_TOKEN (only needed for gated/private models)",
389
- type="password",
390
- value=os.environ.get("HF_TOKEN", ""),
391
- )
392
- run_btn = gr.Button("Run Extraction", variant="primary")
393
 
394
  with gr.Row():
395
- with gr.Column():
396
- summary_out = gr.Textbox(label="Summary", lines=10)
397
- with gr.Column():
398
- json_out = gr.Code(label="Strict JSON Output", language="json")
399
  with gr.Row():
400
- with gr.Column():
401
- diag_out = gr.Textbox(label="Diagnostics & Timing", lines=8)
402
- with gr.Column():
403
- raw_out = gr.Textbox(label="Raw Model Output (debug)", lines=8)
404
 
405
- run_btn.click(
406
- fn=run_extraction,
407
- inputs=[
408
- transcript_text,
409
- transcript_file,
410
- allowed_labels_text,
411
- model_repo,
412
- use_4bit,
413
- max_input_tokens,
414
- hf_token,
415
- ],
416
- outputs=[summary_out, json_out, diag_out, raw_out],
417
- )
418
 
419
  if __name__ == "__main__":
420
  demo.launch()
 
1
 
2
+ Allowed Labels:
3
  {allowed_labels_list}
4
 
5
+ Output STRICT JSON only, no prose:
 
 
 
6
  {{
7
  "labels": ["LabelA","LabelB", ...],
8
  "tasks": [
 
13
  """
14
 
15
  # =========================
16
+ # Utils
17
  # =========================
18
+ def _now_ms(): return int(time.time() * 1000)
 
19
 
20
  def read_file_to_text(file: gr.File) -> str:
21
  if not file or not file.name:
22
  return ""
23
  name = file.name.lower()
24
  data = file.read()
 
25
  if name.endswith(".json"):
26
  try:
27
  obj = json.loads(data.decode("utf-8", errors="ignore"))
 
28
  if isinstance(obj, dict) and "transcript" in obj:
29
  return str(obj["transcript"])
30
  return json.dumps(obj, ensure_ascii=False)
31
  except Exception:
32
  return data.decode("utf-8", errors="ignore")
33
  else:
34
+ return data.decode("utf-8", errors="ignore")
 
 
 
 
 
 
 
35
 
36
  def normalize_labels(labels: List[str]) -> List[str]:
37
  return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
38
 
39
  def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
40
+ return {lab.lower(): lab for lab in allowed}
 
 
 
 
 
 
41
 
42
  def robust_json_extract(text: str) -> Dict[str, Any]:
 
 
 
 
43
  if not text:
44
  return {"labels": [], "tasks": []}
45
+ start, end = text.find("{"), text.rfind("}")
46
+ candidate = text[start:end+1] if (start != -1 and end != -1) else text
 
 
 
 
 
 
 
 
47
  try:
48
  return json.loads(candidate)
49
  except Exception:
 
50
  candidate = re.sub(r",\s*}", "}", candidate)
51
  candidate = re.sub(r",\s*]", "]", candidate)
52
+ try: return json.loads(candidate)
53
+ except Exception: return {"labels": [], "tasks": []}
 
 
54
 
55
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
 
 
 
56
  out = {"labels": [], "tasks": []}
 
 
 
 
 
57
  allowed_map = canonicalize_map(allowed)
58
+ filt_labels = []
59
+ for l in pred.get("labels", []):
60
+ k = str(l).strip().lower()
61
+ if k in allowed_map: filt_labels.append(allowed_map[k])
 
 
 
 
 
62
  filt_labels = normalize_labels(filt_labels)
 
 
63
  filt_tasks = []
64
+ for t in pred.get("tasks", []):
65
+ if not isinstance(t, dict): continue
66
+ k = str(t.get("label", "")).strip().lower()
 
 
67
  if k in allowed_map:
68
+ new_t = dict(t); new_t["label"] = allowed_map[k]
 
69
  filt_tasks.append(new_t)
70
+ from_tasks = [tt["label"] for tt in filt_tasks]
 
 
71
  merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
72
+ out["labels"], out["tasks"] = merged, filt_tasks
 
 
73
  return out
74
 
75
+ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
76
+ toks = tokenizer(text, add_special_tokens=False)["input_ids"]
77
+ if len(toks) <= max_tokens: return text
78
+ return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
 
 
 
 
 
79
 
80
  # =========================
81
+ # Model
82
  # =========================
83
  class ModelWrapper:
84
+ def __init__(self, repo_id, hf_token, load_in_4bit):
85
+ self.repo_id, self.hf_token, self.load_in_4bit = repo_id, hf_token, load_in_4bit
86
+ self.tokenizer, self.model = None, None
 
 
 
87
 
88
  def load(self):
89
  qcfg = None
90
  if self.load_in_4bit and DEVICE == "cuda":
91
  qcfg = BitsAndBytesConfig(
92
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
 
93
  bnb_4bit_compute_dtype=torch.float16,
94
  bnb_4bit_use_double_quant=True,
95
  )
96
+ self.tokenizer = AutoTokenizer.from_pretrained(
97
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
98
+ trust_remote_code=True, use_fast=True,
 
 
 
 
99
  )
100
+ if self.tokenizer.pad_token is None and self.tokenizer.eos_token:
101
+ self.tokenizer.pad_token = self.tokenizer.eos_token
102
+ self.model = AutoModelForCausalLM.from_pretrained(
103
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
 
 
 
 
104
  trust_remote_code=True,
105
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
106
  device_map="auto" if DEVICE == "cuda" else None,
107
+ low_cpu_mem_usage=True, quantization_config=qcfg,
108
+ attn_implementation="sdpa",
 
109
  )
 
 
110
 
111
  @torch.inference_mode()
112
+ def generate(self, system_prompt, user_prompt):
 
113
  if hasattr(self.tokenizer, "apply_chat_template"):
114
+ msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}]
115
+ inputs = self.tokenizer.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt")
116
+ inputs = inputs.to(self.model.device)
 
 
 
 
 
 
117
  else:
118
+ text = f"<s>[SYSTEM]{system_prompt}[/SYSTEM][USER]{user_prompt}[/USER]"
119
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
120
+ with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
121
+ out_ids = self.model.generate(**inputs, generation_config=GEN_CONFIG,
122
+ eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id)
123
+ return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
126
+ def get_model(repo_id, hf_token, load_in_4bit):
 
127
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
128
  if key not in _MODEL_CACHE:
129
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit); m.load()
130
+ _MODEL_CACHE[key] = m
 
131
  return _MODEL_CACHE[key]
132
 
133
  # =========================
134
+ # Pipeline
135
  # =========================
136
+ def run_extraction(text, file, labels_text, repo, use_4bit, max_tokens, hf_token):
 
 
 
 
 
 
 
 
 
137
  t0 = _now_ms()
138
+ raw = read_file_to_text(file) if file else (text or "")
139
+ raw = raw.strip()
140
+ if not raw:
141
+ return "", "", "No transcript.", json.dumps({"labels":[], "tasks":[]}, indent=2)
142
 
143
+ user_labels = [ln.strip() for ln in (labels_text or "").splitlines() if ln.strip()]
144
+ allowed = normalize_labels(user_labels or DEFAULT_ALLOWED_LABELS)
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
146
  try:
147
+ model = get_model(repo, hf_token.strip() or None, use_4bit)
148
  except Exception as e:
149
+ return "", "", f"Model load failed: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2)
 
 
 
 
 
 
 
150
 
151
+ trunc = truncate_tokens(model.tokenizer, raw, max_tokens)
152
+ user_prompt = USER_PROMPT_TEMPLATE.format(transcript=trunc, allowed_labels_list="\n".join(f"- {l}" for l in allowed))
 
 
 
 
153
 
 
154
  t1 = _now_ms()
155
  try:
156
+ out = model.generate(SYSTEM_PROMPT, user_prompt)
157
  except Exception as e:
158
+ return "", "", f"Gen error: {e}", json.dumps({"labels":[], "tasks":[]}, indent=2)
159
  t2 = _now_ms()
160
 
161
+ parsed = robust_json_extract(out)
 
162
  filtered = restrict_to_allowed(parsed, allowed)
163
 
164
+ diag = "\n".join([
 
 
165
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
166
+ f"Model: {repo}",
167
+ f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
168
+ f"Allowed labels: {', '.join(allowed)}"
169
+ ])
170
+ summary = "Detected labels:\n" + "\n".join(f"- {l}" for l in filtered["labels"]) if filtered["labels"] else "Detected labels: (none)"
171
+ if filtered["tasks"]:
172
+ summary += "\n\nTasks:\n" + "\n".join(f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:100]}" for t in filtered["tasks"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
+ summary += "\n\nTasks: (none)"
175
+ return summary, json.dumps(filtered, indent=2), diag, out.strip()
 
 
 
 
 
 
 
 
 
176
 
177
  # =========================
178
  # UI
179
  # =========================
180
  MODEL_CHOICES = [
181
+ "swiss-ai/Apertus-8B-Instruct-2509",
182
+ "meta-llama/Meta-Llama-3-8B-Instruct",
183
+ "mistralai/Mistral-7B-Instruct-v0.3",
184
  ]
185
 
186
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
187
  gr.Markdown("# Talk2Task — Task Extraction Demo")
 
 
 
 
188
 
189
  with gr.Row():
190
  with gr.Column(scale=3):
191
+ file = gr.File(label="Drag & drop transcript (.txt/.md/.json)", file_types=[".txt",".md",".json"], type="filepath")
192
+ text = gr.Textbox(label="Or paste transcript", lines=12)
193
+ labels_text = gr.Textbox(label="Allowed Labels (one per line)", lines=8)
 
 
 
 
 
 
 
 
 
 
 
 
194
  with gr.Column(scale=2):
195
+ repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
196
+ use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
197
+ max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
198
+ hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
199
+ btn = gr.Button("Run Extraction", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  with gr.Row():
202
+ summary = gr.Textbox(label="Summary", lines=12)
203
+ json_out = gr.Code(label="JSON Output", language="json")
 
 
204
  with gr.Row():
205
+ diag = gr.Textbox(label="Diagnostics", lines=6)
206
+ raw = gr.Textbox(label="Raw Model Output", lines=6)
 
 
207
 
208
+ btn.click(fn=run_extraction, inputs=[text,file,labels_text,repo,use_4bit,max_tokens,hf_token], outputs=[summary,json_out,diag,raw])
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  if __name__ == "__main__":
211
  demo.launch()