ACloudCenter commited on
Commit
d5cf69f
·
1 Parent(s): 5dc3e05

Added feedback to UI for inference

Browse files
Files changed (2) hide show
  1. app.py +52 -11
  2. backend_modal/modal_runner.py +187 -22
app.py CHANGED
@@ -221,6 +221,19 @@ def create_demo_interface():
221
  lines=8, max_lines=15,
222
  interactive=False,
223
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  def update_speaker_visibility(num_speakers):
226
  return [gr.update(visible=(i < num_speakers)) for i in range(4)]
@@ -303,15 +316,23 @@ def create_demo_interface():
303
 
304
  def generate_podcast_wrapper(model_choice, num_speakers_val, script, *speakers_and_params):
305
  if remote_generate_function is None:
306
- return None, "ERROR: Modal function not deployed. Please contact the space owner."
307
-
 
 
308
  # Show a message that we are calling the remote function
309
- yield None, "🔄 Calling remote GPU on Modal.com... this may take a moment to start."
 
 
 
 
 
310
 
311
  try:
312
  speakers = speakers_and_params[:4]
313
  cfg_scale_val = speakers_and_params[4]
314
-
 
315
  # Stream updates from the Modal function
316
  for update in remote_generate_function.remote_gen(
317
  num_speakers=int(num_speakers_val),
@@ -323,19 +344,39 @@ def create_demo_interface():
323
  cfg_scale=cfg_scale_val,
324
  model_name=model_choice
325
  ):
326
- # Each update is a tuple (audio_or_none, log_message)
327
- if update:
328
- audio, log = update
329
- yield audio, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  except Exception as e:
331
  tb = traceback.format_exc()
332
  print(f"Error calling Modal: {e}")
333
- yield None, f"❌ An error occurred: {e}\n\n{tb}"
 
 
 
 
 
 
334
 
335
  generate_btn.click(
336
  fn=generate_podcast_wrapper,
337
  inputs=[model_dropdown, num_speakers, script_input] + speaker_selections + [cfg_scale],
338
- outputs=[complete_audio_output, log_output]
339
  )
340
 
341
  with gr.Tab("Architecture"):
@@ -414,4 +455,4 @@ if __name__ == "__main__":
414
  else:
415
  # Launch the full Gradio interface
416
  interface = create_demo_interface()
417
- interface.queue().launch(show_error=True)
 
221
  lines=8, max_lines=15,
222
  interactive=False,
223
  )
224
+ with gr.Row():
225
+ status_display = gr.Markdown(
226
+ value="Status: idle.",
227
+ elem_id="status-display",
228
+ )
229
+ progress_slider = gr.Slider(
230
+ minimum=0,
231
+ maximum=100,
232
+ value=0,
233
+ step=1,
234
+ label="Progress",
235
+ interactive=False,
236
+ )
237
 
238
  def update_speaker_visibility(num_speakers):
239
  return [gr.update(visible=(i < num_speakers)) for i in range(4)]
 
316
 
317
  def generate_podcast_wrapper(model_choice, num_speakers_val, script, *speakers_and_params):
318
  if remote_generate_function is None:
319
+ error_message = "ERROR: Modal function not deployed. Please contact the space owner."
320
+ yield None, error_message, "Status: error.", gr.update(value=0)
321
+ return
322
+
323
  # Show a message that we are calling the remote function
324
+ yield (
325
+ None,
326
+ "🔄 Calling remote GPU on Modal.com... this may take a moment to start.",
327
+ "**Connecting**\nRequesting GPU resources…",
328
+ gr.update(value=0),
329
+ )
330
 
331
  try:
332
  speakers = speakers_and_params[:4]
333
  cfg_scale_val = speakers_and_params[4]
334
+ current_log = ""
335
+
336
  # Stream updates from the Modal function
337
  for update in remote_generate_function.remote_gen(
338
  num_speakers=int(num_speakers_val),
 
344
  cfg_scale=cfg_scale_val,
345
  model_name=model_choice
346
  ):
347
+ if not update:
348
+ continue
349
+
350
+ audio_payload = update.get("audio")
351
+ progress_pct = update.get("pct", 0)
352
+ stage_label = update.get("stage", "").replace("_", " ").title() or "Status"
353
+ status_line = update.get("status") or "Processing…"
354
+ current_log = update.get("log", current_log)
355
+
356
+ status_formatted = f"**{stage_label}**\n{status_line}"
357
+ audio_output = audio_payload if audio_payload is not None else gr.update()
358
+
359
+ yield (
360
+ audio_output,
361
+ current_log,
362
+ status_formatted,
363
+ gr.update(value=progress_pct),
364
+ )
365
  except Exception as e:
366
  tb = traceback.format_exc()
367
  print(f"Error calling Modal: {e}")
368
+ error_log = f"❌ An error occurred: {e}\n\n{tb}"
369
+ yield (
370
+ None,
371
+ error_log,
372
+ "**Error**\nInference failed.",
373
+ gr.update(value=0),
374
+ )
375
 
376
  generate_btn.click(
377
  fn=generate_podcast_wrapper,
378
  inputs=[model_dropdown, num_speakers, script_input] + speaker_selections + [cfg_scale],
379
+ outputs=[complete_audio_output, log_output, status_display, progress_slider]
380
  )
381
 
382
  with gr.Tab("Architecture"):
 
455
  else:
456
  # Launch the full Gradio interface
457
  interface = create_demo_interface()
458
+ interface.queue().launch(show_error=True)
backend_modal/modal_runner.py CHANGED
@@ -5,6 +5,9 @@ import librosa
5
  import soundfile as sf
6
  import torch
7
  from datetime import datetime
 
 
 
8
 
9
  # Modal-specific imports
10
  import modal
@@ -38,8 +41,14 @@ app = modal.App(
38
  image=image,
39
  )
40
 
 
 
41
 
42
- @app.cls(gpu="A100-40GB", scaledown_window=300)
 
 
 
 
43
  class VibeVoiceModel:
44
  def __init__(self):
45
  self.model_paths = {
@@ -48,6 +57,8 @@ class VibeVoiceModel:
48
  }
49
  self.device = "cuda"
50
  self.inference_steps = 5
 
 
51
 
52
  @modal.enter()
53
  def load_models(self):
@@ -113,6 +124,95 @@ class VibeVoiceModel:
113
  self.available_voices[name] = os.path.join(voices_dir, wav_file)
114
  print(f"Voices loaded: {list(self.available_voices.keys())}")
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
117
  try:
118
  wav, sr = sf.read(audio_path)
@@ -193,15 +293,36 @@ class VibeVoiceModel:
193
  Yields progress updates during generation.
194
  """
195
  try:
196
- # Yield initial status
197
- yield None, "🔄 Initializing generation..."
198
  if model_name not in self.models:
199
  raise ValueError(f"Unknown model: {model_name}")
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  # Move the selected model to GPU, others to CPU
202
- yield None, "🔄 Loading model to GPU..."
 
 
 
 
 
203
  self._place_model(model_name)
204
-
205
  model = self.models[model_name]
206
  processor = self.processors[model_name]
207
  model.set_ddpm_inference_steps(num_steps=self.inference_steps)
@@ -216,17 +337,18 @@ class VibeVoiceModel:
216
  if not 1 <= num_speakers <= 4:
217
  raise ValueError("Error: Number of speakers must be between 1 and 4.")
218
 
219
- selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
220
  for i, speaker_name in enumerate(selected_speakers):
221
  if not speaker_name or speaker_name not in self.available_voices:
222
  raise ValueError(f"Error: Please select a valid speaker for Speaker {i+1}.")
223
 
224
- log = f"Generating conference with {num_speakers} speakers\n"
225
- log += f"Model: {model_name}\n"
226
- log += f"Parameters: CFG Scale={cfg_scale}\n"
227
- log += f"Speakers: {', '.join(selected_speakers)}\n"
228
-
229
- yield None, log + "\n🔄 Loading voice samples..."
 
 
230
 
231
  voice_samples = []
232
  for i, speaker_name in enumerate(selected_speakers):
@@ -235,9 +357,18 @@ class VibeVoiceModel:
235
  if len(audio_data) == 0:
236
  raise ValueError(f"Error: Failed to load audio for {speaker_name}")
237
  voice_samples.append(audio_data)
238
- yield None, log + f"\n✓ Loaded voice {i+1}/{len(selected_speakers)}: {speaker_name}"
 
 
 
 
 
 
 
 
239
 
240
- log += f"\nLoaded {len(voice_samples)} voice samples"
 
241
 
242
  lines = script.strip().split('\n')
243
  formatted_script_lines = []
@@ -251,8 +382,14 @@ class VibeVoiceModel:
251
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
252
 
253
  formatted_script = '\n'.join(formatted_script_lines)
254
- log += f"\nFormatted script with {len(formatted_script_lines)} turns"
255
- yield None, log + "\n🔄 Processing script with VibeVoice..."
 
 
 
 
 
 
256
 
257
  inputs = processor(
258
  text=[formatted_script],
@@ -262,7 +399,14 @@ class VibeVoiceModel:
262
  return_attention_mask=True,
263
  ).to(self.device)
264
 
265
- yield None, log + "\n🎯 Starting audio generation (this may take 1-2 minutes)..."
 
 
 
 
 
 
 
266
  start_time = time.time()
267
 
268
  with torch.inference_mode():
@@ -276,7 +420,15 @@ class VibeVoiceModel:
276
  )
277
  generation_time = time.time() - start_time
278
 
279
- yield None, log + f"\n✓ Generation completed in {generation_time:.2f} seconds\n🔄 Processing audio..."
 
 
 
 
 
 
 
 
280
 
281
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
282
  audio_tensor = outputs.speech_outputs[0]
@@ -289,15 +441,28 @@ class VibeVoiceModel:
289
 
290
  sample_rate = 24000
291
  total_duration = len(audio) / sample_rate
292
- log += f"\n✓ Generation completed in {generation_time:.2f} seconds"
293
- log += f"\n✓ Audio duration: {total_duration:.2f} seconds"
 
294
 
295
  # Final yield with both audio and complete log
296
- yield (sample_rate, audio), log + "\n✅ Complete!"
 
 
 
 
 
 
 
297
 
298
  except Exception as e:
299
  import traceback
300
  error_msg = f"❌ An unexpected error occurred on Modal: {str(e)}\n{traceback.format_exc()}"
301
  print(error_msg)
302
  # Yield error state
303
- yield None, error_msg
 
 
 
 
 
 
5
  import soundfile as sf
6
  import torch
7
  from datetime import datetime
8
+ import hashlib
9
+ import json
10
+ import pickle
11
 
12
  # Modal-specific imports
13
  import modal
 
41
  image=image,
42
  )
43
 
44
+ # Create a volume for caching generated audio
45
+ cache_volume = modal.Volume.from_name("vibevoice-cache", create_if_missing=True)
46
 
47
+ @app.cls(
48
+ gpu="A100-40GB",
49
+ scaledown_window=300,
50
+ volumes={"/cache": cache_volume}
51
+ )
52
  class VibeVoiceModel:
53
  def __init__(self):
54
  self.model_paths = {
 
57
  }
58
  self.device = "cuda"
59
  self.inference_steps = 5
60
+ self.cache_dir = "/cache"
61
+ self.max_cache_size_gb = 10 # Limit cache to 10GB
62
 
63
  @modal.enter()
64
  def load_models(self):
 
124
  self.available_voices[name] = os.path.join(voices_dir, wav_file)
125
  print(f"Voices loaded: {list(self.available_voices.keys())}")
126
 
127
+ def _emit_progress(self, stage: str, pct: float, status: str, log_text: str,
128
+ audio=None, done: bool = False):
129
+ """Package a structured progress update for streaming back to Gradio."""
130
+ payload = {
131
+ "stage": stage,
132
+ "pct": pct,
133
+ "status": status,
134
+ "log": log_text,
135
+ }
136
+ if audio is not None:
137
+ payload["audio"] = audio
138
+ if done:
139
+ payload["done"] = True
140
+ return payload
141
+
142
+ def _generate_cache_key(self, script: str, model_name: str, speakers: list, cfg_scale: float) -> str:
143
+ """Generate a unique cache key for this generation."""
144
+ cache_data = {
145
+ "script": script.strip().lower(), # Normalize script
146
+ "model": model_name,
147
+ "speakers": sorted(speakers), # Sort for consistency
148
+ "cfg_scale": cfg_scale,
149
+ "inference_steps": self.inference_steps
150
+ }
151
+ cache_str = json.dumps(cache_data, sort_keys=True)
152
+ return hashlib.sha256(cache_str.encode()).hexdigest()
153
+
154
+ def _get_cached_audio(self, cache_key: str):
155
+ """Check if audio is cached and return it."""
156
+ cache_path = os.path.join(self.cache_dir, f"{cache_key}.pkl")
157
+ if os.path.exists(cache_path):
158
+ try:
159
+ with open(cache_path, 'rb') as f:
160
+ cached_data = pickle.load(f)
161
+ print(f"Cache hit! Loading from {cache_key}")
162
+ return cached_data['audio'], cached_data['sample_rate']
163
+ except Exception as e:
164
+ print(f"Cache read error: {e}")
165
+ return None, None
166
+
167
+ def _save_to_cache(self, cache_key: str, audio: np.ndarray, sample_rate: int):
168
+ """Save generated audio to cache."""
169
+ try:
170
+ # Check cache size
171
+ self._cleanup_cache_if_needed()
172
+
173
+ cache_path = os.path.join(self.cache_dir, f"{cache_key}.pkl")
174
+ cached_data = {
175
+ 'audio': audio,
176
+ 'sample_rate': sample_rate,
177
+ 'timestamp': time.time()
178
+ }
179
+ with open(cache_path, 'wb') as f:
180
+ pickle.dump(cached_data, f)
181
+ print(f"Saved to cache: {cache_key}")
182
+
183
+ # Commit the volume changes
184
+ cache_volume.commit()
185
+ except Exception as e:
186
+ print(f"Cache write error: {e}")
187
+
188
+ def _cleanup_cache_if_needed(self):
189
+ """Remove old cache files if cache is too large."""
190
+ try:
191
+ cache_files = []
192
+ total_size = 0
193
+
194
+ for filename in os.listdir(self.cache_dir):
195
+ if filename.endswith('.pkl'):
196
+ filepath = os.path.join(self.cache_dir, filename)
197
+ size = os.path.getsize(filepath)
198
+ mtime = os.path.getmtime(filepath)
199
+ cache_files.append((filepath, size, mtime))
200
+ total_size += size
201
+
202
+ # If cache is too large, remove oldest files
203
+ max_size = self.max_cache_size_gb * 1024 * 1024 * 1024
204
+ if total_size > max_size:
205
+ # Sort by modification time (oldest first)
206
+ cache_files.sort(key=lambda x: x[2])
207
+
208
+ while total_size > max_size * 0.8 and cache_files: # Keep 80% full
209
+ filepath, size, _ = cache_files.pop(0)
210
+ os.remove(filepath)
211
+ total_size -= size
212
+ print(f"Removed old cache: {os.path.basename(filepath)}")
213
+ except Exception as e:
214
+ print(f"Cache cleanup error: {e}")
215
+
216
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
217
  try:
218
  wav, sr = sf.read(audio_path)
 
293
  Yields progress updates during generation.
294
  """
