Spaces:
Running
on
L40S
Running
on
L40S
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
| import scipy.io.wavfile | |
| import numpy as np | |
| import subprocess | |
| import sys | |
| import os | |
| def setup_flash_attention(): | |
| """One-time setup for flash-attention with special flags""" | |
| # Check if flash-attn is already installed | |
| try: | |
| import flash_attn | |
| print("flash-attn already installed") | |
| return | |
| except ImportError: | |
| pass | |
| # Check if we've already tried to install it in this session | |
| if os.path.exists("/tmp/flash_attn_installed"): | |
| return | |
| try: | |
| print("Installing flash-attn with --no-build-isolation...") | |
| subprocess.run([ | |
| sys.executable, "-m", "pip", "install", | |
| "flash-attn==2.7.3", "--no-build-isolation" | |
| ], check=True) | |
| # Uninstall apex if it exists | |
| subprocess.run([ | |
| sys.executable, "-m", "pip", "uninstall", "apex", "-y" | |
| ], check=False) # Don't fail if apex isn't installed | |
| # Mark as installed | |
| with open("/tmp/flash_attn_installed", "w") as f: | |
| f.write("installed") | |
| print("flash-attn installation completed") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Warning: Failed to install flash-attn: {e}") | |
| # Continue anyway - the model might work without it | |
| # Run setup once when the module is imported | |
| setup_flash_attention() | |
| # Load model and processor | |
| def load_model(): | |
| """Load the musicgen model and processor""" | |
| processor = AutoProcessor.from_pretrained("facebook/musicgen-large") | |
| model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large") | |
| return processor, model | |
| def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0): | |
| """Generate music based on text prompt""" | |
| try: | |
| processor, model = load_model() | |
| # Process the text prompt | |
| inputs = processor( | |
| text=[text_prompt], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Generate audio | |
| with torch.no_grad(): | |
| audio_values = model.generate( | |
| **inputs, | |
| max_new_tokens=duration * 50, # Approximate tokens per second | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| ) | |
| # Convert to numpy array and prepare for output | |
| audio_data = audio_values[0, 0].cpu().numpy() | |
| sample_rate = model.config.sample_rate | |
| # Normalize audio | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| return sample_rate, audio_data | |
| except Exception as e: | |
| return None, f"Error generating music: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="MusicGen Large - Music Generation") as demo: | |
| gr.Markdown("# π΅ MusicGen Large Music Generator") | |
| gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Music Description", | |
| placeholder="Enter a description of the music you want to generate (e.g., 'upbeat jazz with piano and drums')", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| duration = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| label="Duration (seconds)" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature (creativity)" | |
| ) | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=500, | |
| value=250, | |
| step=1, | |
| label="Top-k" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.1, | |
| label="Top-p" | |
| ) | |
| generate_btn = gr.Button("π΅ Generate Music", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| label="Generated Music", | |
| type="numpy" | |
| ) | |
| gr.Markdown("### Tips:") | |
| gr.Markdown(""" | |
| - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") | |
| - Higher temperature = more creative/random results | |
| - Lower temperature = more predictable results | |
| - Duration is limited to 30 seconds for faster generation | |
| """) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["upbeat jazz with piano and drums"], | |
| ["relaxing acoustic guitar melody"], | |
| ["electronic dance music with heavy bass"], | |
| ["classical violin concerto"], | |
| ["reggae with steel drums and bass"], | |
| ["rock ballad with electric guitar solo"], | |
| ], | |
| inputs=text_input, | |
| label="Example Prompts" | |
| ) | |
| # Connect the generate button to the function | |
| generate_btn.click( | |
| fn=generate_music, | |
| inputs=[text_input, duration, temperature, top_k, top_p], | |
| outputs=audio_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |