Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torchaudio | |
| import io | |
| import base64 | |
| import uuid | |
| import os | |
| import time | |
| import re | |
| import threading | |
| import gc | |
| import random | |
| import numpy as np | |
| from einops import rearrange | |
| from huggingface_hub import login | |
| from stable_audio_tools import get_pretrained_model | |
| from stable_audio_tools.inference.generation import generate_diffusion_cond | |
| from gradio_client import Client, handle_file | |
| from contextlib import contextmanager | |
| # Global model storage | |
| model_cache = {} | |
| model_lock = threading.Lock() | |
| def resource_cleanup(): | |
| """Lightweight context manager - let zerogpu handle memory management""" | |
| try: | |
| yield | |
| finally: | |
| # Minimal cleanup - let zerogpu handle the heavy lifting | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # Removed aggressive empty_cache() and gc.collect() calls | |
| def load_stable_audio_model(): | |
| """Load stable-audio-open-small model if not already loaded.""" | |
| with model_lock: | |
| if 'stable_audio_model' not in model_cache: | |
| print("π Loading stable-audio-open-small model...") | |
| load_start = time.time() | |
| # Authenticate with HF | |
| hf_token = os.getenv('HF_TOKEN') | |
| if hf_token: | |
| login(token=hf_token) | |
| print(f"β HF authenticated") | |
| # Load model | |
| model, config = get_pretrained_model("stabilityai/stable-audio-open-small") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| if device == "cuda": | |
| model = model.half() | |
| load_time = time.time() - load_start | |
| print(f"β Model loaded on {device} in {load_time:.2f}s") | |
| # Aggressive model persistence - warm up with dummy generation | |
| print("π₯ Warming up model...") | |
| warmup_start = time.time() | |
| try: | |
| dummy_conditioning = [{"prompt": "test", "seconds_total": 12}] | |
| with torch.no_grad(): | |
| _ = generate_diffusion_cond( | |
| model, | |
| steps=1, # Minimal steps for warmup | |
| cfg_scale=1.0, | |
| conditioning=dummy_conditioning, | |
| sample_size=config["sample_size"], | |
| sampler_type="pingpong", | |
| device=device, | |
| seed=42 | |
| ) | |
| warmup_time = time.time() - warmup_start | |
| print(f"π₯ Model warmed up in {warmup_time:.2f}s") | |
| except Exception as e: | |
| print(f"β οΈ Warmup failed (but continuing): {e}") | |
| model_cache['stable_audio_model'] = model | |
| model_cache['stable_audio_config'] = config | |
| model_cache['stable_audio_device'] = device | |
| print(f"β Stable Audio model ready for fast generation!") | |
| else: | |
| print("β»οΈ Using cached model (should be fast!)") | |
| return (model_cache['stable_audio_model'], | |
| model_cache['stable_audio_config'], | |
| model_cache['stable_audio_device']) | |
| def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1): | |
| """Generate a BPM-aware loop using stable-audio-open-small""" | |
| try: | |
| total_start = time.time() | |
| # Model loading timing | |
| load_start = time.time() | |
| model, config, device = load_stable_audio_model() | |
| load_time = time.time() - load_start | |
| # Calculate loop duration based on BPM and bars | |
| seconds_per_beat = 60.0 / bpm | |
| seconds_per_bar = seconds_per_beat * 4 # 4/4 time | |
| target_loop_duration = seconds_per_bar * bars | |
| # Enhance prompt based on loop type and BPM | |
| if loop_type == "drums": | |
| enhanced_prompt = f"{prompt} drum loop {bpm}bpm" | |
| negative_prompt = "melody, harmony, pitched instruments, vocals, singing" | |
| else: # instruments | |
| enhanced_prompt = f"{prompt} instrumental loop {bpm}bpm" | |
| negative_prompt = "drums, percussion, kick, snare, hi-hat" | |
| # Set seed | |
| if seed == -1: | |
| seed = random.randint(0, 2**32 - 1) | |
| torch.manual_seed(seed) | |
| if device == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| print(f"π΅ Generating {loop_type} loop:") | |
| print(f" Enhanced prompt: {enhanced_prompt}") | |
| print(f" Target duration: {target_loop_duration:.2f}s ({bars} bars at {bpm}bpm)") | |
| print(f" Seed: {seed}") | |
| # Prepare conditioning | |
| conditioning_start = time.time() | |
| conditioning = [{ | |
| "prompt": enhanced_prompt, | |
| "seconds_total": 12 # Model generates 12s max | |
| }] | |
| negative_conditioning = [{ | |
| "prompt": negative_prompt, | |
| "seconds_total": 12 | |
| }] | |
| conditioning_time = time.time() - conditioning_start | |
| # Generation timing | |
| generation_start = time.time() | |
| # Removed aggressive resource cleanup wrapper | |
| # Clear GPU cache once before generation (not after) | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| with torch.cuda.amp.autocast(enabled=(device == "cuda")): | |
| output = generate_diffusion_cond( | |
| model, | |
| steps=8, # Fast generation | |
| cfg_scale=1.0, # Good balance for loops | |
| conditioning=conditioning, | |
| negative_conditioning=negative_conditioning, | |
| sample_size=config["sample_size"], | |
| sampler_type="pingpong", | |
| device=device, | |
| seed=seed | |
| ) | |
| generation_time = time.time() - generation_start | |
| # Post-processing timing | |
| postproc_start = time.time() | |
| # Post-process audio | |
| output = rearrange(output, "b d n -> d (b n)") # (2, N) stereo | |
| output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1) | |
| # Extract the loop portion | |
| sample_rate = config["sample_rate"] | |
| loop_samples = int(target_loop_duration * sample_rate) | |
| available_samples = output.shape[1] | |
| if loop_samples > available_samples: | |
| loop_samples = available_samples | |
| actual_duration = available_samples / sample_rate | |
| print(f"β οΈ Requested {target_loop_duration:.2f}s, got {actual_duration:.2f}s") | |
| # Extract loop from beginning (cleanest beat alignment) | |
| loop_output = output[:, :loop_samples] | |
| loop_output_int16 = loop_output.mul(32767).to(torch.int16).cpu() | |
| # Save to temporary file | |
| loop_filename = f"loop_{loop_type}_{bpm}bpm_{bars}bars_{seed}.wav" | |
| torchaudio.save(loop_filename, loop_output_int16, sample_rate) | |
| postproc_time = time.time() - postproc_start | |
| total_time = time.time() - total_start | |
| actual_duration = loop_samples / sample_rate | |
| # Detailed timing breakdown | |
| print(f"β±οΈ Timing breakdown:") | |
| print(f" Model load: {load_time:.2f}s") | |
| print(f" Conditioning: {conditioning_time:.3f}s") | |
| print(f" Generation: {generation_time:.2f}s") | |
| print(f" Post-processing: {postproc_time:.3f}s") | |
| print(f" Total: {total_time:.2f}s") | |
| print(f"β {loop_type.title()} loop: {actual_duration:.2f}s audio in {total_time:.2f}s") | |
| return loop_filename, f"Generated {actual_duration:.2f}s {loop_type} loop at {bpm}bpm ({bars} bars) in {total_time:.2f}s" | |
| except Exception as e: | |
| print(f"β Generation error: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| def combine_loops(drums_audio, instruments_audio, bpm, bars, num_repeats): | |
| """Combine drum and instrument loops with specified repetitions""" | |
| try: | |
| if not drums_audio and not instruments_audio: | |
| return None, "No audio files to combine" | |
| # Calculate timing | |
| seconds_per_beat = 60.0 / bpm | |
| seconds_per_bar = seconds_per_beat * 4 | |
| loop_duration = seconds_per_bar * bars | |
| total_duration = loop_duration * num_repeats | |
| print(f"ποΈ Combining loops:") | |
| print(f" Loop duration: {loop_duration:.2f}s ({bars} bars)") | |
| print(f" Repeats: {num_repeats}") | |
| print(f" Total duration: {total_duration:.2f}s") | |
| combined_audio = None | |
| sample_rate = None | |
| # Process each audio file | |
| for audio_path, audio_type in [(drums_audio, "drums"), (instruments_audio, "instruments")]: | |
| if audio_path: | |
| # Load audio | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sample_rate is None: | |
| sample_rate = sr | |
| # Ensure we have the exact loop duration | |
| target_samples = int(loop_duration * sr) | |
| if waveform.shape[1] > target_samples: | |
| waveform = waveform[:, :target_samples] | |
| elif waveform.shape[1] < target_samples: | |
| # Pad if necessary | |
| padding = target_samples - waveform.shape[1] | |
| waveform = torch.cat([waveform, torch.zeros(waveform.shape[0], padding)], dim=1) | |
| # Repeat the loop | |
| repeated_waveform = waveform.repeat(1, num_repeats) | |
| print(f" {audio_type}: {waveform.shape[1]/sr:.2f}s repeated {num_repeats}x = {repeated_waveform.shape[1]/sr:.2f}s") | |
| # Add to combined audio | |
| if combined_audio is None: | |
| combined_audio = repeated_waveform | |
| else: | |
| combined_audio = combined_audio + repeated_waveform | |
| if combined_audio is None: | |
| return None, "No valid audio to combine" | |
| # Normalize to prevent clipping | |
| combined_audio = combined_audio / torch.max(torch.abs(combined_audio)) | |
| combined_audio = combined_audio.clamp(-1, 1) | |
| # Convert to int16 and save | |
| combined_audio_int16 = combined_audio.mul(32767).to(torch.int16) | |
| combined_filename = f"combined_{bpm}bpm_{bars}bars_{num_repeats}loops_{random.randint(1000, 9999)}.wav" | |
| torchaudio.save(combined_filename, combined_audio_int16, sample_rate) | |
| actual_duration = combined_audio.shape[1] / sample_rate | |
| status = f"Combined into {actual_duration:.2f}s audio ({num_repeats} Γ {bars} bars at {bpm}bpm)" | |
| print(f"β {status}") | |
| return combined_filename, status | |
| except Exception as e: | |
| print(f"β Combine error: {str(e)}") | |
| return None, f"Combine error: {str(e)}" | |
| def transform_with_melodyflow_api(audio_path, prompt, solver="euler", flowstep=0.12): | |
| """Transform audio using Facebook/MelodyFlow space API""" | |
| if audio_path is None: | |
| return None, "β No audio file provided" | |
| try: | |
| # Initialize client for Facebook MelodyFlow space | |
| client = Client("facebook/MelodyFlow") | |
| # Set steps based on solver | |
| if solver == "midpoint": | |
| base_steps = 128 | |
| effective_steps = base_steps // 2 # 64 effective steps | |
| else: # euler | |
| base_steps = 125 | |
| effective_steps = base_steps // 5 # 25 effective steps | |
| print(f"ποΈ MelodyFlow transformation:") | |
| print(f" Prompt: {prompt}") | |
| print(f" Solver: {solver} ({effective_steps} effective steps)") | |
| print(f" Flowstep: {flowstep}") | |
| # Call the MelodyFlow API | |
| result = client.predict( | |
| model="facebook/melodyflow-t24-30secs", | |
| text=prompt, | |
| solver=solver, | |
| steps=base_steps, | |
| target_flowstep=flowstep, | |
| regularize=solver == "euler", | |
| regularization_strength=0.2, | |
| duration=30, | |
| melody=handle_file(audio_path), | |
| api_name="/predict" | |
| ) | |
| if result and len(result) > 0 and result[0]: | |
| # Save the result locally | |
| output_filename = f"melodyflow_transformed_{random.randint(1000, 9999)}.wav" | |
| import shutil | |
| shutil.copy2(result[0], output_filename) | |
| status_msg = f"β Transformed with prompt: '{prompt}' (flowstep: {flowstep}, {effective_steps} steps)" | |
| return output_filename, status_msg | |
| else: | |
| return None, "β MelodyFlow API returned no results" | |
| except Exception as e: | |
| return None, f"β MelodyFlow API error: {str(e)}" | |
| def calculate_optimal_bars(bpm): | |
| """Calculate optimal bar count for given BPM to fit in ~10s""" | |
| seconds_per_beat = 60.0 / bpm | |
| seconds_per_bar = seconds_per_beat * 4 | |
| max_duration = 10.0 | |
| for bars in [8, 4, 2, 1]: | |
| if seconds_per_bar * bars <= max_duration: | |
| return bars | |
| return 1 | |
| # ========== GRADIO INTERFACE ========== | |
| with gr.Blocks(title="π΅ Stable Audio Loop Generator") as iface: | |
| gr.Markdown("# π΅ Stable Audio Loop Generator") | |
| gr.Markdown("**Generate synchronized drum and instrument loops with stable-audio-open-small, then transform with MelodyFlow!**") | |
| with gr.Accordion("How This Works", open=False): | |
| gr.Markdown(""" | |
| **Workflow:** | |
| 1. **Set global BPM and bars** - affects both drum and instrument generation | |
| 2. **Generate drum loop** - creates BPM-aware percussion | |
| 3. **Generate instrument loop** - creates melodic/harmonic content | |
| 4. **Combine loops** - layer them together with repetitions (up to 30s) | |
| 5. **Transform** - use MelodyFlow to stylistically transform the combined result | |
| **Features:** | |
| - BPM-aware generation ensures perfect sync between loops | |
| - Negative prompting separates drums from instruments cleanly | |
| - Smart bar calculation optimizes loop length for the BPM | |
| - MelodyFlow integration for advanced style transfer | |
| """) | |
| # ========== GLOBAL CONTROLS ========== | |
| gr.Markdown("## ποΈ Global Settings") | |
| with gr.Row(): | |
| global_bpm = gr.Dropdown( | |
| label="Global BPM", | |
| choices=[90, 100, 110, 120, 130, 140, 150], | |
| value=120, | |
| info="BPM applied to both drum and instrument generation" | |
| ) | |
| global_bars = gr.Dropdown( | |
| label="Loop Length (Bars)", | |
| choices=[1, 2, 4, 8], | |
| value=4, | |
| info="Number of bars for each loop" | |
| ) | |
| base_prompt = gr.Textbox( | |
| label="Base Prompt", | |
| value="techno", | |
| placeholder="e.g., 'techno', 'jazz', 'ambient', 'hip-hop'", | |
| info="Style applied to both loops" | |
| ) | |
| # Auto-suggest optimal bars based on BPM | |
| def update_suggested_bars(bpm): | |
| optimal = calculate_optimal_bars(bpm) | |
| return gr.update(info=f"Suggested: {optimal} bars for {bpm}bpm (β€10s)") | |
| global_bpm.change(update_suggested_bars, inputs=[global_bpm], outputs=[global_bars]) | |
| # ========== LOOP GENERATION ========== | |
| gr.Markdown("## π₯ Step 1: Generate Individual Loops") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π₯ Drum Loop") | |
| generate_drums_btn = gr.Button("Generate Drums", variant="primary", size="lg") | |
| drums_audio = gr.Audio(label="Drum Loop", type="filepath") | |
| drums_status = gr.Textbox(label="Drums Status", value="Ready to generate") | |
| with gr.Column(): | |
| gr.Markdown("### πΉ Instrument Loop") | |
| generate_instruments_btn = gr.Button("Generate Instruments", variant="secondary", size="lg") | |
| instruments_audio = gr.Audio(label="Instrument Loop", type="filepath") | |
| instruments_status = gr.Textbox(label="Instruments Status", value="Ready to generate") | |
| # Seed controls | |
| with gr.Row(): | |
| drums_seed = gr.Number(label="Drums Seed", value=-1, info="-1 for random") | |
| instruments_seed = gr.Number(label="Instruments Seed", value=-1, info="-1 for random") | |
| # ========== COMBINATION ========== | |
| gr.Markdown("## ποΈ Step 2: Combine Loops") | |
| with gr.Row(): | |
| num_repeats = gr.Slider( | |
| label="Number of Repetitions", | |
| minimum=1, | |
| maximum=5, | |
| step=1, | |
| value=2, | |
| info="How many times to repeat each loop (creates longer audio)" | |
| ) | |
| combine_btn = gr.Button("ποΈ Combine Loops", variant="primary", size="lg") | |
| combined_audio = gr.Audio(label="Combined Loops", type="filepath") | |
| combine_status = gr.Textbox(label="Combine Status", value="Generate loops first") | |
| # ========== MELODYFLOW TRANSFORMATION ========== | |
| gr.Markdown("## π¨ Step 3: Transform with MelodyFlow") | |
| with gr.Row(): | |
| with gr.Column(): | |
| transform_prompt = gr.Textbox( | |
| label="Transformation Prompt", | |
| value="aggressive industrial techno with distorted sounds", | |
| placeholder="Describe the style transformation", | |
| lines=2 | |
| ) | |
| with gr.Column(): | |
| transform_solver = gr.Dropdown( | |
| label="Solver", | |
| choices=["euler", "midpoint"], | |
| value="euler", | |
| info="EULER: faster (25 steps), MIDPOINT: slower (64 steps)" | |
| ) | |
| transform_flowstep = gr.Slider( | |
| label="Transform Intensity", | |
| minimum=0.0, | |
| maximum=0.15, | |
| step=0.01, | |
| value=0.12, | |
| info="Lower = more dramatic transformation" | |
| ) | |
| transform_btn = gr.Button("π¨ Transform Audio", variant="secondary", size="lg") | |
| transformed_audio = gr.Audio(label="Transformed Audio", type="filepath") | |
| transform_status = gr.Textbox(label="Transform Status", value="Combine audio first") | |
| # ========== EVENT HANDLERS ========== | |
| # Generate drums | |
| generate_drums_btn.click( | |
| generate_stable_audio_loop, | |
| inputs=[base_prompt, gr.State("drums"), global_bpm, global_bars, drums_seed], | |
| outputs=[drums_audio, drums_status] | |
| ) | |
| # Generate instruments | |
| generate_instruments_btn.click( | |
| generate_stable_audio_loop, | |
| inputs=[base_prompt, gr.State("instruments"), global_bpm, global_bars, instruments_seed], | |
| outputs=[instruments_audio, instruments_status] | |
| ) | |
| # Combine loops | |
| combine_btn.click( | |
| combine_loops, | |
| inputs=[drums_audio, instruments_audio, global_bpm, global_bars, num_repeats], | |
| outputs=[combined_audio, combine_status] | |
| ) | |
| # Transform with MelodyFlow | |
| transform_btn.click( | |
| transform_with_melodyflow_api, | |
| inputs=[combined_audio, transform_prompt, transform_solver, transform_flowstep], | |
| outputs=[transformed_audio, transform_status] | |
| ) | |
| # ========== EXAMPLES ========== | |
| gr.Markdown("## π― Example Workflows") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["techno", 128, 4, "aggressive industrial techno"], | |
| ["jazz", 110, 2, "smooth lo-fi jazz with vinyl crackle"], | |
| ["ambient", 90, 8, "ethereal ambient soundscape"], | |
| ["hip-hop", 100, 4, "classic boom bap hip-hop"], | |
| ["drum and bass", 140, 4, "liquid drum and bass"], | |
| ], | |
| inputs=[base_prompt, global_bpm, global_bars, transform_prompt], | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |