ruslanmv commited on
Commit
239225b
·
1 Parent(s): a197317

First commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. Makefile +20 -20
  3. app.py +82 -121
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /.venv
Makefile CHANGED
@@ -23,23 +23,23 @@ PORT ?= 7860
23
  # Core runtime deps (CPU-safe). Torch comes via transitive deps where needed;
24
  # you may pin torch externally if required by your environment.
25
  REQS = \
26
- "numpy<2" \
27
- "gradio==4.27.0" \
28
- "python-dotenv" \
29
- "huggingface_hub" \
30
- "ffmpeg-python" \
31
- "nltk" \
32
- "emoji" \
33
- "langid" \
34
- "noisereduce" \
35
- "TTS" \
36
- "llama-cpp-python>=0.2.90"
37
 
38
  # Dev tools (optional)
39
  DEV_REQS = \
40
- "ruff" \
41
- "black" \
42
- "pip-tools"
43
 
44
  # ================================================================
45
  # Meta
@@ -113,12 +113,12 @@ check-ffmpeg:
113
  # ================================================================
114
  # Pre-download model assets and compute voice latents (runs your app's functions)
115
  precache: install check-ffmpeg
116
- $(PY) - <<'PY'
117
- from app import precache_assets, init_models_and_latents
118
- precache_assets()
119
- init_models_and_latents()
120
- print("Precache complete.")
121
- PY
122
 
123
  run: install
124
  @echo "Starting app on port $(PORT)…"
 
23
  # Core runtime deps (CPU-safe). Torch comes via transitive deps where needed;
24
  # you may pin torch externally if required by your environment.
25
  REQS = \
26
+ "numpy<2" \
27
+ "gradio==4.27.0" \
28
+ "python-dotenv" \
29
+ "huggingface_hub" \
30
+ "ffmpeg-python" \
31
+ "nltk" \
32
+ "emoji" \
33
+ "langid" \
34
+ "noisereduce" \
35
+ "TTS" \
36
+ "llama-cpp-python>=0.2.90"
37
 
38
  # Dev tools (optional)
39
  DEV_REQS = \
40
+ "ruff" \
41
+ "black" \
42
+ "pip-tools"
43
 
44
  # ================================================================
45
  # Meta
 
113
  # ================================================================
114
  # Pre-download model assets and compute voice latents (runs your app's functions)
115
  precache: install check-ffmpeg
116
+ $(PY) - <<- 'PY'
117
+ from app import precache_assets, init_models_and_latents
118
+ precache_assets()
119
+ init_models_and_latents()
120
+ print("Precache complete.")
121
+ PY
122
 
123
  run: install
124
  @echo "Starting app on port $(PORT)…"
app.py CHANGED
@@ -2,7 +2,6 @@
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
-
6
  import os
7
  import sys
8
  import base64
@@ -10,34 +9,33 @@ import struct
10
  import textwrap
11
  import requests
12
  import atexit
13
- from typing import List, Dict, Tuple, Generator, Any
14
 
15
  # --- Fast, safe defaults ---
16
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
17
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
19
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
20
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks
21
 
22
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
23
  from dotenv import load_dotenv
24
  load_dotenv()
25
 
26
- # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) ---
27
  import numpy as _np
28
  if int(_np.__version__.split(".", 1)[0]) >= 2:
29
  raise RuntimeError(
30
- f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)."
31
  )
32
 
33
- # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
34
  try:
35
  import spaces # Required for ZeroGPU on HF
36
  except Exception:
37
  class _SpacesShim:
38
  def GPU(self, *args, **kwargs):
39
- def _wrap(fn):
40
- return fn
41
  return _wrap
42
  spaces = _SpacesShim()
43
 
@@ -49,7 +47,7 @@ import numpy as np
49
  from huggingface_hub import HfApi, hf_hub_download
50
  from llama_cpp import Llama
51
 
52
- # --- Audio decoding (pure ffmpeg-python; no torchaudio) ---
53
  import ffmpeg
54
 
55
  # --- TTS Libraries ---
@@ -64,7 +62,6 @@ import langid
64
  import emoji
65
  import noisereduce as nr
66
 
67
-
68
  # ===================================================================================
69
  # 2) GLOBALS & HELPERS
70
  # ===================================================================================
@@ -72,12 +69,10 @@ import noisereduce as nr
72
  # NLTK data
73
  nltk.download("punkt", quiet=True)
74
 
75
- # Models & caches
76
  tts_model: Xtts | None = None
77
  llm_model: Llama | None = None
78
-
79
- # Store latents as NumPy on CPU for portability; convert to device at inference time
80
- voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
81
 
82
  # Config
83
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -87,6 +82,9 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
87
  SENTENCE_SPLIT_LENGTH = 250
88
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
89
 
 
 
 
90
  # System prompts and roles
91
  default_system_message = (
92
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
@@ -100,25 +98,7 @@ ROLE_PROMPTS["Pirate"] = (
100
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
101
  )
102
 
103
-
104
- # ---------- tiny utilities ----------
105
- def _model_device(m: torch.nn.Module) -> torch.device:
106
- try:
107
- return next(m.parameters()).device
108
- except StopIteration:
109
- return torch.device("cpu")
110
-
111
- def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor:
112
- if isinstance(x, np.ndarray):
113
- return torch.from_numpy(x).float().to(device)
114
- if torch.is_tensor(x):
115
- return x.to(device, dtype=torch.float32)
116
- return torch.as_tensor(x, dtype=torch.float32, device=device)
117
-
118
- def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
119
- gpt_cond, spk = latents
120
- return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device)
121
-
122
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
123
  if pcm_data.startswith(b"RIFF"):
124
  return pcm_data
@@ -135,13 +115,13 @@ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit
135
 
136
  def split_sentences(text: str, max_len: int) -> List[str]:
137
  sentences = nltk.sent_tokenize(text)
138
- chunks: List[str] = []
139
- for sent in sentences:
140
- if len(sent) > max_len:
141
- chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
142
  else:
143
- chunks.append(sent)
144
- return chunks
145
 
146
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
147
  prompt = f"<|system|>\n{system_message}</s>"
@@ -151,7 +131,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
151
  prompt += f"<|user|>\n{message}</s><|assistant|>"
152
  return prompt
153
 
154
-
155
  # ---------- robust audio decode (mono via ffmpeg) ----------
156
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
157
  """
@@ -168,23 +147,17 @@ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
168
  pcm = np.frombuffer(out, dtype=np.int16)
169
  if pcm.size == 0:
170
  raise RuntimeError("ffmpeg produced empty audio.")
171
- wav = (pcm.astype(np.float32) / 32767.0)
172
- return wav
173
  except ffmpeg.Error as e:
174
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
175
 
176
-
177
- # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
178
  def _patched_load_audio(audiopath: str, load_sr: int):
179
  """
180
- Match XTTS' expected return type:
181
- - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
182
- already resampled to `load_sr`.
183
- - DO NOT return (audio, sr) tuple.
184
  """
185
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
186
- import torch as _torch # local import to avoid any circularities
187
- audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
188
  return audio
189
 
190
  xtts_module.load_audio = _patched_load_audio
@@ -194,14 +167,12 @@ try:
194
  except Exception:
195
  pass
196
 
197
-
198
  def _coqui_cache_dir() -> str:
199
- # Matches what TTS uses on Linux: ~/.local/share/tts
200
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
201
 
202
-
203
  # ===================================================================================
204
- # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues)
205
  # ===================================================================================
206
 
207
  def precache_assets() -> None:
@@ -234,9 +205,8 @@ def precache_assets() -> None:
234
  except Exception as e:
235
  print(f"Warning: GGUF pre-cache error: {e}")
236
 
237
-
238
- def _load_xtts(device: str = "cpu") -> Xtts:
239
- """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups."""
240
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
241
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
242
  ModelManager().download_model(model_name) # idempotent
@@ -255,20 +225,16 @@ def _load_xtts(device: str = "cpu") -> Xtts:
255
  print("XTTS model loaded.")
256
  return model
257
 
258
-
259
- def _load_llama() -> Llama:
260
- """
261
- Load Llama (Zephyr GGUF).
262
- Keep simple & robust: default to CPU (works everywhere).
263
- """
264
- print("Loading LLM (Zephyr GGUF)...")
265
  zephyr_model_path = hf_hub_download(
266
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
267
  filename="zephyr-7b-beta.Q5_K_M.gguf"
268
  )
269
  llm = Llama(
270
  model_path=zephyr_model_path,
271
- n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU
272
  n_ctx=4096,
273
  n_batch=512,
274
  verbose=False
@@ -276,27 +242,27 @@ def _load_llama() -> Llama:
276
  print("LLM loaded (CPU).")
277
  return llm
278
 
279
-
280
  def init_models_and_latents() -> None:
281
  """
282
- Preload models on CPU and compute voice latents on CPU.
283
- This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling.
284
  """
285
  global tts_model, llm_model, voice_latents
286
 
 
 
 
287
  if tts_model is None:
288
- tts_model = _load_xtts(device="cpu") # always CPU at startup
 
 
289
 
290
  if llm_model is None:
291
- llm_model = _load_llama()
292
 
 
293
  if not voice_latents:
294
  print("Computing voice conditioning latents (CPU)...")
295
- # Ensure the TTS model is on CPU while computing latents
296
- orig_dev = _model_device(tts_model)
297
- if orig_dev.type != "cpu":
298
- tts_model.to("cpu")
299
-
300
  with torch.no_grad():
301
  for role, filename in [
302
  ("Cloée", "cloee-1.wav"),
@@ -305,21 +271,11 @@ def init_models_and_latents() -> None:
305
  ("Thera", "thera-1.wav"),
306
  ]:
307
  path = os.path.join("voices", filename)
308
- gpt_lat, spk_emb = tts_model.get_conditioning_latents(
 
309
  audio_path=path, gpt_cond_len=30, max_ref_length=60
310
  )
311
- # Store as NumPy on CPU; convert to device on demand later
312
- voice_latents[role] = (
313
- gpt_lat.detach().cpu().numpy(),
314
- spk_emb.detach().cpu().numpy(),
315
- )
316
-
317
- # Return model to original device (keep CPU at startup for safety)
318
- if orig_dev.type != "cpu":
319
- tts_model.to(orig_dev)
320
-
321
- print("Voice latents ready.")
322
-
323
 
324
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
325
  def _close_llm():
@@ -331,7 +287,6 @@ def _close_llm():
331
  pass
332
  atexit.register(_close_llm)
333
 
334
-
335
  # ===================================================================================
336
  # 4) INFERENCE HELPERS
337
  # ===================================================================================
@@ -339,17 +294,17 @@ atexit.register(_close_llm)
339
  def generate_text_stream(llm_instance: Llama, prompt: str,
340
  history: List[Tuple[str, str | None]],
341
  system_message_text: str) -> Generator[str, None, None]:
342
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
343
  stream = llm_instance(
344
- formatted_prompt,
345
  temperature=0.7,
346
  max_tokens=512,
347
  top_p=0.95,
348
  stop=LLM_STOP_WORDS,
349
  stream=True
350
  )
351
- for response in stream:
352
- ch = response["choices"][0]["text"]
353
  try:
354
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
355
  except Exception:
@@ -358,29 +313,31 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
358
  continue
359
  yield ch
360
 
 
 
 
 
 
 
 
361
 
362
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
363
- latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
364
- # Convert stored CPU NumPy latents to tensors on the model's current device
365
- device = _model_device(tts_instance)
366
- gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device)
367
-
368
  try:
369
  for chunk in tts_instance.inference_stream(
370
  text=text,
371
  language=language,
372
- gpt_cond_latent=gpt_cond_latent_t,
373
- speaker_embedding=speaker_embedding_t,
374
  temperature=0.85,
375
  ):
376
  if chunk is None:
377
  continue
378
- # chunk: torch.FloatTensor [N] or [1, N], float32 in [-1, 1]
379
- f32 = chunk.detach().cpu().numpy().squeeze()
380
- f32 = np.clip(f32, -1.0, 1.0).astype(np.float32)
381
  s16 = (f32 * 32767.0).astype(np.int16)
382
  yield s16.tobytes()
383
-
384
  except RuntimeError as e:
385
  print(f"Error during TTS inference: {e}")
386
  if "device-side assert" in str(e) and api:
@@ -390,32 +347,34 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
390
  except Exception:
391
  pass
392
 
393
-
394
  # ===================================================================================
395
- # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well)
396
  # ===================================================================================
397
 
398
- @spaces.GPU(duration=120) # GPU ops must occur inside this function when on ZeroGPU
399
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
400
  if secret_token_input != SECRET_TOKEN:
401
  raise gr.Error("Invalid secret token provided.")
402
  if not input_text:
403
  return []
404
 
405
- # Ensure models/latents exist (loaded on CPU)
406
  if tts_model is None or llm_model is None or not voice_latents:
407
  init_models_and_latents()
408
 
409
- # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU
410
  try:
411
  if torch.cuda.is_available():
412
  tts_model.to("cuda")
 
413
  else:
414
  tts_model.to("cpu")
 
415
  except Exception:
416
  tts_model.to("cpu")
 
417
 
418
- # Generate story text (LLM kept CPU for simplicity & reliability)
419
  history: List[Tuple[str, str | None]] = [(input_text, None)]
420
  full_story_text = "".join(
421
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
@@ -432,10 +391,13 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
432
  if not any(c.isalnum() for c in sentence):
433
  continue
434
 
435
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
 
 
 
436
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
437
 
438
- # Optional noise reduction (best-effort)
439
  try:
440
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
441
  if data_s16.size > 0:
@@ -450,7 +412,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
450
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
451
  results.append({"text": sentence, "audio": b64_wav})
452
 
453
- # Leave model on CPU after the ZeroGPU window
454
  try:
455
  tts_model.to("cpu")
456
  except Exception:
@@ -458,7 +420,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
458
 
459
  return results
460
 
461
-
462
  # ===================================================================================
463
  # 6) STARTUP: PRECACHE & UI
464
  # ===================================================================================
@@ -473,16 +434,16 @@ def build_ui() -> gr.Interface:
473
  ],
474
  outputs=gr.JSON(label="Story and Audio Output"),
475
  title="AI Storyteller with ZeroGPU",
476
- description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.",
477
- flagging_mode="never",
478
  allow_flagging="never",
 
479
  )
480
 
481
  if __name__ == "__main__":
482
- print("===== Startup: pre-cache assets and preload models (CPU) =====")
483
- print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
484
- precache_assets() # 1) download everything to disk
485
- init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU
486
  print("Models and assets ready. Launching UI...")
487
 
488
  demo = build_ui()
 
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
 
5
  import os
6
  import sys
7
  import base64
 
9
  import textwrap
10
  import requests
11
  import atexit
12
+ from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
19
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
20
 
21
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
+ # --- NumPy sanity with torch 2.2.x ---
26
  import numpy as _np
27
  if int(_np.__version__.split(".", 1)[0]) >= 2:
28
  raise RuntimeError(
29
+ f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
30
  )
31
 
32
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
33
  try:
34
  import spaces # Required for ZeroGPU on HF
35
  except Exception:
36
  class _SpacesShim:
37
  def GPU(self, *args, **kwargs):
38
+ def _wrap(fn): return fn
 
39
  return _wrap
40
  spaces = _SpacesShim()
41
 
 
47
  from huggingface_hub import HfApi, hf_hub_download
48
  from llama_cpp import Llama
49
 
50
+ # --- Audio decode via ffmpeg-python (no torchaudio.load) ---
51
  import ffmpeg
52
 
53
  # --- TTS Libraries ---
 
62
  import emoji
63
  import noisereduce as nr
64
 
 
65
  # ===================================================================================
66
  # 2) GLOBALS & HELPERS
67
  # ===================================================================================
 
69
  # NLTK data
70
  nltk.download("punkt", quiet=True)
71
 
72
+ # Cached models & latents
73
  tts_model: Xtts | None = None
74
  llm_model: Llama | None = None
75
+ voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
 
 
76
 
77
  # Config
78
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
82
  SENTENCE_SPLIT_LENGTH = 250
83
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
84
 
85
+ # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
86
+ USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
87
+
88
  # System prompts and roles
89
  default_system_message = (
90
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
 
98
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
99
  )
100
 
101
+ # ---------- small utils ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
103
  if pcm_data.startswith(b"RIFF"):
104
  return pcm_data
 
115
 
116
  def split_sentences(text: str, max_len: int) -> List[str]:
117
  sentences = nltk.sent_tokenize(text)
118
+ out: List[str] = []
119
+ for s in sentences:
120
+ if len(s) > max_len:
121
+ out.extend(textwrap.wrap(s, max_len, break_long_words=True))
122
  else:
123
+ out.append(s)
124
+ return out
125
 
126
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
127
  prompt = f"<|system|>\n{system_message}</s>"
 
131
  prompt += f"<|user|>\n{message}</s><|assistant|>"
132
  return prompt
133
 
 
134
  # ---------- robust audio decode (mono via ffmpeg) ----------
135
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
136
  """
 