295
  try:
 
 
296
  if model_name not in self.models:
297
  raise ValueError(f"Unknown model: {model_name}")
298
 
299
+ # Initialize log scaffold
300
+ selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
301
+ log_lines = [
302
+ f"Generating conference with {num_speakers} speakers",
303
+ f"Model: {model_name}",
304
+ f"Parameters: CFG Scale={cfg_scale}",
305
+ f"Speakers: {', '.join(selected_speakers)}",
306
+ ]
307
+ log_text = "\n".join(log_lines)
308
+
309
+ # Emit initial status before heavy work kicks in
310
+ yield self._emit_progress(
311
+ stage="queued",
312
+ pct=5,
313
+ status="Queued GPU job and validating inputs…",
314
+ log_text=log_text,
315
+ )
316
+
317
  # Move the selected model to GPU, others to CPU
318
+ yield self._emit_progress(
319
+ stage="loading_model",
320
+ pct=15,
321
+ status=f"Loading {model_name} weights to GPU…",
322
+ log_text=log_text,
323
+ )
324
  self._place_model(model_name)
325
+
326
  model = self.models[model_name]
327
  processor = self.processors[model_name]
328
  model.set_ddpm_inference_steps(num_steps=self.inference_steps)
 
337
  if not 1 <= num_speakers <= 4:
