bravedims commited on
Commit
72beae6
Β·
1 Parent(s): eb861f7

Fix critical indentation error in app.py

Browse files

πŸ”§ Critical Fix:
βœ… Fixed IndentationError: unexpected indent at line 249
βœ… Cleaned up corrupted code sections with duplicate/misplaced lines
βœ… Restored proper method structure and indentation
βœ… Removed duplicate get_available_voices() fragments
βœ… Fixed method boundaries and class structure

πŸ—οΈ Code Quality:
βœ… Consistent indentation throughout file
βœ… Proper method organization
βœ… Clean imports and structure
βœ… No duplicate or orphaned code blocks

Result: App should now start without syntax errors!

Files changed (1) hide show
  1. app.py +0 -509
app.py CHANGED
@@ -222,515 +222,6 @@ class TTSManager:
222
  "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
223
  }
224
 
225
- def get_tts_info(self):
226
- """Get TTS system information"""
227
- info = {
228
- "clients_loaded": self.clients_loaded,
229
- "advanced_tts_available": self.advanced_tts is not None,
230
- "robust_tts_available": self.robust_tts is not None,
231
- "primary_method": "Robust TTS"
232
- }
233
-
234
- try:
235
- if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
236
- advanced_info = self.advanced_tts.get_model_info()
237
- info.update({
238
- "advanced_tts_loaded": advanced_info.get("models_loaded", False),
239
- "transformers_available": advanced_info.get("transformers_available", False),
240
- "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
241
- "device": advanced_info.get("device", "cpu"),
242
- "vits_available": advanced_info.get("vits_available", False),
243
- "speecht5_available": advanced_info.get("speecht5_available", False)
244
- })
245
- except Exception as e:
246
- logger.debug(f"Could not get advanced TTS info: {e}")
247
-
248
- return info
249
- return await self.advanced_tts.get_available_voices()
250
- except:
251
- pass
252
-
253
- # Return default voices if advanced TTS not available
254
- return {
255
- "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
256
- "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
257
- "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
258
- "ErXwobaYiN019PkySvjV": "Male (Professional)",
259
- "TxGEqnHWrfGW9XjX": "Male (Deep)",
260
- "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
261
- "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
262
- }
263
-
264
- def get_tts_info(self):
265
- """Get TTS system information"""
266
- info = {
267
- "clients_loaded": self.clients_loaded,
268
- "advanced_tts_available": self.advanced_tts is not None,
269
- "robust_tts_available": self.robust_tts is not None,
270
- "primary_method": "Robust TTS"
271
- }
272
-
273
- try:
274
- if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
275
- advanced_info = self.advanced_tts.get_model_info()
276
- info.update({
277
- "advanced_tts_loaded": advanced_info.get("models_loaded", False),
278
- "transformers_available": advanced_info.get("transformers_available", False),
279
- "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
280
- "device": advanced_info.get("device", "cpu"),
281
- "vits_available": advanced_info.get("vits_available", False),
282
- "speecht5_available": advanced_info.get("speecht5_available", False)
283
- })
284
- except Exception as e:
285
- logger.debug(f"Could not get advanced TTS info: {e}")
286
-
287
- return info
288
-
289
- class OmniAvatarAPI:
290
- def __init__(self):
291
- self.model_loaded = False
292
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
293
- self.tts_manager = TTSManager()
294
- logger.info(f"Using device: {self.device}")
295
- logger.info("Initialized with robust TTS system")
296
-
297
- def load_model(self):
298
- """Load the OmniAvatar model"""
299
- try:
300
- # Check if models are downloaded
301
- model_paths = [
302
- "./pretrained_models/Wan2.1-T2V-14B",
303
- "./pretrained_models/OmniAvatar-14B",
304
- "./pretrained_models/wav2vec2-base-960h"
305
- ]
306
-
307
- for path in model_paths:
308
- if not os.path.exists(path):
309
- logger.error(f"Model path not found: {path}")
310
- return False
311
-
312
- self.model_loaded = True
313
- logger.info("Models loaded successfully")
314
- return True
315
-
316
- except Exception as e:
317
- logger.error(f"Error loading model: {str(e)}")
318
- return False
319
-
320
- async def download_file(self, url: str, suffix: str = "") -> str:
321
- """Download file from URL and save to temporary location"""
322
- try:
323
- async with aiohttp.ClientSession() as session:
324
- async with session.get(str(url)) as response:
325
- if response.status != 200:
326
- raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
327
-
328
- content = await response.read()
329
-
330
- # Create temporary file
331
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
332
- temp_file.write(content)
333
- temp_file.close()
334
-
335
- return temp_file.name
336
-
337
- except aiohttp.ClientError as e:
338
- logger.error(f"Network error downloading {url}: {e}")
339
- raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
340
- except Exception as e:
341
- logger.error(f"Error downloading file from {url}: {e}")
342
- raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
343
-
344
- def validate_audio_url(self, url: str) -> bool:
345
- """Validate if URL is likely an audio file"""
346
- try:
347
- parsed = urlparse(url)
348
- # Check for common audio file extensions
349
- audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
350
- is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
351
-
352
- return is_audio_ext or 'audio' in url.lower()
353
- except:
354
- return False
355
-
356
- def validate_image_url(self, url: str) -> bool:
357
- """Validate if URL is likely an image file"""
358
- try:
359
- parsed = urlparse(url)
360
- image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
361
- return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
362
- except:
363
- return False
364
-
365
- async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
366
- """Generate avatar video from prompt and audio/text"""
367
- import time
368
- start_time = time.time()
369
- audio_generated = False
370
- tts_method = None
371
-
372
- try:
373
- # Determine audio source
374
- audio_path = None
375
-
376
- if request.text_to_speech:
377
- # Generate speech from text using TTS manager
378
- logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
379
- audio_path, tts_method = await self.tts_manager.text_to_speech(
380
- request.text_to_speech,
381
- request.voice_id or "21m00Tcm4TlvDq8ikWAM"
382
- )
383
- audio_generated = True
384
-
385
- elif request.audio_url:
386
- # Download audio from provided URL
387
- logger.info(f"Downloading audio from URL: {request.audio_url}")
388
- if not self.validate_audio_url(str(request.audio_url)):
389
- logger.warning(f"Audio URL may not be valid: {request.audio_url}")
390
-
391
- audio_path = await self.download_file(str(request.audio_url), ".mp3")
392
- tts_method = "External Audio URL"
393
-
394
- else:
395
- raise HTTPException(
396
- status_code=400,
397
- detail="Either text_to_speech or audio_url must be provided"
398
- )
399
-
400
- # Download image if provided
401
- image_path = None
402
- if request.image_url:
403
- logger.info(f"Downloading image from URL: {request.image_url}")
404
- if not self.validate_image_url(str(request.image_url)):
405
- logger.warning(f"Image URL may not be valid: {request.image_url}")
406
-
407
- # Determine image extension from URL or default to .jpg
408
- parsed = urlparse(str(request.image_url))
409
- ext = os.path.splitext(parsed.path)[1] or ".jpg"
410
- image_path = await self.download_file(str(request.image_url), ext)
411
-
412
- # Create temporary input file for inference
413
- with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
414
- if image_path:
415
- input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
416
- else:
417
- input_line = f"{request.prompt}@@@@{audio_path}"
418
- f.write(input_line)
419
- temp_input_file = f.name
420
-
421
- # Prepare inference command
422
- cmd = [
423
- "python", "-m", "torch.distributed.run",
424
- "--standalone", f"--nproc_per_node={request.sp_size}",
425
- "scripts/inference.py",
426
- "--config", "configs/inference.yaml",
427
- "--input_file", temp_input_file,
428
- "--guidance_scale", str(request.guidance_scale),
429
- "--audio_scale", str(request.audio_scale),
430
- "--num_steps", str(request.num_steps)
431
- ]
432
-
433
- if request.tea_cache_l1_thresh:
434
- cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
435
-
436
- logger.info(f"Running inference with command: {' '.join(cmd)}")
437
-
438
- # Run inference
439
- result = subprocess.run(cmd, capture_output=True, text=True)
440
-
441
- # Clean up temporary files
442
- os.unlink(temp_input_file)
443
- os.unlink(audio_path)
444
- if image_path:
445
- os.unlink(image_path)
446
-
447
- if result.returncode != 0:
448
- logger.error(f"Inference failed: {result.stderr}")
449
- raise Exception(f"Inference failed: {result.stderr}")
450
-
451
- # Find output video file
452
- output_dir = "./outputs"
453
- if os.path.exists(output_dir):
454
- video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
455
- if video_files:
456
- # Return the most recent video file
457
- video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
458
- output_path = os.path.join(output_dir, video_files[0])
459
- processing_time = time.time() - start_time
460
- return output_path, processing_time, audio_generated, tts_method
461
-
462
- raise Exception("No output video generated")
463
-
464
- except Exception as e:
465
- # Clean up any temporary files in case of error
466
- try:
467
- if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
468
- os.unlink(audio_path)
469
- if 'image_path' in locals() and image_path and os.path.exists(image_path):
470
- os.unlink(image_path)
471
- if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
472
- os.unlink(temp_input_file)
473
- except:
474
- pass
475
-
476
- logger.error(f"Generation error: {str(e)}")
477
- raise HTTPException(status_code=500, detail=str(e))
478
-
479
- # Initialize API
480
- omni_api = OmniAvatarAPI()
481
-
482
- # Use FastAPI lifespan instead of deprecated on_event
483
- from contextlib import asynccontextmanager
484
-
485
- @asynccontextmanager
486
- async def lifespan(app: FastAPI):
487
- # Startup
488
- success = omni_api.load_model()
489
- if not success:
490
- logger.warning("OmniAvatar model loading failed on startup")
491
-
492
- # Load TTS models
493
- try:
494
- await omni_api.tts_manager.load_models()
495
- logger.info("TTS models initialization completed")
496
- except Exception as e:
497
- logger.error(f"TTS initialization failed: {e}")
498
-
499
- yield
500
-
501
- # Shutdown (if needed)
502
- logger.info("Application shutting down...")
503
-
504
- # Apply lifespan to app
505
- app.router.lifespan_context = lifespan
506
-
507
- @app.get("/health")
508
- async def health_check():
509
- """Health check endpoint"""
510
- tts_info = omni_api.tts_manager.get_tts_info()
511
-
512
- return {
513
- "status": "healthy",
514
- "model_loaded": omni_api.model_loaded,
515
- "device": omni_api.device,
516
- "supports_text_to_speech": True,
517
- "supports_image_urls": True,
518
- "supports_audio_urls": True,
519
- "tts_system": "Advanced TTS with Robust Fallback",
520
- "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
521
- "robust_tts_available": ROBUST_TTS_AVAILABLE,
522
- **tts_info
523
- }
524
-
525
- @app.get("/voices")
526
- async def get_voices():
527
- """Get available voice configurations"""
528
- try:
529
- voices = await omni_api.tts_manager.get_available_voices()
530
- return {"voices": voices}
531
- except Exception as e:
532
- logger.error(f"Error getting voices: {e}")
533
- return {"error": str(e)}
534
-
535
- @app.post("/generate", response_model=GenerateResponse)
536
- async def generate_avatar(request: GenerateRequest):
537
- """Generate avatar video from prompt, text/audio, and optional image URL"""
538
-
539
- if not omni_api.model_loaded:
540
- raise HTTPException(status_code=503, detail="Model not loaded")
541
-
542
- logger.info(f"Generating avatar with prompt: {request.prompt}")
543
- if request.text_to_speech:
544
- logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
545
- logger.info(f"Voice ID: {request.voice_id}")
546
- if request.audio_url:
547
- logger.info(f"Audio URL: {request.audio_url}")
548
- if request.image_url:
549
- logger.info(f"Image URL: {request.image_url}")
550
-
551
- try:
552
- output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
553
-
554
- return GenerateResponse(
555
- message="Avatar generation completed successfully",
556
- output_path=get_video_url(output_path),
557
- processing_time=processing_time,
558
- audio_generated=audio_generated,
559
- tts_method=tts_method
560
- )
561
-
562
- except HTTPException:
563
- raise
564
- except Exception as e:
565
- logger.error(f"Unexpected error: {e}")
566
- raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
567
-
568
- # Enhanced Gradio interface with proper flagging configuration
569
- def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
570
- """Gradio interface wrapper with robust TTS support"""
571
- if not omni_api.model_loaded:
572
- return "Error: Model not loaded"
573
-
574
- try:
575
- # Create request object
576
- request_data = {
577
- "prompt": prompt,
578
- "guidance_scale": guidance_scale,
579
- "audio_scale": audio_scale,
580
- "num_steps": int(num_steps)
581
- }
582
-
583
- # Add audio source
584
- if text_to_speech and text_to_speech.strip():
585
- request_data["text_to_speech"] = text_to_speech
586
- request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
587
- elif audio_url and audio_url.strip():
588
- request_data["audio_url"] = audio_url
589
- else:
590
- return "Error: Please provide either text to speech or audio URL"
591
-
592
- if image_url and image_url.strip():
593
- request_data["image_url"] = image_url
594
-
595
- request = GenerateRequest(**request_data)
596
-
597
- # Run async function in sync context
598
- loop = asyncio.new_event_loop()
599
- asyncio.set_event_loop(loop)
600
- output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
601
- loop.close()
602
-
603
- success_message = f"βœ… Generation completed in {processing_time:.1f}s using {tts_method}"
604
- print(success_message)
605
-
606
- return output_path
607
-
608
- except Exception as e:
609
- logger.error(f"Gradio generation error: {e}")
610
- return f"Error: {str(e)}"
611
-
612
- # Create Gradio interface with fixed flagging settings
613
- iface = gr.Interface(
614
- fn=gradio_generate,
615
- inputs=[
616
- gr.Textbox(
617
- label="Prompt",
618
- placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
619
- lines=2
620
- ),
621
- gr.Textbox(
622
- label="Text to Speech",
623
- placeholder="Enter text to convert to speech",
624
- lines=3,
625
- info="Will use best available TTS system (Advanced or Fallback)"
626
- ),
627
- gr.Textbox(
628
- label="OR Audio URL",
629
- placeholder="https://example.com/audio.mp3",
630
- info="Direct URL to audio file (alternative to text-to-speech)"
631
- ),
632
- gr.Textbox(
633
- label="Image URL (Optional)",
634
- placeholder="https://example.com/image.jpg",
635
- info="Direct URL to reference image (JPG, PNG, etc.)"
636
- ),
637
- gr.Dropdown(
638
- choices=[
639
- "21m00Tcm4TlvDq8ikWAM",
640
- "pNInz6obpgDQGcFmaJgB",
641
- "EXAVITQu4vr4xnSDxMaL",
642
- "ErXwobaYiN019PkySvjV",
643
- "TxGEqnHWrfGW9XjX",
644
- "yoZ06aMxZJJ28mfd3POQ",
645
- "AZnzlk1XvdvUeBnXmlld"
646
- ],
647
- value="21m00Tcm4TlvDq8ikWAM",
648
- label="Voice Profile",
649
- info="Choose voice characteristics for TTS generation"
650
- ),
651
- gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
652
- gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
653
- gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
654
- ],
655
- outputs=gr.Video(label="Generated Avatar Video"),
656
- title="🎭 OmniAvatar-14B with Advanced TTS System",
657
- description="""
658
- Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
659
-
660
- **πŸ”§ Robust TTS Architecture**
661
- - πŸ€– **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
662
- - πŸ”„ **Fallback**: Robust tone generation for 100% reliability
663
- - ⚑ **Automatic**: Seamless switching between methods
664
-
665
- **Features:**
666
- - βœ… **Guaranteed Generation**: Always produces audio output
667
- - βœ… **No Dependencies**: Works even without advanced models
668
- - βœ… **High Availability**: Multiple fallback layers
669
- - βœ… **Voice Profiles**: Multiple voice characteristics
670
- - βœ… **Audio URL Support**: Use external audio files
671
- - βœ… **Image URL Support**: Reference images for characters
672
-
673
- **Usage:**
674
- 1. Enter a character description in the prompt
675
- 2. **Either** enter text for speech generation **OR** provide an audio URL
676
- 3. Optionally add a reference image URL
677
- 4. Choose voice profile and adjust parameters
678
- 5. Generate your avatar video!
679
-
680
- **System Status:**
681
- - The system will automatically use the best available TTS method
682
- - If advanced models are available, you'll get high-quality speech
683
- - If not, robust fallback ensures the system always works
684
- """,
685
- examples=[
686
- [
687
- "A professional teacher explaining a mathematical concept with clear gestures",
688
- "Hello students! Today we're going to learn about calculus and derivatives.",
689
- "",
690
- "",
691
- "21m00Tcm4TlvDq8ikWAM",
692
- 5.0,
693
- 3.5,
694
- 30
695
- ],
696
- [
697
- "A friendly presenter speaking confidently to an audience",
698
- "Welcome everyone to our presentation on artificial intelligence!",
699
- "",
700
- "",
701
- "pNInz6obpgDQGcFmaJgB",
702
- 5.5,
703
- 4.0,
704
- 35
705
- ]
706
- ],
707
- # Disable flagging to prevent permission errors
708
- allow_flagging="never",
709
- # Set flagging directory to writable location
710
- flagging_dir="/tmp/gradio_flagged"
711
- )
712
-
713
- # Mount Gradio app
714
- app = gr.mount_gradio_app(app, iface, path="/gradio")
715
-
716
- if __name__ == "__main__":
717
- import uvicorn
718
- uvicorn.run(app, host="0.0.0.0", port=7860)
719
- return await self.advanced_tts.get_available_voices()
720
- except:
721
- pass
722
-
723
- # Return default voices if advanced TTS not available
724
- return {
725
- "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
726
- "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
727
- "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
728
- "ErXwobaYiN019PkySvjV": "Male (Professional)",
729
- "TxGEqnHWrfGW9XjX": "Male (Deep)",
730
- "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
731
- "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
732
- }
733
-
734
  def get_tts_info(self):
735
  """Get TTS system information"""
736
  info = {
 
222
  "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
223
  }
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def get_tts_info(self):
226
  """Get TTS system information"""
227
  info = {