147
  pcm = np.frombuffer(out, dtype=np.int16)
148
  if pcm.size == 0:
149
  raise RuntimeError("ffmpeg produced empty audio.")
150
+ return (pcm.astype(np.float32) / 32767.0)
 
151
  except ffmpeg.Error as e:
152
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
153
 
154
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
 
155
  def _patched_load_audio(audiopath: str, load_sr: int):
156
  """
157
+ Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
 
 
 
158
  """
159
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
160
+ audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
 
161
  return audio
162
 
163
  xtts_module.load_audio = _patched_load_audio
 
167
  except Exception:
168
  pass
169
 
 
170
  def _coqui_cache_dir() -> str:
171
+ # Coqui cache default on Linux
172
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
173
 
 
174
  # ===================================================================================
175
+ # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
176
  # ===================================================================================
177
 
178
  def precache_assets() -> None:
 
205
  except Exception as e:
206
  print(f"Warning: GGUF pre-cache error: {e}")
207
 
208
+ def _load_xtts(device: str) -> Xtts:
209
+ """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
 
210
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
211
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
212
  ModelManager().download_model(model_name) # idempotent
 
225
  print("XTTS model loaded.")
226
  return model
227
 
228
+ def _load_llama_cpu_only() -> Llama:
229
+ """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
230
+ print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
231
  zephyr_model_path = hf_hub_download(
232
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
233
  filename="zephyr-7b-beta.Q5_K_M.gguf"
234
  )
