Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import os | |
| import json | |
| import tempfile | |
| import requests | |
| import subprocess | |
| from pathlib import Path | |
| import torchaudio | |
| from model import Jamify | |
| from utils import json_to_text, text_to_json, convert_text_time_to_beats, convert_text_beats_to_time, convert_text_beats_to_time_with_regrouping, text_to_words, beats_to_text_with_regrouping, round_to_quarter_beats | |
| def crop_audio_to_30_seconds(audio_path): | |
| """Crop audio to first 30 seconds and return path to temporary cropped file""" | |
| if not audio_path or not os.path.exists(audio_path): | |
| return None | |
| try: | |
| # Load audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Calculate 30 seconds in samples | |
| target_samples = sample_rate * 30 | |
| # Crop to first 30 seconds (or full audio if shorter) | |
| if waveform.shape[1] > target_samples: | |
| cropped_waveform = waveform[:, :target_samples] | |
| else: | |
| cropped_waveform = waveform | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
| temp_path = temp_file.name | |
| torchaudio.save(temp_path, cropped_waveform, sample_rate) | |
| return temp_path | |
| except Exception as e: | |
| print(f"Error processing audio: {e}") | |
| return None | |
| def download_resources(): | |
| """Download examples data from GitHub repository if not already present""" | |
| examples_dir = Path("examples") | |
| if examples_dir.exists(): | |
| subprocess.run(["rm", "-rf", str(examples_dir)]) | |
| repo_url = "https://github.com/xhhhhang/jam-examples.git" | |
| subprocess.run(["git", "clone", repo_url, str(examples_dir)], check=True) | |
| public_dir = Path("public") | |
| if public_dir.exists(): | |
| subprocess.run(["rm", "-rf", str(public_dir)]) | |
| repo_url = "https://github.com/xhhhhang/jam-public-resources.git" | |
| subprocess.run(["git", "clone", repo_url, str(public_dir)], check=True) | |
| print('Downloading examples data...') | |
| download_resources() | |
| # Initialize the Jamify model once | |
| print("Initializing Jamify model...") | |
| jamify_model = Jamify() | |
| print("Jamify model ready.") | |
| gr.set_static_paths(paths=[Path.cwd().absolute()]) | |
| def generate_song(reference_audio, lyrics_text, duration, mode="time", bpm=120, style_prompt=None): | |
| # We need to save the uploaded files to temporary paths to pass to the model | |
| reference_audio = reference_audio not in ("", None) and reference_audio or None | |
| # Convert beats to time format if in beats mode | |
| if mode == "beats" and lyrics_text: | |
| try: | |
| lyrics_text = convert_text_beats_to_time(lyrics_text, bpm) | |
| except Exception as e: | |
| print(f"Error converting beats to time: {e}") | |
| # Convert text format to JSON and save to temporary file | |
| lyrics_json = text_to_json(lyrics_text) | |
| # Create temporary file for lyrics JSON | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(lyrics_json, f, indent=2) | |
| lyrics_file = f.name | |
| try: | |
| output_path = jamify_model.predict( | |
| reference_audio_path=reference_audio, | |
| lyrics_json_path=lyrics_file, | |
| style_prompt=style_prompt, | |
| duration=duration | |
| ) | |
| return output_path | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(lyrics_file): | |
| os.unlink(lyrics_file) | |
| # Load and cache examples | |
| def load_examples(): | |
| """Load examples from the examples directory and pre-compute text formats""" | |
| examples = [] | |
| examples_file = "examples/input.json" | |
| if os.path.exists(examples_file): | |
| print("Loading and caching examples...") | |
| with open(examples_file, 'r') as f: | |
| examples_data = json.load(f) | |
| for example in examples_data: | |
| example_id = example.get('id', '') | |
| audio_path = example.get('audio_path', '') | |
| lrc_path = example.get('lrc_path', '') | |
| duration = example.get('duration', 120) | |
| bpm = example.get('bpm', 120.0) # Read BPM from input.json, default to 120 | |
| # Load lyrics and convert to text format (pre-computed/cached) | |
| lyrics_text = "" | |
| if os.path.exists(lrc_path): | |
| try: | |
| with open(lrc_path, 'r') as f: | |
| lyrics_json = json.load(f) | |
| lyrics_text = json_to_text(lyrics_json) | |
| print(f"Cached example {example_id}: {len(lyrics_text)} chars") | |
| except Exception as e: | |
| print(f"Error loading lyrics from {lrc_path}: {e}") | |
| examples.append({ | |
| 'id': example_id, | |
| 'audio_path': audio_path if os.path.exists(audio_path) else None, | |
| 'lyrics_text': lyrics_text, | |
| 'duration': duration, | |
| 'bpm': bpm | |
| }) | |
| print(f"Loaded {len(examples)} cached examples") | |
| return examples | |
| def load_example(example_idx, examples, mode="time"): | |
| """Load a specific example and return its data""" | |
| if 0 <= example_idx < len(examples): | |
| example = examples[example_idx] | |
| lyrics_text = example['lyrics_text'] | |
| bpm = example.get('bpm', 120.0) | |
| # Convert to beats format if in beats mode | |
| if mode == "beats" and lyrics_text: | |
| try: | |
| lyrics_text = beats_to_text_with_regrouping(lyrics_text, bpm, round_to_quarters=True) | |
| except Exception as e: | |
| print(f"Error converting to beats format: {e}") | |
| return ( | |
| example['audio_path'], | |
| lyrics_text, | |
| example['duration'], | |
| bpm | |
| ) | |
| return None, "", 120, 120.0 | |
| def clear_form(): | |
| """Clear all form inputs to allow user to create their own song""" | |
| return None, "", 120, 120.0 # audio, lyrics, duration, bpm | |
| def update_button_styles(selected_idx, total_examples): | |
| """Update button styles to highlight the selected example""" | |
| updates = [] | |
| for i in range(total_examples): | |
| if i == selected_idx: | |
| updates.append(gr.update(variant="primary")) | |
| else: | |
| updates.append(gr.update(variant="secondary")) | |
| # Update "Make Your Own" button | |
| if selected_idx == -1: | |
| make_your_own_update = gr.update(variant="primary") | |
| else: | |
| make_your_own_update = gr.update(variant="secondary") | |
| return updates + [make_your_own_update] | |
| # Load examples at startup | |
| examples = load_examples() | |
| # Get default values from first example | |
| default_audio = examples[0]['audio_path'] if examples else None | |
| default_lyrics = examples[0]['lyrics_text'] if examples else "" | |
| default_duration = examples[0]['duration'] if examples else 120 | |
| default_bpm = examples[0]['bpm'] if examples else 120.0 | |
| # Create cropped version of default audio for display | |
| default_audio_display = crop_audio_to_30_seconds(default_audio) if default_audio else None | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Jamify: Music Generation from Lyrics and Style") | |
| gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.") | |
| # State to track selected example (-1 means "Make Your Own" is selected, 0 is first example) | |
| selected_example = gr.State(0 if examples else -1) | |
| # States for mode and BPM | |
| input_mode = gr.State("time") | |
| current_bpm = gr.State(default_bpm) | |
| # Sample buttons section | |
| if examples: | |
| gr.Markdown("### Sample Examples") | |
| with gr.Row(): | |
| example_buttons = [] | |
| for i, example in enumerate(examples): | |
| # Use consistent button width and truncate long IDs if needed | |
| button_text = example['id'][:12] + "..." if len(example['id']) > 15 else example['id'] | |
| # First button starts as primary (selected), others as secondary | |
| initial_variant = "primary" if i == 0 else "secondary" | |
| button = gr.Button( | |
| button_text, | |
| variant=initial_variant, | |
| size="sm", | |
| scale=1, # Equal width for all buttons | |
| min_width=80 # Minimum consistent width | |
| ) | |
| example_buttons.append(button) | |
| # Add "Make Your Own" button with same sizing (starts as secondary since first example is selected) | |
| make_your_own_button = gr.Button( | |
| "🎵 Make Your Own", | |
| variant="secondary", | |
| size="sm", | |
| scale=1, | |
| min_width=80 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Inputs") | |
| # Mode switcher | |
| mode_radio = gr.Radio( | |
| choices=["Time Mode", "Beats Mode"], | |
| value="Time Mode", | |
| label="Input Format", | |
| info="Choose how to specify timing: seconds or musical beats" | |
| ) | |
| # BPM input (initially hidden) | |
| bpm_input = gr.Number( | |
| label="BPM (Beats Per Minute)", | |
| value=default_bpm, | |
| minimum=60, | |
| maximum=200, | |
| step=1, | |
| visible=False, | |
| info="Tempo for converting beats to time" | |
| ) | |
| lyrics_text = gr.Textbox( | |
| label="Lyrics", | |
| lines=10, | |
| placeholder="Enter lyrics with timestamps: word[start_time:end_time] word[start_time:end_time]...\n\nExample: Hello[0.0:1.2] world[1.5:2.8] this[3.0:3.8] is[4.2:4.6] my[5.0:5.8] song[6.2:7.0]\n\nFormat: Each word followed by [start_seconds:end_seconds] in brackets\nTimestamps should be in seconds with up to 2 decimal places", | |
| value=default_lyrics | |
| ) | |
| duration_slider = gr.Slider(minimum=120, maximum=230, value=default_duration, step=1, label="Duration (seconds)") | |
| with gr.Column(): | |
| gr.Markdown("### Style & Generation") | |
| with gr.Tab("Style from Audio"): | |
| reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath", value=default_audio) | |
| reference_audio_display = gr.Audio( | |
| label="Reference Audio (Only first 30 seconds will be used for generation)", | |
| value=default_audio_display, | |
| visible=default_audio_display is not None | |
| ) | |
| generate_button = gr.Button("Generate Song", variant="primary") | |
| gr.Markdown("### Output") | |
| output_audio = gr.Audio(label="Generated Song") | |
| # Mode switching functions | |
| def switch_mode(mode_choice, current_lyrics, current_bpm_val): | |
| """Handle switching between time and beats mode""" | |
| mode = "beats" if mode_choice == "Beats Mode" else "time" | |
| # Update BPM input visibility | |
| bpm_visible = (mode == "beats") | |
| # Update lyrics placeholder and convert existing text | |
| if mode == "time": | |
| placeholder = "Enter lyrics with timestamps: word[start_time:end_time] word[start_time:end_time]...\n\nExample: Hello[0.0:1.2] world[1.5:2.8] this[3.0:3.8] is[4.2:4.6] my[5.0:5.8] song[6.2:7.0]\n\nFormat: Each word followed by [start_seconds:end_seconds] in brackets\nTimestamps should be in seconds with up to 2 decimal places" | |
| label = "Lyrics" | |
| # Convert from beats to time if there's content | |
| converted_lyrics = current_lyrics | |
| if current_lyrics.strip(): | |
| try: | |
| converted_lyrics = convert_text_beats_to_time_with_regrouping(current_lyrics, current_bpm_val) | |
| except Exception as e: | |
| print(f"Error converting beats to time: {e}") | |
| else: | |
| placeholder = "Enter lyrics with beat timestamps: word[start_beat:end_beat] word[start_beat:end_beat]...\n\nExample: Hello[0:1] world[1.5:2.75] this[3:3.75] is[4.25:4.5] my[5:5.75] song[6.25:7]\n\nFormat: Each word followed by [start_beat:end_beat] in brackets\nBeats are in quarter notes (1 beat = quarter note, 0.25 = sixteenth note)" | |
| label = "Lyrics (Beats Format)" | |
| # Convert from time to beats if there's content | |
| converted_lyrics = current_lyrics | |
| if current_lyrics.strip(): | |
| try: | |
| converted_lyrics = beats_to_text_with_regrouping(current_lyrics, current_bpm_val, round_to_quarters=True) | |
| except Exception as e: | |
| print(f"Error converting time to beats: {e}") | |
| return ( | |
| gr.update(visible=bpm_visible), # bpm_input visibility | |
| gr.update(placeholder=placeholder, label=label, value=converted_lyrics), # lyrics_text | |
| mode # input_mode state | |
| ) | |
| def update_bpm_state(bpm_val): | |
| """Update the BPM state""" | |
| return bpm_val | |
| def update_reference_audio_display(audio_file): | |
| """Process and display the cropped reference audio""" | |
| if audio_file is None: | |
| return gr.update(visible=False, value=None) | |
| cropped_path = crop_audio_to_30_seconds(audio_file) | |
| if cropped_path: | |
| return gr.update(visible=True, value=cropped_path) | |
| else: | |
| return gr.update(visible=False, value=None) | |
| # Connect mode switching | |
| mode_radio.change( | |
| fn=switch_mode, | |
| inputs=[mode_radio, lyrics_text, current_bpm], | |
| outputs=[bpm_input, lyrics_text, input_mode] | |
| ) | |
| # Connect BPM changes | |
| bpm_input.change( | |
| fn=update_bpm_state, | |
| inputs=[bpm_input], | |
| outputs=[current_bpm] | |
| ) | |
| # Connect reference audio file changes to display | |
| reference_audio.change( | |
| fn=update_reference_audio_display, | |
| inputs=[reference_audio], | |
| outputs=[reference_audio_display] | |
| ) | |
| generate_button.click( | |
| fn=generate_song, | |
| inputs=[reference_audio, lyrics_text, duration_slider, input_mode, current_bpm], | |
| outputs=output_audio, | |
| api_name="generate_song" | |
| ) | |
| # Connect example buttons to load data and update selection | |
| if examples: | |
| def load_example_and_update_selection(idx, current_mode): | |
| """Load example data and update button selection state""" | |
| mode = "beats" if current_mode == "Beats Mode" else "time" | |
| audio, lyrics, duration, bpm = load_example(idx, examples, mode) | |
| button_updates = update_button_styles(idx, len(examples)) | |
| audio_display_update = update_reference_audio_display(audio) | |
| return [audio, lyrics, duration, bpm, idx, audio_display_update] + button_updates | |
| def clear_form_and_update_selection(): | |
| """Clear form and update button selection state""" | |
| audio, lyrics, duration, bpm = clear_form() | |
| button_updates = update_button_styles(-1, len(examples)) | |
| audio_display_update = update_reference_audio_display(audio) | |
| return [audio, lyrics, duration, bpm, -1, audio_display_update] + button_updates | |
| for i, button in enumerate(example_buttons): | |
| button.click( | |
| fn=lambda current_mode, idx=i: load_example_and_update_selection(idx, current_mode), | |
| inputs=[mode_radio], | |
| outputs=[reference_audio, lyrics_text, duration_slider, current_bpm, selected_example, reference_audio_display] + example_buttons + [make_your_own_button] | |
| ) | |
| # Connect "Make Your Own" button to clear form and update selection | |
| make_your_own_button.click( | |
| fn=clear_form_and_update_selection, | |
| outputs=[reference_audio, lyrics_text, duration_slider, current_bpm, selected_example, reference_audio_display] + example_buttons + [make_your_own_button] | |
| ) | |
| # Create necessary temporary directories for Gradio | |
| print("Creating temporary directories...") | |
| try: | |
| os.makedirs("/tmp/gradio", exist_ok=True) | |
| print("Temporary directories created successfully.") | |
| except Exception as e: | |
| print(f"Warning: Could not create temporary directories: {e}") | |
| demo.queue().launch(share=True) |