RishiRP commited on
Commit
e84ddb8
·
verified ·
1 Parent(s): 2f8734d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -17,7 +17,9 @@ from transformers import (
17
  AutoModelForCausalLM,
18
  BitsAndBytesConfig,
19
  GenerationConfig,
 
20
  )
 
21
 
22
  # =========================
23
  # Global config
@@ -26,7 +28,7 @@ SPACE_CACHE = Path.home() / ".cache" / "huggingface"
26
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # Force slow tokenizer path by default; avoids Rust tokenizer.json parsing issues
30
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
31
  os.environ.setdefault("TOKENIZERS_PREFER_FAST", "false")
32
 
@@ -72,7 +74,7 @@ DEFAULT_LABEL_GLOSSARY = {
72
  "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
73
  }
74
 
75
- # Tiny multilingual fallback rules (optional) to avoid empty outputs
76
  DEFAULT_FALLBACK_CUES = {
77
  "plan_contact": [
78
  r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
@@ -253,7 +255,6 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
253
  # Cache purge for fresh downloads
254
  # =========================
255
  def _purge_repo_from_cache(repo_id: str):
256
- """Delete cached files of a specific repo to guarantee a fresh download."""
257
  try:
258
  base = SPACE_CACHE
259
  safe = repo_id.replace("/", "--")
@@ -277,7 +278,7 @@ def _purge_repo_from_cache(repo_id: str):
277
  pass
278
 
279
  # =========================
280
- # HF model wrapper (robust: slow tokenizer first + load fallbacks)
281
  # =========================
282
  class ModelWrapper:
283
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
@@ -290,52 +291,52 @@ class ModelWrapper:
290
  self.model = None
291
  self.load_path = "uninitialized"
292
 
293
- def _load_tokenizer(self):
294
- """
295
- Prefer the slow (SentencePiece) tokenizer first to avoid Rust tokenizers JSON parsing.
296
- If user asked to force fresh download, purge local cache first.
297
- """
298
- if self.force_tok_redownload:
299
- _purge_repo_from_cache(self.repo_id)
300
-
301
- common = dict(
302
- pretrained_model_name_or_path=self.repo_id,
303
  token=self.hf_token,
304
  cache_dir=str(SPACE_CACHE),
305
  trust_remote_code=True,
306
  local_files_only=False,
307
  force_download=True if self.force_tok_redownload else False,
308
- revision=None,
309
  )
310
 
311
- # 1) SLOW PATH FIRST
312
- slow_err = None
313
- tok = None
314
- try:
315
- tok = AutoTokenizer.from_pretrained(use_fast=False, **common)
316
- except Exception as e:
317
- slow_err = e
318
 
319
- # 2) If slow somehow failed, try FAST as a last resort
320
- fast_err = None
321
- if tok is None:
322
- try:
323
- tok = AutoTokenizer.from_pretrained(use_fast=True, **common)
324
- except Exception as e:
325
- fast_err = e
326
 
327
- if tok is None:
328
- raise RuntimeError(f"Tokenizer failed (slow: {slow_err}) (fast: {fast_err})")
 
 
 
 
 
 
 
329
 
 
 
 
 
 
 
 
 
 
 
330
  if tok.pad_token is None and tok.eos_token:
331
  tok.pad_token = tok.eos_token
332
-
333
- # Tag which path we used
334
- if slow_err is None:
335
- self.load_path = "tok:SLOW"
336
- else:
337
- self.load_path = "tok:FAST"
338
-
339
  return tok
340
 
341
  def load(self):
@@ -989,14 +990,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
989
  )
990
 
991
  if __name__ == "__main__":
992
- # Optional: print environment info to logs
993
  try:
994
- print("Torch version:", torch.__version__)
995
  print("CUDA available:", torch.cuda.is_available())
996
  if torch.cuda.is_available():
997
  print("CUDA (compiled):", torch.version.cuda)
998
  print("Device:", torch.cuda.get_device_name(0))
999
- except Exception as _:
1000
  pass
1001
 
1002
  demo.launch()
 
17
  AutoModelForCausalLM,
18
  BitsAndBytesConfig,
19
  GenerationConfig,
20
+ LlamaTokenizer, # manual fallback
21
  )