235
  llm = Llama(
236
  model_path=zephyr_model_path,
237
+ n_gpu_layers=0, # never touch CUDA at startup
238
  n_ctx=4096,
239
  n_batch=512,
240
  verbose=False
 
242
  print("LLM loaded (CPU).")
243
  return llm
244
 
 
245
  def init_models_and_latents() -> None:
246
  """
247
+ Preload TTS and LLM on CPU and compute voice latents on CPU.
248
+ This avoids any CUDA tensors outside the @spaces.GPU window.
249
  """
250
  global tts_model, llm_model, voice_latents
251
 
252
+ # Always CPU here (ZeroGPU rule)
253
+ target_device = "cpu"
254
+
255
  if tts_model is None:
256
+ tts_model = _load_xtts(device=target_device)
257
+ else:
258
+ tts_model.to("cpu")
259
 
260
  if llm_model is None:
261
+ llm_model = _load_llama_cpu_only()
262
 
263
+ # Pre-compute latents once on CPU (uses our ffmpeg loader)
264
  if not voice_latents:
265
  print("Computing voice conditioning latents (CPU)...")
 
 
 
 
 
266
  with torch.no_grad():
267
  for role, filename in [
268
  ("Cloée", "cloee-1.wav"),
 
271
  ("Thera", "thera-1.wav"),
272
  ]:
273
  path = os.path.join("voices", filename)
274
+ # Returns torch tensors; keep them on CPU
275
+ voice_latents[role] = tts_model.get_conditioning_latents(
276
  audio_path=path, gpt_cond_len=30, max_ref_length=60
277
  )
278
+ print("Voice latents ready (CPU).")
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
281
  def _close_llm():
 
287
  pass
288
  atexit.register(_close_llm)
289
 
 
290
  # ===================================================================================
291
  # 4) INFERENCE HELPERS
292
  # ===================================================================================
 
294
  def generate_text_stream(llm_instance: Llama, prompt: str,
295
  history: List[Tuple[str, str | None]],
296
  system_message_text: str) -> Generator[str, None, None]:
297
+ formatted = format_prompt_zephyr(prompt, history, system_message_text)
298
  stream = llm_instance(
299
+ formatted,
300
  temperature=0.7,
301
  max_tokens=512,
302
  top_p=0.95,
303
  stop=LLM_STOP_WORDS,
304
  stream=True
305
  )
306
+ for resp in stream:
307
+ ch = resp["choices"][0]["text"]
308
  try:
309
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
310
  except Exception:
 
313
  continue
314
  yield ch
315
 
316
+ def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ g, s = latents
318
+ if isinstance(g, torch.Tensor):
319
+ g = g.to(device)
320
+ if isinstance(s, torch.Tensor):
321
+ s = s.to(device)
322
+ return g, s
323
 
324
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
325
+ latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
326
+ gpt_cond_latent, speaker_embedding = latents
 
 
 
327
  try:
328
  for chunk in tts_instance.inference_stream(
329
  text=text,
330
  language=language,
331
+ gpt_cond_latent=gpt_cond_latent,
332
+ speaker_embedding=speaker_embedding,
333
  temperature=0.85,
334
  ):
335
  if chunk is None:
336
  continue
337
+ f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
338
+ f32 = np.clip(f32, -1.0, 1.0)
 
339
  s16 = (f32 * 32767.0).astype(np.int16)
340
  yield s16.tobytes()
 
341
  except RuntimeError as e:
342
  print(f"Error during TTS inference: {e}")
343
  if "device-side assert" in str(e) and api:
 
347
  except Exception:
348
  pass
349
 
 
350
  # ===================================================================================
351
+ # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
352
  # ===================================================================================
353
 
354
+ @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
355
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
356
  if secret_token_input != SECRET_TOKEN:
357
  raise gr.Error("Invalid secret token provided.")
358
  if not input_text:
359
  return []
360
 
361
+ # Ensure models/latents exist (CPU)
362
  if tts_model is None or llm_model is None or not voice_latents:
363
  init_models_and_latents()
364
 
365
+ # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
366
  try:
367
  if torch.cuda.is_available():
368
  tts_model.to("cuda")
369
+ device = torch.device("cuda")
370
  else:
371
  tts_model.to("cpu")
372
+ device = torch.device("cpu")
373
  except Exception:
374
  tts_model.to("cpu")
375
+ device = torch.device("cpu")
376
 
377
+ # Generate story text (LLM on CPU)
378
  history: List[Tuple[str, str | None]] = [(input_text, None)]
379
  full_story_text = "".join(
380
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
 
391
  if not any(c.isalnum() for c in sentence):
392
  continue
393
 
394
+ # Move cached latents to the same device as the model for this call
395
+ lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
396
+
397
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
398
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
399
 
400
+ # Optional noise reduction (best-effort, CPU)
401
  try:
402
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
403
  if data_s16.size > 0:
 
412
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
413
  results.append({"text": sentence, "audio": b64_wav})
414
 
415
+ # Return XTTS to CPU to release GPU instantly
416
  try:
417
  tts_model.to("cpu")
418
  except Exception:
 
420
 
421
  return results
422
 
 
423
  # ===================================================================================
424
  # 6) STARTUP: PRECACHE & UI
425
  # ===================================================================================
 
434
  ],
435
  outputs=gr.JSON(label="Story and Audio Output"),
436
  title="AI Storyteller with ZeroGPU",
437
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
 
438
  allow_flagging="never",
439
+ analytics_enabled=False,
440
  )
441
 
442
  if __name__ == "__main__":
443
+ print("===== Startup: pre-cache assets and preload models =====")
444
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
445
+ precache_assets() # 1) download everything to disk
446
+ init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
447
  print("Models and assets ready. Launching UI...")
448
 
449
  demo = build_ui()