Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline | |
| from diffusers.utils import export_to_video | |
| from PIL import Image, ImageOps | |
| from gtts import gTTS | |
| from pydub import AudioSegment | |
| try: | |
| import whisper | |
| except ImportError: | |
| whisper = None | |
| import ffmpeg | |
| import requests | |
| from io import BytesIO | |
| import os | |
| import gc | |
| # Load LTX models | |
| ltx_model_id = "Lightricks/LTX-Video-0.9.7-distilled" | |
| upscaler_model_id = "Lightricks/ltxv-spatial-upscaler-0.9.7" | |
| pipe = LTXConditionPipeline.from_pretrained(ltx_model_id, torch_dtype=torch.float16) | |
| pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( | |
| upscaler_model_id, vae=pipe.vae, torch_dtype=torch.float16 | |
| ) | |
| pipe.to("cuda") | |
| pipe_upsample.to("cuda") | |
| pipe.vae.enable_tiling() | |
| def prepare_image_condition(image, size=(512, 512), background=(0, 0, 0)): | |
| image = ImageOps.contain(image, size) | |
| canvas = Image.new("RGB", size, background) | |
| offset = ((size[0] - image.width) // 2, (size[1] - image.height) // 2) | |
| canvas.paste(image, offset) | |
| return canvas | |
| def generate_video(prompt, image_url): | |
| generator = torch.Generator("cuda").manual_seed(42) | |
| # Load & prepare image | |
| image = None | |
| if image_url: | |
| raw_image = Image.open(BytesIO(requests.get(image_url).content)).convert("RGB") | |
| image = prepare_image_condition(raw_image) | |
| # Set target resolutions - using dimensions that match expected latent shapes | |
| # LTX uses 32x downsampling, so we need multiples of 32 | |
| # For latent shape (1, 128, 8, 16, 16), we need 16*32 = 512x512 | |
| base_width, base_height = 512, 512 # final upscaled size (16*32) | |
| down_width, down_height = 256, 256 # for initial generation (8*32) - smaller ratio for upscaling | |
| # Step 1: Generate latents at lower resolution with improved quality settings | |
| latents = pipe( | |
| prompt=prompt, | |
| image=image, | |
| width=down_width, | |
| height=down_height, | |
| num_frames=60, | |
| num_inference_steps=10, # Increased from 7 for better quality | |
| output_type="latent", | |
| guidance_scale=1.5, # Slightly increased for better prompt adherence | |
| decode_timestep=0.08, # Optimized value | |
| decode_noise_scale=0.05, # Reduced noise | |
| generator=generator | |
| ).frames | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Step 2: Upscale latents | |
| upscaled_latents = pipe_upsample(latents=latents, output_type="latent").frames | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Step 3: Decode upscaled latents to frames with improved settings | |
| frames = pipe( | |
| prompt=prompt, # Use original prompt for consistency | |
| latents=upscaled_latents, | |
| width=base_width, | |
| height=base_height, | |
| num_frames=60, | |
| num_inference_steps=12, # Increased for better decoding quality | |
| output_type="pil", | |
| guidance_scale=1.5, # Consistent with generation | |
| decode_timestep=0.08, # Optimized | |
| decode_noise_scale=0.05, # Reduced noise | |
| image_cond_noise_scale=0.02, # Reduced for cleaner output | |
| denoise_strength=0.25, # Balanced denoising | |
| generator=generator | |
| ).frames[0] | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Step 4: Export video | |
| video_path = "output.mp4" | |
| export_to_video(frames, video_path, fps=24) | |
| # Step 5: TTS | |
| tts = gTTS(text=prompt, lang='en') | |
| tts.save("voice.mp3") | |
| AudioSegment.from_mp3("voice.mp3").export("voice.wav", format="wav") | |
| # Step 6: Subtitles | |
| if whisper is not None: | |
| try: | |
| model = whisper.load_model("base", device="cpu") | |
| result = model.transcribe("voice.wav", task="transcribe", language="en") | |
| # Generate SRT subtitles manually since result["srt"] might not be available | |
| srt_content = "" | |
| for i, segment in enumerate(result["segments"]): | |
| start_time = format_time(segment["start"]) | |
| end_time = format_time(segment["end"]) | |
| text = segment["text"].strip() | |
| srt_content += f"{i + 1}\n{start_time} --> {end_time}\n{text}\n\n" | |
| with open("subtitles.srt", "w", encoding="utf-8") as f: | |
| f.write(srt_content) | |
| except Exception as e: | |
| print(f"Whisper transcription failed: {e}") | |
| # Create a simple subtitle with the original prompt | |
| srt_content = f"1\n00:00:00,000 --> 00:00:05,000\n{prompt}\n\n" | |
| with open("subtitles.srt", "w", encoding="utf-8") as f: | |
| f.write(srt_content) | |
| else: | |
| print("Whisper not available, using prompt as subtitle") | |
| # Create a simple subtitle with the original prompt | |
| srt_content = f"1\n00:00:00,000 --> 00:00:05,000\n{prompt}\n\n" | |
| with open("subtitles.srt", "w", encoding="utf-8") as f: | |
| f.write(srt_content) | |
| # Step 7: Merge video + audio + subtitles with proper FFmpeg handling | |
| final_output = "final_with_audio.mp4" | |
| try: | |
| # First, create video with subtitles | |
| video_with_subs = "video_with_subs.mp4" | |
| ( | |
| ffmpeg | |
| .input(video_path) | |
| .filter('subtitles', 'subtitles.srt') | |
| .output(video_with_subs, vcodec='libx264', acodec='aac', loglevel='error') | |
| .overwrite_output() | |
| .run() | |
| ) | |
| # Then add audio track | |
| ( | |
| ffmpeg | |
| .input(video_with_subs) | |
| .input('voice.wav') | |
| .output( | |
| final_output, | |
| vcodec='copy', | |
| acodec='aac', | |
| shortest=None, | |
| loglevel='error' | |
| ) | |
| .overwrite_output() | |
| .run() | |
| ) | |
| return final_output | |
| except Exception as e: | |
| print(f"FFmpeg error: {e}") | |
| # Fallback: try simpler approach without subtitles | |
| try: | |
| ( | |
| ffmpeg | |
| .input(video_path) | |
| .input('voice.wav') | |
| .output( | |
| final_output, | |
| vcodec='libx264', | |
| acodec='aac', | |
| shortest=None, | |
| loglevel='error' | |
| ) | |
| .overwrite_output() | |
| .run() | |
| ) | |
| return final_output | |
| except Exception as e2: | |
| print(f"FFmpeg fallback error: {e2}") | |
| # Final fallback: return original video | |
| return video_path | |
| def format_time(seconds): | |
| """Convert seconds to SRT time format""" | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| millisecs = int((seconds % 1) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=generate_video, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", placeholder="Describe your scene..."), | |
| gr.Textbox(label="Optional Image URL (e.g. Pexels)", placeholder="https://...") | |
| ], | |
| outputs=gr.Video(label="Generated Video"), | |
| title="🎬 LTX AI Video Generator", | |
| description="AI-powered video with voiceover and subtitles. Generates at 256x256 and upscales to 512x512 with improved quality." | |
| ) | |
| demo.launch() |