RishiRP commited on
Commit
7a1eb70
·
verified ·
1 Parent(s): ef109ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -12
app.py CHANGED
@@ -25,7 +25,10 @@ from transformers import (
25
  SPACE_CACHE = Path.home() / ".cache" / "huggingface"
26
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
28
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
29
 
30
  GEN_CONFIG = GenerationConfig(
31
  temperature=0.0,
@@ -247,7 +250,34 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
247
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
248
 
249
  # =========================
250
- # HF model wrapper (robust: fast→slow tokenizer + load fallbacks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  # =========================
252
  class ModelWrapper:
253
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
@@ -261,8 +291,13 @@ class ModelWrapper:
261
  self.load_path = "uninitialized"
262
 
263
  def _load_tokenizer(self):
264
- fast_err = None
265
- tok = None
 
 
 
 
 
266
  common = dict(
267
  pretrained_model_name_or_path=self.repo_id,
268
  token=self.hf_token,
@@ -272,15 +307,36 @@ class ModelWrapper:
272
  force_download=True if self.force_tok_redownload else False,
273
  revision=None,
274
  )
 
 
 
 
275
  try:
276
- tok = AutoTokenizer.from_pretrained(use_fast=True, **common)
277
  except Exception as e:
278
- fast_err = e
 
 
 
279
  if tok is None:
280
- tok = AutoTokenizer.from_pretrained(use_fast=False, **common)
 
 
 
 
 
 
 
281
  if tok.pad_token is None and tok.eos_token:
282
  tok.pad_token = tok.eos_token
283
- return tok, fast_err
 
 
 
 
 
 
 
284
 
285
  def load(self):
286
  qcfg = None
@@ -292,7 +348,7 @@ class ModelWrapper:
292
  bnb_4bit_use_double_quant=True,
293
  )
294
 
295
- tok, fast_err = self._load_tokenizer()
296
 
297
  errors = []
298
  for desc, kwargs in [
@@ -331,13 +387,12 @@ class ModelWrapper:
331
  mdl = mdl.to(torch.device("cuda"))
332
  self.tokenizer = tok
333
  self.model = mdl
334
- self.load_path = desc + (" (fast tok)" if fast_err is None else " (slow tok)")
335
  return
336
  except Exception as e:
337
  errors.append(f"{desc}: {e}")
338
 
339
- extra = f"\nFast tokenizer error: {fast_err}" if fast_err else ""
340
- raise RuntimeError("All load attempts failed:\n" + "\n".join(errors) + extra)
341
 
342
  @torch.inference_mode()
343
  def generate(self, system_prompt: str, user_prompt: str) -> str:
@@ -778,6 +833,7 @@ MODEL_CHOICES = [
778
  "mistralai/Mistral-7B-Instruct-v0.3",
779
  ]
780
 
 
781
  custom_css = """
782
  :root { --radius: 14px; }
783
  .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
@@ -847,6 +903,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
847
  json_out = gr.Code(label="Strict JSON Output", language="json")
848
  diag = gr.Textbox(label="Diagnostics", lines=10)
849
  raw = gr.Textbox(label="Raw Model Output", lines=8)
 
850
  prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown")
851
  token_info = gr.Textbox(label="Token counts (transcript / prompt / load path)", lines=2)
852
  gr.Markdown("</div>")
@@ -887,7 +944,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
887
  labels_text, sys_instr_tb, glossary_tb, fallback_tb,
888
  repo, use_4bit, use_sdpa, max_tokens, hf_token, force_tok_redownload
889
  ],
890
- outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
891
  )
892
 
893
  with gr.Tab("Batch evaluation"):
 
25
  SPACE_CACHE = Path.home() / ".cache" / "huggingface"
26
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ # Force slow tokenizer path by default; avoids Rust tokenizer.json parsing issues
30
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
31
+ os.environ.setdefault("TOKENIZERS_PREFER_FAST", "false")
32
 
33
  GEN_CONFIG = GenerationConfig(
34
  temperature=0.0,
 
250
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
251
 
252
  # =========================
253
+ # Cache purge for fresh downloads
254
+ # =========================
255
+ def _purge_repo_from_cache(repo_id: str):
256
+ """Delete cached files of a specific repo to guarantee a fresh download."""
257
+ try:
258
+ base = SPACE_CACHE
259
+ safe = repo_id.replace("/", "--")
260
+ for p in base.glob(f"models--{safe}*"):
261
+ try:
262
+ if p.is_file():
263
+ p.unlink()
264
+ else:
265
+ for sub in sorted(p.rglob("*"), reverse=True):
266
+ try:
267
+ if sub.is_file() or sub.is_symlink():
268
+ sub.unlink()
269
+ else:
270
+ sub.rmdir()
271
+ except Exception:
272
+ pass
273
+ p.rmdir()
274
+ except Exception:
275
+ pass
276
+ except Exception:
277
+ pass
278
+
279
+ # =========================
280
+ # HF model wrapper (robust: slow tokenizer first + load fallbacks)
281
  # =========================
282
  class ModelWrapper:
283
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
 
291
  self.load_path = "uninitialized"
292
 
293
  def _load_tokenizer(self):
294
+ """
295
+ Prefer the slow (SentencePiece) tokenizer first to avoid Rust tokenizers JSON parsing.
296
+ If user asked to force fresh download, purge local cache first.
297
+ """
298
+ if self.force_tok_redownload:
299
+ _purge_repo_from_cache(self.repo_id)
300
+
301
  common = dict(
302
  pretrained_model_name_or_path=self.repo_id,
303
  token=self.hf_token,
 
307
  force_download=True if self.force_tok_redownload else False,
308
  revision=None,
309
  )
310
+
311
+ # 1) SLOW PATH FIRST
312
+ slow_err = None
313
+ tok = None
314
  try:
315
+ tok = AutoTokenizer.from_pretrained(use_fast=False, **common)
316
  except Exception as e:
317
+ slow_err = e
318
+
319
+ # 2) If slow somehow failed, try FAST as a last resort
320
+ fast_err = None
321
  if tok is None:
322
+ try:
323
+ tok = AutoTokenizer.from_pretrained(use_fast=True, **common)
324
+ except Exception as e:
325
+ fast_err = e
326
+
327
+ if tok is None:
328
+ raise RuntimeError(f"Tokenizer failed (slow: {slow_err}) (fast: {fast_err})")
329
+
330
  if tok.pad_token is None and tok.eos_token:
331
  tok.pad_token = tok.eos_token
332
+
333
+ # Tag which path we used
334
+ if slow_err is None:
335
+ self.load_path = "tok:SLOW"
336
+ else:
337
+ self.load_path = "tok:FAST"
338
+
339
+ return tok
340
 
341
  def load(self):
342
  qcfg = None
 
348
  bnb_4bit_use_double_quant=True,
349
  )
350
 
351
+ tok = self._load_tokenizer()
352
 
353
  errors = []
354
  for desc, kwargs in [
 
387
  mdl = mdl.to(torch.device("cuda"))
388
  self.tokenizer = tok
389
  self.model = mdl
390
+ self.load_path = f"{self.load_path} | {desc}"
391
  return
392
  except Exception as e:
393
  errors.append(f"{desc}: {e}")
394
 
395
+ raise RuntimeError("All load attempts failed:\n" + "\n".join(errors))
 
396
 
397
  @torch.inference_mode()
398
  def generate(self, system_prompt: str, user_prompt: str) -> str:
 
833
  "mistralai/Mistral-7B-Instruct-v0.3",
834
  ]
835
 
836
+ # White, modern UI (no purple)
837
  custom_css = """
838
  :root { --radius: 14px; }
839
  .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
 
903
  json_out = gr.Code(label="Strict JSON Output", language="json")
904
  diag = gr.Textbox(label="Diagnostics", lines=10)
905
  raw = gr.Textbox(label="Raw Model Output", lines=8)
906
+ metrics_tb = gr.Textbox(label="Metrics vs Ground Truth (optional)", lines=6)
907
  prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown")
908
  token_info = gr.Textbox(label="Token counts (transcript / prompt / load path)", lines=2)
909
  gr.Markdown("</div>")
 
944
  labels_text, sys_instr_tb, glossary_tb, fallback_tb,
945
  repo, use_4bit, use_sdpa, max_tokens, hf_token, force_tok_redownload
946
  ],
947
+ outputs=[summary, json_out, diag, raw, context_md, instr_md, metrics_tb, prompt_preview, token_info],
948
  )
949
 
950
  with gr.Tab("Batch evaluation"):