22
+ from huggingface_hub import hf_hub_download
23
 
24
  # =========================
25
  # Global config
 
28
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
+ # Force slow path by default; avoid Rust tokenizer JSON parsing
32
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
33
  os.environ.setdefault("TOKENIZERS_PREFER_FAST", "false")
34
 
 
74
  "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
75
  }
76
 
77
+ # Minimal multilingual fallback rules (optional)
78
  DEFAULT_FALLBACK_CUES = {
79
  "plan_contact": [
80
  r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
 
255
  # Cache purge for fresh downloads
256
  # =========================
257
  def _purge_repo_from_cache(repo_id: str):
 
258
  try:
259
  base = SPACE_CACHE
260
  safe = repo_id.replace("/", "--")
 
278
  pass
279
 
280
  # =========================
281
+ # HF model wrapper (with manual LlamaTokenizer fallback)
282
  # =========================
283
  class ModelWrapper:
284
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool, force_tok_redownload: bool):
 
291
  self.model = None
292
  self.load_path = "uninitialized"
293
 
294
+ def _try_auto_tokenizer(self, use_fast: bool):
295
+ return AutoTokenizer.from_pretrained(
296
+ self.repo_id,
 
 
 
 
 
 
 
297
  token=self.hf_token,
298
  cache_dir=str(SPACE_CACHE),
299
  trust_remote_code=True,
300
  local_files_only=False,
301
  force_download=True if self.force_tok_redownload else False,
302
+ use_fast=use_fast,
303
  )
304
 
305
+ def _try_manual_llama_tokenizer(self):
306
+ # Download only tokenizer.model; ignore tokenizer.json entirely
307
+ sp_path = hf_hub_download(repo_id=self.repo_id, filename="tokenizer.model", token=self.hf_token, cache_dir=str(SPACE_CACHE))
308
+ tok = LlamaTokenizer(vocab_file=sp_path)
309
+ if tok.pad_token is None and tok.eos_token:
310
+ tok.pad_token = tok.eos_token
311
+ return tok
312
 
313
+ def _load_tokenizer(self):
314
+ if self.force_tok_redownload:
315
+ _purge_repo_from_cache(self.repo_id)
 
 
 
 
316
 
317
+ # 1) Slow auto
318
+ try:
319
+ tok = self._try_auto_tokenizer(use_fast=False)
320
+ if tok.pad_token is None and tok.eos_token:
321
+ tok.pad_token = tok.eos_token
322
+ self.load_path = "tok:AUTO_SLOW"
323
+ return tok
324
+ except Exception:
325
+ pass
326
 
327
+ # 2) Manual LlamaTokenizer from tokenizer.model
328
+ try:
329
+ tok = self._try_manual_llama_tokenizer()
330
+ self.load_path = "tok:LLAMA_SPM"
331
+ return tok
332
+ except Exception:
333
+ pass
334
+
335
+ # 3) Fast auto (last resort)
336
+ tok = self._try_auto_tokenizer(use_fast=True) # will raise if broken
337
  if tok.pad_token is None and tok.eos_token:
338
  tok.pad_token = tok.eos_token
339
+ self.load_path = "tok:AUTO_FAST"
 
 
 
 
 
 
340
  return tok
341
 
342
  def load(self):
 
990
  )
991
 
992
  if __name__ == "__main__":
 
993
  try:
994
+ print("Torch:", torch.__version__)
995
  print("CUDA available:", torch.cuda.is_available())
996
  if torch.cuda.is_available():
997
  print("CUDA (compiled):", torch.version.cuda)
998
  print("Device:", torch.cuda.get_device_name(0))
999
+ except Exception:
1000
  pass
1001
 
1002
  demo.launch()