ruslanmv commited on
Commit
ead9609
·
1 Parent(s): e1e4a12

First commit

Browse files
Files changed (2) hide show
  1. README.md +2 -1
  2. app.py +191 -84
README.md CHANGED
@@ -6,7 +6,8 @@ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.47.2
8
  app_file: app.py
 
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://github.com/ruslanmv/ai-story-server
 
6
  sdk: gradio
7
  sdk_version: 5.47.2
8
  app_file: app.py
9
+ python_version: "3.11"
10
  pinned: false
11
  ---
12
 
13
+ Check out the configuration reference at https://github.com/ruslanmv/ai-story-server
app.py CHANGED
@@ -3,6 +3,7 @@
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
 
6
  import base64
7
  import struct
8
  import textwrap
@@ -14,31 +15,23 @@ from typing import List, Dict, Tuple, Generator
14
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
 
18
 
19
- # --- Prefer torchaudio sox_io/soundfile backend (avoid FFmpeg/torio bug) ---
20
- try:
21
- import torchaudio
22
- _backend_set = False
23
- for _cand in ("sox_io", "soundfile"):
24
- try:
25
- torchaudio.set_audio_backend(_cand)
26
- _backend_set = True
27
- break
28
- except Exception:
29
- pass
30
- if not _backend_set:
31
- os.environ["TORCHAUDIO_USE_FFMPEG"] = "0"
32
- except Exception:
33
- torchaudio = None
34
-
35
- # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
36
  from dotenv import load_dotenv
37
  load_dotenv()
38
 
39
- # --- Hugging Face Spaces & ZeroGPU ---
 
 
 
 
 
 
 
40
  try:
41
- import spaces
42
  except Exception:
43
  class _SpacesShim:
44
  def GPU(self, *args, **kwargs):
@@ -49,17 +42,20 @@ except Exception:
49
 
50
  import gradio as gr
51
 
52
- # --- Core ML & Data Libraries ---
53
  import torch
54
  import numpy as np
55
  from huggingface_hub import HfApi, hf_hub_download
56
  from llama_cpp import Llama
57
 
 
 
 
58
  # --- TTS Libraries ---
59
  from TTS.tts.configs.xtts_config import XttsConfig
60
  from TTS.tts.models.xtts import Xtts
61
  from TTS.utils.manage import ModelManager
62
- from TTS.utils.generic_utils import get_user_data_dir
63
 
64
  # --- Text & Audio Processing ---
65
  import nltk
@@ -71,12 +67,15 @@ import noisereduce as nr
71
  # 2) GLOBALS & HELPERS
72
  # ===================================================================================
73
 
 
74
  nltk.download("punkt", quiet=True)
75
 
 
76
  tts_model: Xtts | None = None
77
  llm_model: Llama | None = None
78
  voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
79
 
 
80
  HF_TOKEN = os.environ.get("HF_TOKEN")
81
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
82
  repo_id = "ruslanmv/ai-story-server"
@@ -84,6 +83,10 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
84
  SENTENCE_SPLIT_LENGTH = 250
85
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
86
 
 
 
 
 
87
  default_system_message = (
88
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
89
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -96,16 +99,17 @@ ROLE_PROMPTS["Pirate"] = (
96
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
97
  )
98
 
 
99
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
100
  if pcm_data.startswith(b"RIFF"):
101
  return pcm_data
 
 
102
  chunk_size = 36 + len(pcm_data)
103
  header = struct.pack(
104
  "<4sI4s4sIHHIIHH4sI",
105
  b"RIFF", chunk_size, b"WAVE", b"fmt ",
106
- 16, 1, channels, sample_rate,
107
- sample_rate * channels * bit_depth // 8,
108
- channels * bit_depth // 8, bit_depth,
109
  b"data", len(pcm_data)
110
  )
111
  return header + pcm_data
@@ -128,11 +132,61 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
128
  prompt += f"<|user|>\n{message}</s><|assistant|>"
129
  return prompt
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # ===================================================================================
132
  # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
133
  # ===================================================================================
134
 
135
  def precache_assets() -> None:
 
 
136
  print("Pre-caching voice files...")
137
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
138
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
@@ -148,27 +202,31 @@ def precache_assets() -> None:
148
  except Exception as e:
149
  print(f"Failed to download {name}: {e}")
150
 
 
151
  print("Pre-caching XTTS v2 model files...")
152
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
153
 
 
154
  print("Pre-caching Zephyr GGUF...")
155
  try:
156
  hf_hub_download(
157
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
158
- filename="zephyr-7b-beta.Q5_K_M.gguf"
 
159
  )
160
  except Exception as e:
161
  print(f"Warning: GGUF pre-cache error: {e}")
162
 
163
  def _load_xtts(device: str) -> Xtts:
164
- print("Loading Coqui XTTS V2 model (CPU first)...")
 
165
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
166
- model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
 
167
 
168
  cfg = XttsConfig()
169
  cfg.load_json(os.path.join(model_dir, "config.json"))
170
  model = Xtts.init_from_config(cfg)
171
-
172
  model.load_checkpoint(
173
  cfg,
174
  checkpoint_dir=model_dir,
@@ -180,30 +238,59 @@ def _load_xtts(device: str) -> Xtts:
180
  return model
181
 
182
  def _load_llama() -> Llama:
183
- print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
184
  zephyr_model_path = hf_hub_download(
185
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
186
  filename="zephyr-7b-beta.Q5_K_M.gguf"
187
  )
188
- llm = Llama(
189
- model_path=zephyr_model_path,
190
- n_gpu_layers=0,
191
- n_ctx=4096,
192
- n_batch=512,
193
- verbose=False
194
- )
195
- print("LLM loaded (CPU).")
196
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def init_models_and_latents() -> None:
 
 
 
 
199
  global tts_model, llm_model, voice_latents
200
 
 
 
201
  if tts_model is None:
202
- tts_model = _load_xtts(device="cpu")
203
 
204
  if llm_model is None:
205
  llm_model = _load_llama()
206
 
 
207
  if not voice_latents:
208
  print("Computing voice conditioning latents...")
209
  for role, filename in [
@@ -213,18 +300,20 @@ def init_models_and_latents() -> None:
213
  ("Thera", "thera-1.wav"),
214
  ]:
215
  path = os.path.join("voices", filename)
216
- voice_latents[role] = tts_model.get_conditioning_latents(
217
- audio_path=path, gpt_cond_len=30, max_ref_length=60
218
- )
 
219
  print("Voice latents ready.")
220
 
 
221
  def _close_llm():
222
  global llm_model
223
- if llm_model is not None:
224
- try:
225
  llm_model.close()
226
- except Exception:
227
- pass
228
  atexit.register(_close_llm)
229
 
230
  # ===================================================================================
@@ -264,72 +353,88 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
264
  speaker_embedding=speaker_embedding,
265
  temperature=0.85,
266
  ):
267
- if chunk is not None:
268
- yield chunk.detach().cpu().numpy().squeeze().tobytes()
 
 
 
 
 
269
  except RuntimeError as e:
270
  print(f"Error during TTS inference: {e}")
271
  if "device-side assert" in str(e) and api:
272
- gr.Warning("Critical GPU error. Attempting to restart the Space...")
273
  try:
 
274
  api.restart_space(repo_id=repo_id)
275
  except Exception:
276
  pass
277
 
278
  # ===================================================================================
279
- # 5) ZERO-GPU ENTRYPOINT
280
  # ===================================================================================
281
 
282
- @spaces.GPU(duration=120)
283
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
284
  if secret_token_input != SECRET_TOKEN:
285
  raise gr.Error("Invalid secret token provided.")
286
  if not input_text:
287
  return []
288
 
 
289
  if tts_model is None or llm_model is None or not voice_latents:
290
- raise gr.Error("Models not initialized. Please restart the Space.")
291
 
 
292
  try:
293
  if torch.cuda.is_available():
294
  tts_model.to("cuda")
295
  else:
296
  tts_model.to("cpu")
 
 
297
 
298
- history: List[Tuple[str, str | None]] = [(input_text, None)]
299
- full_story_text = "".join(
300
- generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
301
- ).strip()
302
- if not full_story_text:
303
- return []
 
304
 
305
- sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
306
- lang = langid.classify(sentences[0])[0] if sentences else "en"
 
307
 
308
- results: List[Dict[str, str]] = []
309
- for sentence in sentences:
310
- if not any(c.isalnum() for c in sentence):
311
- continue
312
 
313
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
314
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
315
 
316
- try:
317
- data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
318
- if data_s16.size > 0:
319
- float_data = data_s16.astype(np.float32) / 32767.0
320
- reduced = nr.reduce_noise(y=float_data, sr=24000)
321
- final_pcm = (reduced * 32767).astype(np.int16).tobytes()
322
- else:
323
- final_pcm = pcm_data
324
- except Exception:
325
  final_pcm = pcm_data
 
 
326
 
327
- b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
328
- results.append({"text": sentence, "audio": b64_wav})
329
 
330
- return results
331
- finally:
332
  tts_model.to("cpu")
 
 
 
 
333
 
334
  # ===================================================================================
335
  # 6) STARTUP: PRECACHE & UI
@@ -345,15 +450,17 @@ def build_ui() -> gr.Interface:
345
  ],
346
  outputs=gr.JSON(label="Story and Audio Output"),
347
  title="AI Storyteller with ZeroGPU",
348
- description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
349
  flagging_mode="never",
 
350
  )
351
 
352
  if __name__ == "__main__":
353
  print("===== Startup: pre-cache assets and preload models =====")
354
- precache_assets()
355
- init_models_and_latents()
 
356
  print("Models and assets ready. Launching UI...")
357
 
358
  demo = build_ui()
359
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
6
+ import sys
7
  import base64
8
  import struct
9
  import textwrap
 
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
19
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage issues
20
 
21
+ # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
+ # --- NumPy sanity (Torch 2.2.x wants NumPy 1.x) ---
26
+ import numpy as _np
27
+ if int(_np.__version__.split(".", 1)[0]) >= 2:
28
+ raise RuntimeError(
29
+ f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4) for this Space."
30
+ )
31
+
32
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
33
  try:
34
+ import spaces # Required for ZeroGPU on HF
35
  except Exception:
36
  class _SpacesShim:
37
  def GPU(self, *args, **kwargs):
 
42
 
43
  import gradio as gr
44
 
45
+ # --- Core ML & Data Libraries (after spaces import) ---
46
  import torch
47
  import numpy as np
48
  from huggingface_hub import HfApi, hf_hub_download
49
  from llama_cpp import Llama
50
 
51
+ # --- Audio decoding (use ffmpeg-python; no torchaudio) ---
52
+ import ffmpeg
53
+
54
  # --- TTS Libraries ---
55
  from TTS.tts.configs.xtts_config import XttsConfig
56
  from TTS.tts.models.xtts import Xtts
57
  from TTS.utils.manage import ModelManager
58
+ import TTS.tts.models.xtts as xtts_module # for monkey-patching load_audio
59
 
60
  # --- Text & Audio Processing ---
61
  import nltk
 
67
  # 2) GLOBALS & HELPERS
68
  # ===================================================================================
69
 
70
+ # NLTK data
71
  nltk.download("punkt", quiet=True)
72
 
73
+ # Cached models & latents
74
  tts_model: Xtts | None = None
75
  llm_model: Llama | None = None
76
  voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
77
 
78
+ # Config
79
  HF_TOKEN = os.environ.get("HF_TOKEN")
80
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
81
  repo_id = "ruslanmv/ai-story-server"
 
83
  SENTENCE_SPLIT_LENGTH = 250
84
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
85
 
86
+ # Prefer native GPU if available; otherwise we’ll rely on ZeroGPU (or CPU)
87
+ PREFER_NATIVE_GPU = torch.cuda.is_available()
88
+
89
+ # System prompts and roles
90
  default_system_message = (
91
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
92
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
99
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
100
  )
101
 
102
+ # ---------- small utils ----------
103
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
104
  if pcm_data.startswith(b"RIFF"):
105
  return pcm_data
106
+ byte_rate = sample_rate * channels * bit_depth // 8
107
+ block_align = channels * bit_depth // 8
108
  chunk_size = 36 + len(pcm_data)
109
  header = struct.pack(
110
  "<4sI4s4sIHHIIHH4sI",
111
  b"RIFF", chunk_size, b"WAVE", b"fmt ",
112
+ 16, 1, channels, sample_rate, byte_rate, block_align, bit_depth,
 
 
113
  b"data", len(pcm_data)
114
  )
115
  return header + pcm_data
 
132
  prompt += f"<|user|>\n{message}</s><|assistant|>"
133
  return prompt
134
 