338
  raise ValueError("Error: Number of speakers must be between 1 and 4.")
339
 
 
340
  for i, speaker_name in enumerate(selected_speakers):
341
  if not speaker_name or speaker_name not in self.available_voices:
342
  raise ValueError(f"Error: Please select a valid speaker for Speaker {i+1}.")
343
 
344
+ log_lines.append("Loading voice samples…")
345
+ log_text = "\n".join(log_lines)
346
+ yield self._emit_progress(
347
+ stage="loading_voices",
348
+ pct=25,
349
+ status="Loading reference voices…",
350
+ log_text=log_text,
351
+ )
352
 
353
  voice_samples = []
354
  for i, speaker_name in enumerate(selected_speakers):
 
357
  if len(audio_data) == 0:
358
  raise ValueError(f"Error: Failed to load audio for {speaker_name}")
359
  voice_samples.append(audio_data)
360
+ voice_pct = 25 + ((i + 1) / len(selected_speakers)) * 15
361
+ log_lines.append(f"Loaded voice {i+1}/{len(selected_speakers)}: {speaker_name}")
362
+ log_text = "\n".join(log_lines)
363
+ yield self._emit_progress(
364
+ stage="loading_voices",
365
+ pct=voice_pct,
366
+ status=f"Loaded {speaker_name}",
367
+ log_text=log_text,
368
+ )
369
 
370
+ log_lines.append(f"Loaded {len(voice_samples)} voice samples")
371
+ log_text = "\n".join(log_lines)
372
 
373
  lines = script.strip().split('\n')
374
  formatted_script_lines = []
 
382
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
383
 
384
  formatted_script = '\n'.join(formatted_script_lines)
385
+ log_lines.append(f"Formatted script with {len(formatted_script_lines)} turns")
386
+ log_text = "\n".join(log_lines)
387
+ yield self._emit_progress(
388
+ stage="preparing_inputs",
389
+ pct=50,
390
+ status="Formatting script and preparing tensors…",
391
+ log_text=log_text,
392
+ )
393
 
394
  inputs = processor(
395
  text=[formatted_script],
 
399
  return_attention_mask=True,
400
  ).to(self.device)
401
 
402
+ log_lines.append("Inputs prepared; starting diffusion generation…")
403
+ log_text = "\n".join(log_lines)
404
+ yield self._emit_progress(
405
+ stage="generating_audio",
406
+ pct=70,
407
+ status="Running VibeVoice diffusion (this may take 1-2 minutes)…",
408
+ log_text=log_text,
409
+ )
410
  start_time = time.time()
411
 
412
  with torch.inference_mode():
 
420
  )
421
  generation_time = time.time() - start_time
422
 
423
+ log_lines.append(f"Generation completed in {generation_time:.2f} seconds")
424
+ log_lines.append("Processing audio output…")
425
+ log_text = "\n".join(log_lines)
426
+ yield self._emit_progress(
427
+ stage="processing_audio",
428
+ pct=90,
429
+ status="Post-processing audio output…",
430
+ log_text=log_text,
431
+ )
432
 
433
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
434
  audio_tensor = outputs.speech_outputs[0]
 
441
 
442
  sample_rate = 24000
443
  total_duration = len(audio) / sample_rate
444
+ log_lines.append(f"Audio duration: {total_duration:.2f} seconds")
445
+ log_lines.append("Complete!")
446
+ log_text = "\n".join(log_lines)
447
 
448
  # Final yield with both audio and complete log
449
+ yield self._emit_progress(
450
+ stage="complete",
451
+ pct=100,
452
+ status="Conference ready to download.",
453
+ log_text=log_text,
454
+ audio=(sample_rate, audio),
455
+ done=True,
456
+ )
457
 
458
  except Exception as e:
459
  import traceback
460
  error_msg = f"❌ An unexpected error occurred on Modal: {str(e)}\n{traceback.format_exc()}"
461
  print(error_msg)
462
  # Yield error state
463
+ yield self._emit_progress(
464
+ stage="error",
465
+ pct=0,
466
+ status="Generation failed.",
467
+ log_text=error_msg,
468
+ )