RishiRP commited on
Commit
c99502f
·
verified ·
1 Parent(s): 41b65ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -154
app.py CHANGED
@@ -1,19 +1,44 @@
1
  import os
2
  import json
 
 
3
  import gradio as gr
4
  import torch
5
- from typing import Optional, Tuple, Dict, Any
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
7
 
8
  # =========================
9
- # Runtime / Model Defaults
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # =========================
11
- # Small, ungated default to avoid permission/download issues.
12
- # You can switch at runtime via the dropdown or set MODEL_ID env var.
13
- DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 
 
 
 
 
 
 
14
 
15
  def _has_bnb_and_cuda() -> bool:
16
- if not torch.cuda.is_available():
17
  return False
18
  try:
19
  import bitsandbytes as _bnb # noqa: F401
@@ -22,10 +47,9 @@ def _has_bnb_and_cuda() -> bool:
22
  return False
23
 
24
  USE_BNB = _has_bnb_and_cuda()
25
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # =========================
28
- # Model Load (safe + flexible)
29
  # =========================
30
  _tokenizer: Optional[AutoTokenizer] = None
31
  _model: Optional[AutoModelForCausalLM] = None
@@ -33,8 +57,8 @@ _current_model_id: Optional[str] = None
33
 
34
  def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
35
  """
36
- Loads (or reuses) a model/tokenizer. Uses bitsandbytes 4-bit only if
37
- CUDA is available AND bnb is installed. Otherwise plain CPU/GPU.
38
  """
39
  global _tokenizer, _model, _current_model_id
40
 
@@ -64,9 +88,9 @@ def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
64
  _tokenizer, _model, _current_model_id = tokenizer, model, model_id
65
  return tokenizer, model
66
 
67
- # ======================================
68
- # Helpers: Ingest TXT/JSON from Tabs box
69
- # ======================================
70
  def read_file(file_obj: Optional[gr.File]) -> Optional[str]:
71
  if not file_obj:
72
  return None
@@ -77,20 +101,81 @@ def read_file(file_obj: Optional[gr.File]) -> Optional[str]:
77
  return None
78
 
79
  def normalize_txt_input(paste_txt: str, upload_file: Optional[gr.File]) -> str:
80
- file_text = read_file(upload_file)
81
- if paste_txt and paste_txt.strip():
82
- return paste_txt
83
- return file_text or ""
84
 
85
  def normalize_json_input(paste_json: str, upload_file: Optional[gr.File]) -> str:
86
- file_text = read_file(upload_file)
87
- candidate = paste_json.strip() if paste_json else ""
88
- if not candidate and file_text:
89
- candidate = file_text
90
- return candidate
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # =========================
93
- # Core Extraction (placeholder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # =========================
95
  def run_extraction(
96
  model_choice: str,
@@ -104,58 +189,66 @@ def run_extraction(
104
  max_new_tokens: int,
105
  temperature: float,
106
  top_p: float,
 
107
  ) -> Tuple[str, str, str, str, str]:
108
- """
109
- Wire your real extraction here.
110
- Returns:
111
- tasks_out, entities_out, cleaned_out, summary_out, diagnostics
112
- """
113
  diagnostics_lines = []
114
 
115
- # Resolve inputs from single-box Tab controls
116
  input_txt = normalize_txt_input(txt_paste, txt_upload)
117
  input_json_raw = normalize_json_input(json_paste, json_upload)
118
 
119
- diagnostics_lines.append(f"Model: {model_choice}")
120
- diagnostics_lines.append(f"Params: {params_checked}")
121
- diagnostics_lines.append(f"Instructions length: {len(instructions_text)} chars")
122
- diagnostics_lines.append(f"Context length: {len(context_text)} chars")
123
- diagnostics_lines.append(f"TXT length: {len(input_txt)} chars")
124
-
125
- # Try parse JSON (optional)
126
  parsed_json: Dict[str, Any] = {}
 
127
  if input_json_raw:
128
  try:
129
  parsed_json = json.loads(input_json_raw)
130
- diagnostics_lines.append("JSON: parsed successfully")
131
  except Exception as e:
132
  diagnostics_lines.append(f"JSON parse error: {e}")
133
 
134
- # Load selected model (safe)
135
  try:
136
  tokenizer, model = load_model(model_choice)
137
  except Exception as e:
138
- # If model fails to load, still return diagnostics
139
- diag = "\n".join(diagnostics_lines + [f"Model load failed: {e}"])
 
 
 
 
 
 
140
  return "", "", "", "", diag
141
 
142
- # ---------- Dummy generation (replace with your real prompts) ----------
143
- # Build a prompt from inputs (very basic)
 
144
  user_prompt = (
145
- "You are an assistant that extracts tasks and entities.\n"
 
 
 
 
146
  f"Instructions: {instructions_text}\n"
147
  f"Context: {context_text}\n"
148
  "----\n"
149
  f"TEXT:\n{input_txt[:4000]}\n"
150
  "----\n"
151
  f"JSON:\n{json.dumps(parsed_json)[:2000]}\n"
152
- "Extract:\n- Tasks list\n- Entities list\n- Cleaned text (sanitized)\n- 1-2 line summary\n"
 
 
 
 
153
  )
 
154
 
155
  try:
156
  inputs = tokenizer(user_prompt, return_tensors="pt").to(DEVICE)
157
  with torch.no_grad():
158
- outputs = _model.generate(
159
  **inputs,
160
  max_new_tokens=max_new_tokens,
161
  do_sample=True,
@@ -165,184 +258,195 @@ def run_extraction(
165
  )
166
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
  except Exception as e:
168
- diag = "\n".join(diagnostics_lines + [f"Inference failed: {e}"])
 
 
 
 
 
 
 
169
  return "", "", "", "", diag
170
 
171
- # Very naive post-split (replace with your own structured parsing)
172
- tasks_out = "• Task 1\n• Task 2\n(Replace with your parser)"
173
- entities_out = "• Entity A\n• Entity B\n(Replace with your parser)"
174
- cleaned_out = "Cleaned text here… (Replace with your cleaning pipeline)"
175
- summary_out = "Short summary here… (Replace with your summarizer)"
176
-
177
- diagnostics_lines.append("Generation completed successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  diagnostics = "\n".join(diagnostics_lines)
179
 
180
  return tasks_out, entities_out, cleaned_out, summary_out, diagnostics
181
 
182
  # =========================
183
- # UI (Gradio Blocks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # =========================
185
  THEME_CSS = """
186
- /* Global colors: white background, black text */
187
  :root {
188
  --body-background-fill: #ffffff !important;
189
  --body-text-color: #111111 !important;
190
- --link-text-color: #0b63ce !important; /* blue */
191
- --shadow-spread: 0px;
192
- }
193
-
194
- /* Ensure all text is readable (black-ish) */
195
- .gradio-container, .prose, .prose * {
196
- color: #111111 !important;
197
  }
198
-
199
- /* Accent elements in blue (no purple) */
200
- label, .tabitem .label-wrap, .wrap .label-wrap {
201
- color: #0b63ce !important;
202
- }
203
-
204
- /* Cards / Boxes */
205
- .gr-box, .gr-panel, .gr-group, .gr-accordion {
206
- border: 1px solid #e5e7eb !important; /* light gray border */
207
- border-radius: 14px !important;
208
- }
209
-
210
- /* Red run button */
211
  button#run-btn {
212
  background: #e11900 !important;
213
- color: #ffffff !important;
214
  border: 1px solid #b50f00 !important;
215
  }
216
- button#run-btn:hover {
217
- filter: brightness(0.95);
218
- }
219
-
220
- /* Inputs layout polish */
221
- .input-card {
222
- padding: 10px;
223
- }
224
  """
225
 
 
 
 
226
  def build_interface() -> gr.Blocks:
227
  with gr.Blocks(title="Talk2Task Demo", css=THEME_CSS) as demo:
228
- # 1) MODEL SELECTION (full width) + checklist embedded
229
  with gr.Group():
230
- gr.Markdown("### Model & Parameters", elem_id="model-header")
231
- with gr.Row(equal_height=True):
232
  model_choice = gr.Dropdown(
233
  label="Model",
234
- choices=[
235
- DEFAULT_MODEL_ID,
236
- "mistralai/Mistral-7B-Instruct-v0.2",
237
- "meta-llama/Llama-3.1-8B-Instruct", # if accessible
238
- ],
239
  value=DEFAULT_MODEL_ID,
240
- scale=3
241
  )
242
  params_checked = gr.CheckboxGroup(
243
  label="Options",
244
  choices=[
245
  "Default cleaning",
246
  "Remove PII",
247
- "Allow 4-bit (if available)",
248
  "Detect language",
 
249
  ],
250
- value=["Default cleaning"],
251
- scale=2
252
  )
253
  with gr.Row():
254
- # generation controls (kept compact)
255
  temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
256
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
257
- max_new_tokens = gr.Slider(32, 1024, value=200, step=8, label="Max new tokens")
258
 
259
- # 2) SINGLE “BOX” PER TYPE via Tabs (Paste OR Drag & Drop) — side-by-side
260
- gr.Markdown("### Input", elem_id="input-header")
261
  with gr.Row(equal_height=True):
262
- with gr.Group(elem_classes=["input-card"]):
263
- gr.Markdown("**TXT Input** (Paste or Drag & Drop)", elem_id="txt-box-title")
264
  with gr.Tabs():
265
  with gr.TabItem("Paste"):
266
- txt_paste = gr.TextArea(
267
- label="Paste TXT",
268
- placeholder="Paste raw transcript or text here...",
269
- lines=12,
270
- )
271
- with gr.TabItem("Drag & Drop"):
272
- txt_upload = gr.File(
273
- label="Upload .txt file",
274
- file_types=[".txt"],
275
- )
276
-
277
- with gr.Group(elem_classes=["input-card"]):
278
- gr.Markdown("**JSON Input** (Paste or Drag & Drop)", elem_id="json-box-title")
279
  with gr.Tabs():
280
  with gr.TabItem("Paste"):
281
- json_paste = gr.Code(
282
- label="Paste JSON",
283
- language="json",
284
- value="{\n \"example\": true\n}",
285
- lines=12,
286
- )
287
- with gr.TabItem("Drag & Drop"):
288
- json_upload = gr.File(
289
- label="Upload .json file",
290
- file_types=[".json"],
291
- )
292
-
293
- # 3) RUN BUTTON (red), then collapsible Instructions & Context
294
- run_btn = gr.Button("Run Extraction", elem_id="run-btn", variant="primary")
295
 
 
 
 
 
296
  with gr.Row():
297
  with gr.Accordion("Instructions (editable)", open=False):
298
  instructions_text = gr.TextArea(
299
- label="Instructions",
300
  value=(
301
- "Extract tasks, entities, and a short summary. "
302
- "Apply default cleaning unless unchecked."
303
  ),
304
- lines=5,
305
  )
306
  with gr.Accordion("Context (editable)", open=False):
307
  context_text = gr.TextArea(
308
- label="Context",
309
  value=(
310
- "Use banking/consulting context if relevant. "
311
- "Prefer concise actionable phrasing."
312
  ),
313
- lines=5,
314
  )
315
 
316
- # 4) OUTPUT LAYOUT — symmetrical boxes
317
- gr.Markdown("### Results", elem_id="results-header")
318
  with gr.Row(equal_height=True):
319
- tasks_out = gr.TextArea(label="Tasks", lines=10)
320
- entities_out = gr.TextArea(label="Entities", lines=10)
321
  with gr.Row(equal_height=True):
322
- cleaned_out = gr.TextArea(label="Cleaned Text", lines=10)
323
- summary_out = gr.TextArea(label="Summary", lines=10)
 
 
 
324
 
325
- gr.Markdown("### Diagnostics", elem_id="diagnostics-header")
326
- diagnostics = gr.TextArea(label="Diagnostics / Logs", lines=10)
 
 
 
 
 
 
 
 
 
327
 
328
- # Wire up button
329
  run_inputs = [
330
  model_choice, params_checked, instructions_text, context_text,
331
  txt_paste, txt_upload, json_paste, json_upload,
332
- max_new_tokens, temperature, top_p
333
  ]
334
  run_outputs = [tasks_out, entities_out, cleaned_out, summary_out, diagnostics]
335
-
336
- run_btn.click(
337
- fn=run_extraction,
338
- inputs=run_inputs,
339
- outputs=run_outputs
340
- )
341
 
342
  return demo
343
 
344
  demo = build_interface()
345
 
346
  if __name__ == "__main__":
347
- # Let Gradio/Spaces choose host & port; this keeps local runs easy too.
348
  demo.launch()
 
1
  import os
2
  import json
3
+ from typing import Optional, Tuple, Dict, Any, List
4
+
5
  import gradio as gr
6
  import torch
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from langdetect import detect, DetectorFactory
9
+
10
+ # Make langdetect deterministic
11
+ DetectorFactory.seed = 7
12
 
13
  # =========================
14
+ # Challenge: allowed labels (from UBS repo)
15
+ # =========================
16
+ # Source: GitHub repo "From-Talk-to-Task-Insights-from-Client-Conversations"
17
+ ALLOWED_LABELS = [
18
+ "plan_contact",
19
+ "schedule_meeting",
20
+ "update_contact_info_non_postal",
21
+ "update_contact_info_postal_address",
22
+ "update_kyc_activity",
23
+ "update_kyc_origin_of_assets",
24
+ "update_kyc_purpose_of_businessrelation",
25
+ "update_kyc_total_assets",
26
+ ]
27
+
28
  # =========================
29
+ # Models / Defaults
30
+ # =========================
31
+ DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "Apertus/Apertus-8B")
32
+ SUPPORTED_MODELS = [
33
+ "Apertus/Apertus-8B",
34
+ "meta-llama/Meta-Llama-3-8B-Instruct",
35
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
36
+ ]
37
+
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
  def _has_bnb_and_cuda() -> bool:
41
+ if DEVICE != "cuda":
42
  return False
43
  try:
44
  import bitsandbytes as _bnb # noqa: F401
 
47
  return False
48
 
49
  USE_BNB = _has_bnb_and_cuda()
 
50
 
51
  # =========================
52
+ # Model cache
53
  # =========================
54
  _tokenizer: Optional[AutoTokenizer] = None
55
  _model: Optional[AutoModelForCausalLM] = None
 
57
 
58
  def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
59
  """
60
+ Loads (or reuses) a model/tokenizer.
61
+ Uses bitsandbytes 4-bit only if CUDA + bnb available; otherwise standard load.
62
  """
63
  global _tokenizer, _model, _current_model_id
64
 
 
88
  _tokenizer, _model, _current_model_id = tokenizer, model, model_id
89
  return tokenizer, model
90
 
91
+ # =========================
92
+ # Helpers
93
+ # =========================
94
  def read_file(file_obj: Optional[gr.File]) -> Optional[str]:
95
  if not file_obj:
96
  return None
 
101
  return None
102
 
103
  def normalize_txt_input(paste_txt: str, upload_file: Optional[gr.File]) -> str:
104
+ return paste_txt.strip() if (paste_txt and paste_txt.strip()) else (read_file(upload_file) or "")
 
 
 
105
 
106
  def normalize_json_input(paste_json: str, upload_file: Optional[gr.File]) -> str:
107
+ if paste_json and paste_json.strip():
108
+ return paste_json
109
+ return read_file(upload_file) or ""
110
+
111
+ def safe_lang_detect(text: str) -> str:
112
+ try:
113
+ if not text or not text.strip():
114
+ return "unknown"
115
+ return detect(text)
116
+ except Exception:
117
+ return "unknown"
118
+
119
+ def count_tokens(tokenizer: AutoTokenizer, text: str) -> int:
120
+ try:
121
+ return len(tokenizer(text, return_tensors=None).get("input_ids", []))
122
+ except Exception:
123
+ # Fallback rough estimate if tokenizer path fails
124
+ return max(1, len(text.split()))
125
 
126
  # =========================
127
+ # Evaluation function (from repo)
128
+ # =========================
129
+ # Source: UBS GitHub README "Evaluation" snippet (weighted FN/FP, custom penalties)
130
+ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
131
+ import numpy as np
132
+
133
+ LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
134
+ FN_PENALTY = 2.0
135
+ FP_PENALTY = 1.0
136
+
137
+ if len(y_true) != len(y_pred):
138
+ raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
139
+
140
+ n_samples = len(y_true)
141
+ n_labels = len(ALLOWED_LABELS)
142
+
143
+ y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
144
+ y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
145
+
146
+ def _process(sample_labels: List[str], sample_name: str) -> List[str]:
147
+ if not isinstance(sample_labels, list):
148
+ raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
149
+ seen = set()
150
+ valid = []
151
+ for lbl in sample_labels:
152
+ if not isinstance(lbl, str):
153
+ raise ValueError(f"{sample_name} contains non-string label: {lbl}")
154
+ if lbl in seen:
155
+ raise ValueError(f"{sample_name} contains duplicate label: '{lbl}'")
156
+ seen.add(lbl)
157
+ if lbl not in ALLOWED_LABELS:
158
+ raise ValueError(f"{sample_name} contains invalid label: '{lbl}'. Allowed: {ALLOWED_LABELS}")
159
+ valid.append(lbl)
160
+ return valid
161
+
162
+ for i, lbls in enumerate(y_true):
163
+ for lbl in _process(lbls, f"y_true[{i}]"):
164
+ y_true_binary[i, LABEL_TO_IDX[lbl]] = 1
165
+
166
+ for i, lbls in enumerate(y_pred):
167
+ for lbl in _process(lbls, f"y_pred[{i}]"):
168
+ y_pred_binary[i, LABEL_TO_IDX[lbl]] = 1
169
+
170
+ false_negatives = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
171
+ false_positives = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
172
+ weighted_errors = FN_PENALTY * false_negatives + FP_PENALTY * false_positives
173
+ max_errors_per_sample = FN_PENALTY * np.sum(y_true_binary, axis=1) + FP_PENALTY * (n_labels - np.sum(y_true_binary, axis=1))
174
+ per_sample_scores = np.where(max_errors_per_sample > 0, 1.0 - (weighted_errors / max_errors_per_sample), 1.0)
175
+ return float(np.mean(per_sample_scores))
176
+
177
+ # =========================
178
+ # Core Extraction
179
  # =========================
180
  def run_extraction(
181
  model_choice: str,
 
189
  max_new_tokens: int,
190
  temperature: float,
191
  top_p: float,
192
+ usd_per_1k_tokens: float,
193
  ) -> Tuple[str, str, str, str, str]:
 
 
 
 
 
194
  diagnostics_lines = []
195
 
196
+ # Resolve inputs from the unified boxes
197
  input_txt = normalize_txt_input(txt_paste, txt_upload)
198
  input_json_raw = normalize_json_input(json_paste, json_upload)
199
 
200
+ # Language detection & JSON parse
201
+ lang = safe_lang_detect(input_txt)
 
 
 
 
 
202
  parsed_json: Dict[str, Any] = {}
203
+ json_parse_ok = False
204
  if input_json_raw:
205
  try:
206
  parsed_json = json.loads(input_json_raw)
207
+ json_parse_ok = True
208
  except Exception as e:
209
  diagnostics_lines.append(f"JSON parse error: {e}")
210
 
211
+ # Load model
212
  try:
213
  tokenizer, model = load_model(model_choice)
214
  except Exception as e:
215
+ diag = "\n".join([
216
+ f"Model: {model_choice}",
217
+ f"Params: {params_checked}",
218
+ f"Language detected: {lang}",
219
+ f"TXT length: {len(input_txt)}",
220
+ f"JSON parsed: {json_parse_ok}",
221
+ f"Model load failed: {e}"
222
+ ])
223
  return "", "", "", "", diag
224
 
225
+ # Token counts & rough cost estimate
226
+ in_tokens = count_tokens(tokenizer, input_txt) + count_tokens(tokenizer, json.dumps(parsed_json) if parsed_json else "")
227
+ # Build multilingual-aware prompt (summary in English; extraction language-agnostic)
228
  user_prompt = (
229
+ "You analyze client-conversation transcripts.\n"
230
+ "Transcripts may be multilingual. Detect the language automatically. "
231
+ "Extract tasks and entities correctly regardless of language. "
232
+ "Always write the short summary in English.\n"
233
+ "Include only information present in the inputs; avoid hallucinations.\n"
234
  f"Instructions: {instructions_text}\n"
235
  f"Context: {context_text}\n"
236
  "----\n"
237
  f"TEXT:\n{input_txt[:4000]}\n"
238
  "----\n"
239
  f"JSON:\n{json.dumps(parsed_json)[:2000]}\n"
240
+ "Output:\n"
241
+ "- Tasks list (use allowed labels where possible)\n"
242
+ "- Entities list\n"
243
+ "- Cleaned text\n"
244
+ "- Short summary (English)\n"
245
  )
246
+ prompt_tokens = count_tokens(tokenizer, user_prompt)
247
 
248
  try:
249
  inputs = tokenizer(user_prompt, return_tensors="pt").to(DEVICE)
250
  with torch.no_grad():
251
+ outputs = model.generate(
252
  **inputs,
253
  max_new_tokens=max_new_tokens,
254
  do_sample=True,
 
258
  )
259
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
260
  except Exception as e:
261
+ diag = "\n".join([
262
+ f"Model: {model_choice}",
263
+ f"Params: {params_checked}",
264
+ f"Language detected: {lang}",
265
+ f"TXT length: {len(input_txt)}",
266
+ f"JSON parsed: {json_parse_ok}",
267
+ f"Inference failed: {e}"
268
+ ])
269
  return "", "", "", "", diag
270
 
271
+ # (Replace this with your structured parser that maps to ALLOWED_LABELS)
272
+ # For now, placeholders to keep UI working:
273
+ tasks_out = "• plan_contact\n• schedule_meeting"
274
+ entities_out = " Client: John Doe\n• Product: Mortgage"
275
+ cleaned_out = "Cleaned transcript text here…"
276
+ summary_out = "A short English summary of the conversation."
277
+
278
+ # Output token count and cost
279
+ out_tokens = count_tokens(tokenizer, full_text)
280
+ total_tokens = in_tokens + prompt_tokens + out_tokens
281
+ est_cost = (total_tokens / 1000.0) * max(0.0, float(usd_per_1k_tokens))
282
+
283
+ diagnostics_lines.extend([
284
+ f"Model: {model_choice}",
285
+ f"Params: {params_checked}",
286
+ f"Language detected: {lang}",
287
+ f"TXT length: {len(input_txt)}",
288
+ f"JSON parsed: {json_parse_ok}",
289
+ f"Input tokens (txt+json): {in_tokens}",
290
+ f"Prompt tokens: {prompt_tokens}",
291
+ f"Output tokens: {out_tokens}",
292
+ f"Total tokens (approx): {total_tokens}",
293
+ f"Est. cost @ ${usd_per_1k_tokens:.4f}/1k toks: ${est_cost:.6f}",
294
+ "Generation completed successfully.",
295
+ ])
296
  diagnostics = "\n".join(diagnostics_lines)
297
 
298
  return tasks_out, entities_out, cleaned_out, summary_out, diagnostics
299
 
300
  # =========================
301
+ # Evaluation handler (JSON arrays or files)
302
+ # =========================
303
+ def evaluate_ui(y_true_text: str, y_true_file: Optional[gr.File], y_pred_text: str, y_pred_file: Optional[gr.File]) -> str:
304
+ """
305
+ Accepts pasted JSON (e.g., [["plan_contact"], ["schedule_meeting", ...], ...])
306
+ or uploaded .json files for y_true and y_pred. Returns the score or an error.
307
+ """
308
+ def _load_json(text: str, file_obj: Optional[gr.File]) -> Any:
309
+ if text and text.strip():
310
+ return json.loads(text)
311
+ ftxt = read_file(file_obj)
312
+ if ftxt:
313
+ return json.loads(ftxt)
314
+ raise ValueError("Missing JSON input")
315
+
316
+ try:
317
+ y_true = _load_json(y_true_text, y_true_file)
318
+ y_pred = _load_json(y_pred_text, y_pred_file)
319
+ score = evaluate_predictions(y_true, y_pred)
320
+ return f"Evaluation score: {score:.4f} (higher is better; weighted FN>FP)"
321
+ except Exception as e:
322
+ return f"Evaluation error: {e}"
323
+
324
+ # =========================
325
+ # UI Styling (black text on white; blue accents; red Run)
326
  # =========================
327
  THEME_CSS = """
 
328
  :root {
329
  --body-background-fill: #ffffff !important;
330
  --body-text-color: #111111 !important;
331
+ --link-text-color: #0b63ce !important;
 
 
 
 
 
 
332
  }
333
+ .gradio-container, .prose, .prose * { color: #111111 !important; }
334
+ label { color: #0b63ce !important; }
 
 
 
 
 
 
 
 
 
 
 
335
  button#run-btn {
336
  background: #e11900 !important;
337
+ color: #fff !important;
338
  border: 1px solid #b50f00 !important;
339
  }
 
 
 
 
 
 
 
 
340
  """
341
 
342
+ # =========================
343
+ # UI Layout
344
+ # =========================
345
  def build_interface() -> gr.Blocks:
346
  with gr.Blocks(title="Talk2Task Demo", css=THEME_CSS) as demo:
347
+ # Model selection (full width) with checklist + sliders + price input
348
  with gr.Group():
349
+ gr.Markdown("### Model & Parameters")
350
+ with gr.Row():
351
  model_choice = gr.Dropdown(
352
  label="Model",
353
+ choices=SUPPORTED_MODELS,
 
 
 
 
354
  value=DEFAULT_MODEL_ID,
355
+ scale=3,
356
  )
357
  params_checked = gr.CheckboxGroup(
358
  label="Options",
359
  choices=[
360
  "Default cleaning",
361
  "Remove PII",
 
362
  "Detect language",
363
+ "Use 4-bit if available",
364
  ],
365
+ value=["Default cleaning", "Detect language"],
366
+ scale=2,
367
  )
368
  with gr.Row():
369
+ max_new_tokens = gr.Slider(64, 1024, value=200, step=16, label="Max new tokens")
370
  temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
371
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
372
+ usd_per_1k_tokens = gr.Number(value=0.002, label="Est. $ per 1k tokens (edit)")
373
 
374
+ # Single boxes for TXT and JSON via Tabs (left/right)
375
+ gr.Markdown("### Input")
376
  with gr.Row(equal_height=True):
377
+ with gr.Group():
378
+ gr.Markdown("**TXT Input** (Paste or Upload)")
379
  with gr.Tabs():
380
  with gr.TabItem("Paste"):
381
+ txt_paste = gr.TextArea(label="Paste TXT", lines=12, placeholder="Paste transcript here (any language)…")
382
+ with gr.TabItem("Upload"):
383
+ txt_upload = gr.File(label="Upload TXT", file_types=[".txt"])
384
+ with gr.Group():
385
+ gr.Markdown("**JSON Input** (Paste or Upload)")
 
 
 
 
 
 
 
 
386
  with gr.Tabs():
387
  with gr.TabItem("Paste"):
388
+ json_paste = gr.Code(label="Paste JSON", language="json", value="{\n \"example\": true\n}", lines=12)
389
+ with gr.TabItem("Upload"):
390
+ json_upload = gr.File(label="Upload JSON", file_types=[".json"])
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ # Red run button
393
+ run_btn = gr.Button("Run Extraction", elem_id="run-btn")
394
+
395
+ # Collapsible instructions/context (defaults aligned to brief)
396
  with gr.Row():
397
  with gr.Accordion("Instructions (editable)", open=False):
398
  instructions_text = gr.TextArea(
 
399
  value=(
400
+ "Extract key tasks (use allowed labels when applicable), entities, cleaned text, and a short summary.\n"
401
+ "Be robust to noisy or incomplete data. Avoid hallucinations."
402
  ),
403
+ lines=5
404
  )
405
  with gr.Accordion("Context (editable)", open=False):
406
  context_text = gr.TextArea(
 
407
  value=(
408
+ "Client-advisor banking context. Assume transcripts may include multiple languages; "
409
+ "summaries must be in English."
410
  ),
411
+ lines=5
412
  )
413
 
414
+ # Outputs (symmetrical)
415
+ gr.Markdown("### Results")
416
  with gr.Row(equal_height=True):
417
+ tasks_out = gr.TextArea(label="Tasks", lines=8)
418
+ entities_out = gr.TextArea(label="Entities", lines=8)
419
  with gr.Row(equal_height=True):
420
+ cleaned_out = gr.TextArea(label="Cleaned Text", lines=8)
421
+ summary_out = gr.TextArea(label="Summary (English)", lines=8)
422
+
423
+ gr.Markdown("### Diagnostics / Metrics")
424
+ diagnostics = gr.TextArea(label="Diagnostics", lines=12)
425
 
426
+ # Evaluation accordion (cost-accuracy comparison support)
427
+ with gr.Accordion("Evaluation (paste or upload y_true / y_pred arrays)", open=False):
428
+ with gr.Row():
429
+ y_true_text = gr.Code(label="y_true (JSON)", language="json", lines=10)
430
+ y_pred_text = gr.Code(label="y_pred (JSON)", language="json", lines=10)
431
+ with gr.Row():
432
+ y_true_file = gr.File(label="Upload y_true.json", file_types=[".json"])
433
+ y_pred_file = gr.File(label="Upload y_pred.json", file_types=[".json"])
434
+ eval_btn = gr.Button("Compute Official Score")
435
+ eval_result = gr.Textbox(label="Evaluation Result")
436
+ eval_btn.click(evaluate_ui, inputs=[y_true_text, y_true_file, y_pred_text, y_pred_file], outputs=eval_result)
437
 
438
+ # Wire main run
439
  run_inputs = [
440
  model_choice, params_checked, instructions_text, context_text,
441
  txt_paste, txt_upload, json_paste, json_upload,
442
+ max_new_tokens, temperature, top_p, usd_per_1k_tokens
443
  ]
444
  run_outputs = [tasks_out, entities_out, cleaned_out, summary_out, diagnostics]
445
+ run_btn.click(fn=run_extraction, inputs=run_inputs, outputs=run_outputs)
 
 
 
 
 
446
 
447
  return demo
448
 
449
  demo = build_interface()
450
 
451
  if __name__ == "__main__":
 
452
  demo.launch()