135
+ # ---------- robust audio decode (mono via ffmpeg) ----------
136
+ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
137
+ """
138
+ Return float32 waveform in [-1, 1], mono, resampled to target_sr.
139
+ Shape: (samples,)
140
+ """
141
+ try:
142
+ out, _ = (
143
+ ffmpeg
144
+ .input(path)
145
+ .output("pipe:", format="s16le", acodec="pcm_s16le", ac=1, ar=target_sr)
146
+ .run(capture_stdout=True, capture_stderr=True, cmd="ffmpeg")
147
+ )
148
+ pcm = np.frombuffer(out, dtype=np.int16)
149
+ if pcm.size == 0:
150
+ raise RuntimeError("ffmpeg produced empty audio.")
151
+ wav = (pcm.astype(np.float32) / 32767.0)
152
+ return wav
153
+ except ffmpeg.Error as e:
154
+ raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
155
+
156
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
157
+ def _patched_load_audio(audiopath: str, load_sr: int):
158
+ """
159
+ Match XTTS' expected return type:
160
+ - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
161
+ already resampled to `load_sr`.
162
+ - DO NOT return (audio, sr) tuple.
163
+ """
164
+ wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
165
+ import torch as _torch # local import to avoid any circularities
166
+ audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N]
167
+ return audio
168
+
169
+ xtts_module.load_audio = _patched_load_audio
170
+
171
+ # Also patch the common utility location, in case this version imports from there:
172
+ try:
173
+ import TTS.utils.audio as _tts_audio_mod
174
+ _tts_audio_mod.load_audio = _patched_load_audio
175
+ except Exception:
176
+ pass
177
+
178
+ # ---------- where Coqui caches models (avoid get_user_data_dir import) ----------
179
+ def _coqui_cache_dir() -> str:
180
+ # Matches what TTS uses on Linux: ~/.local/share/tts
181
+ return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
182
+
183
  # ===================================================================================
184
  # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
185
  # ===================================================================================
186
 
187
  def precache_assets() -> None:
188
+ """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
189
+ # Voices
190
  print("Pre-caching voice files...")
191
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
192
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
 
202
  except Exception as e:
203
  print(f"Failed to download {name}: {e}")
204
 
205
+ # XTTS model files
206
  print("Pre-caching XTTS v2 model files...")
207
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
208
 
209
+ # LLM GGUF
210
  print("Pre-caching Zephyr GGUF...")
211
  try:
212
  hf_hub_download(
213
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
214
+ filename="zephyr-7b-beta.Q5_K_M.gguf",
215
+ force_download=False
216
  )
217
  except Exception as e:
218
  print(f"Warning: GGUF pre-cache error: {e}")
219
 
220
  def _load_xtts(device: str) -> Xtts:
221
+ """Load XTTS from the local cache. Use checkpoint_dir to avoid None path bugs."""
222
+ print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
223
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
224
+ ModelManager().download_model(model_name) # idempotent
225
+ model_dir = os.path.join(_coqui_cache_dir(), model_name.replace("/", "--"))
226
 
227
  cfg = XttsConfig()
228
  cfg.load_json(os.path.join(model_dir, "config.json"))
229
  model = Xtts.init_from_config(cfg)
 
230
  model.load_checkpoint(
231
  cfg,
232
  checkpoint_dir=model_dir,
 
238
  return model
239
 
240
  def _load_llama() -> Llama:
241
+ """
242
+ Load Llama (Zephyr GGUF). Prefer GPU offload if native CUDA build is present,
243
+ otherwise fall back to pure CPU.
244
+ """
245
+ print("Loading LLM (Zephyr GGUF)...")
246
  zephyr_model_path = hf_hub_download(
247
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
248
  filename="zephyr-7b-beta.Q5_K_M.gguf"
249
  )
250
+
251
+ # Heuristic: try to offload a large number of layers if CUDA build exists.
252
+ gpu_layers_env = int(os.getenv("LLAMA_GPU_LAYERS", "100"))
253
+ n_gpu_layers = gpu_layers_env if PREFER_NATIVE_GPU else 0
254
+
255
+ try:
256
+ llm = Llama(
257
+ model_path=zephyr_model_path,
258
+ n_gpu_layers=n_gpu_layers, # if CUDA build exists, this offloads layers
259
+ n_ctx=4096,
260
+ n_batch=512,
261
+ verbose=False
262
+ )
263
+ used = "GPU-offload" if n_gpu_layers > 0 else "CPU"
264
+ print(f"LLM loaded ({used}).")
265
+ return llm
266
+ except Exception as e:
267
+ print(f"LLM GPU offload failed ({e}); falling back to CPU.")
268
+ llm = Llama(
269
+ model_path=zephyr_model_path,
270
+ n_gpu_layers=0,
271
+ n_ctx=4096,
272
+ n_batch=512,
273
+ verbose=False
274
+ )
275
+ print("LLM loaded (CPU).")
276
+ return llm
277
 
278
  def init_models_and_latents() -> None:
279
+ """
280
+ Preload TTS and LLM. If native GPU is available at startup, load XTTS on CUDA
281
+ and precompute voice latents there; otherwise do it on CPU (ZeroGPU will move it later).
282
+ """
283
  global tts_model, llm_model, voice_latents
284
 
285
+ target_device = "cuda" if PREFER_NATIVE_GPU else "cpu"
286
+
287
  if tts_model is None:
288
+ tts_model = _load_xtts(device=target_device)
289
 
290
  if llm_model is None:
291
  llm_model = _load_llama()
292
 
293
+ # Pre-compute latents once; uses patched loader (ffmpeg) under the hood
294
  if not voice_latents:
295
  print("Computing voice conditioning latents...")
296
  for role, filename in [
 
300
  ("Thera", "thera-1.wav"),
301
  ]:
302
  path = os.path.join("voices", filename)
303
+ with torch.no_grad():
304
+ voice_latents[role] = tts_model.get_conditioning_latents(
305
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
306
+ )
307
  print("Voice latents ready.")
308
 
309
+ # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
310
  def _close_llm():
311
  global llm_model
312
+ try:
313
+ if llm_model is not None:
314
  llm_model.close()
315
+ except Exception:
316
+ pass
317
  atexit.register(_close_llm)
318
 
319
  # ===================================================================================
 
353
  speaker_embedding=speaker_embedding,
354
  temperature=0.85,
355
  ):
356
+ if chunk is None:
357
+ continue
358
+ # chunk: torch.FloatTensor [N] or [1, N], float32 in [-1, 1]
359
+ f32 = chunk.detach().cpu().numpy().squeeze()
360
+ f32 = np.clip(f32, -1.0, 1.0).astype(np.float32)
361
+ s16 = (f32 * 32767.0).astype(np.int16)
362
+ yield s16.tobytes()
363
  except RuntimeError as e:
364
  print(f"Error during TTS inference: {e}")
365
  if "device-side assert" in str(e) and api:
 
366
  try:
367
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
368
  api.restart_space(repo_id=repo_id)
369
  except Exception:
370
  pass
371
 
372
  # ===================================================================================
373
+ # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
374
  # ===================================================================================
375
 
376
+ @spaces.GPU(duration=120) # On native-GPU Spaces this simply runs with the resident GPU.
377
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
378
  if secret_token_input != SECRET_TOKEN:
379
  raise gr.Error("Invalid secret token provided.")
380
  if not input_text:
381
  return []
382
 
383
+ # Ensure models/latents exist
384
  if tts_model is None or llm_model is None or not voice_latents:
385
+ init_models_and_latents()
386
 
387
+ # Prefer GPU if available at call time (ZeroGPU grants CUDA during this function)
388
  try:
389
  if torch.cuda.is_available():
390
  tts_model.to("cuda")
391
  else:
392
  tts_model.to("cpu")
393
+ except Exception:
394
+ tts_model.to("cpu")
395
 
396
+ # Generate story text
397
+ history: List[Tuple[str, str | None]] = [(input_text, None)]
398
+ full_story_text = "".join(
399
+ generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
400
+ ).strip()
401
+ if not full_story_text:
402
+ return []
403
 
404
+ # Split into TTS-friendly sentences
405
+ sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
406
+ lang = langid.classify(sentences[0])[0] if sentences else "en"
407
 
408
+ results: List[Dict[str, str]] = []
409
+ for sentence in sentences:
410
+ if not any(c.isalnum() for c in sentence):
411
+ continue
412
 
413
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
414
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
415
 
416
+ # Optional noise reduction (best-effort)
417
+ try:
418
+ data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
419
+ if data_s16.size > 0:
420
+ float_data = (data_s16.astype(np.float32) / 32767.0)
421
+ reduced = nr.reduce_noise(y=float_data, sr=24000)
422
+ final_pcm = np.clip(reduced * 32767.0, -32768, 32767).astype(np.int16).tobytes()
423
+ else:
 
424
  final_pcm = pcm_data
425
+ except Exception:
426
+ final_pcm = pcm_data
427
 
428
+ b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
429
+ results.append({"text": sentence, "audio": b64_wav})
430
 
431
+ # Release GPU immediately if we were in a ZeroGPU window
432
+ try:
433
  tts_model.to("cpu")
434
+ except Exception:
435
+ pass
436
+
437
+ return results
438
 
439
  # ===================================================================================
440
  # 6) STARTUP: PRECACHE & UI
 
450
  ],
451
  outputs=gr.JSON(label="Story and Audio Output"),
452
  title="AI Storyteller with ZeroGPU",
453
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU or native GPU when available.",
454
  flagging_mode="never",
455
+ allow_flagging="never",
456
  )
457
 
458
  if __name__ == "__main__":
459
  print("===== Startup: pre-cache assets and preload models =====")
460
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
461
+ precache_assets() # 1) download everything to disk
462
+ init_models_and_latents() # 2) load models (prefer native GPU) + compute voice latents
463
  print("Models and assets ready. Launching UI...")
464
 
465
  demo = build_ui()
466
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))