Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import logging | |
| import os | |
| from concurrent.futures import ProcessPoolExecutor | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| import time | |
| import typing as tp | |
| import subprocess as sp | |
| import torch | |
| import gradio as gr | |
| from audiocraft.data.audio_utils import f32_pcm, normalize_audio | |
| from audiocraft.data.audio import audio_write | |
| from audiocraft.models import JASCO | |
| MODEL = None | |
| MAX_BATCH_SIZE = 12 | |
| INTERRUPTING = False | |
| # Wrap subprocess call to clean logs | |
| _old_call = sp.call | |
| def _call_nostderr(*args, **kwargs): | |
| kwargs['stderr'] = sp.DEVNULL | |
| kwargs['stdout'] = sp.DEVNULL | |
| _old_call(*args, **kwargs) | |
| sp.call = _call_nostderr | |
| # Preallocate process pool | |
| pool = ProcessPoolExecutor(4) | |
| pool.__enter__() | |
| def interrupt(): | |
| global INTERRUPTING | |
| INTERRUPTING = True | |
| class FileCleaner: | |
| def __init__(self, file_lifetime: float = 3600): | |
| self.file_lifetime = file_lifetime | |
| self.files = [] | |
| def add(self, path: tp.Union[str, Path]): | |
| self._cleanup() | |
| self.files.append((time.time(), Path(path))) | |
| def _cleanup(self): | |
| now = time.time() | |
| for time_added, path in list(self.files): | |
| if now - time_added > self.file_lifetime: | |
| if path.exists(): | |
| path.unlink() | |
| self.files.pop(0) | |
| else: | |
| break | |
| file_cleaner = FileCleaner() | |
| def chords_string_to_list(chords: str): | |
| if chords == '': | |
| return [] | |
| chords = chords.replace('[', '').replace(']', '').replace(' ', '') | |
| chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] | |
| return [(x[0], float(x[1])) for x in chrd_times] | |
| def load_model(version='facebook/jasco-chords-drums-400M'): | |
| global MODEL | |
| print("Loading model", version) | |
| if MODEL is None or MODEL.name != version: | |
| MODEL = None | |
| MODEL = JASCO.get_pretrained(version) | |
| def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): | |
| MODEL.set_generation_params(**gen_kwargs) | |
| be = time.time() | |
| chords = chords_string_to_list(chords) | |
| if melody_matrix is not None: | |
| melody_matrix = torch.load(melody_matrix.name, weights_only=True) | |
| if len(melody_matrix.shape) != 2: | |
| raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") | |
| if melody_matrix.shape[0] > melody_matrix.shape[1]: | |
| melody_matrix = melody_matrix.permute(1, 0) | |
| if drum_prompt is None: | |
| preprocessed_drums_wav = None | |
| drums_sr = 32000 | |
| else: | |
| drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() | |
| if drums.dim() == 1: | |
| drums = drums[None] | |
| drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) | |
| preprocessed_drums_wav = drums | |
| try: | |
| outputs = MODEL.generate_music(descriptions=texts, chords=chords, | |
| drums_wav=preprocessed_drums_wav, | |
| melody_salience_matrix=melody_matrix, | |
| drums_sample_rate=drums_sr, progress=progress) | |
| except RuntimeError as e: | |
| raise gr.Error("Error while generating " + e.args[0]) | |
| outputs = outputs.detach().cpu().float() | |
| out_wavs = [] | |
| for output in outputs: | |
| with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
| audio_write( | |
| file.name, output, MODEL.sample_rate, strategy="loudness", | |
| loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
| out_wavs.append(file.name) | |
| file_cleaner.add(file.name) | |
| return out_wavs | |
| def predict_full(model, text, chords_sym, melody_file, | |
| drums_file, drums_mic, drum_input_src, | |
| cfg_coef_all, cfg_coef_txt, | |
| ode_rtol, ode_atol, | |
| ode_solver, ode_steps, | |
| progress=gr.Progress()): | |
| global INTERRUPTING | |
| INTERRUPTING = False | |
| progress(0, desc="Loading model...") | |
| load_model(model) | |
| max_generated = 0 | |
| def _progress(generated, to_generate): | |
| nonlocal max_generated | |
| max_generated = max(generated, max_generated) | |
| progress((min(max_generated, to_generate), to_generate)) | |
| if INTERRUPTING: | |
| raise gr.Error("Interrupted.") | |
| MODEL.set_custom_progress_callback(_progress) | |
| drums = drums_mic if drum_input_src == "mic" else drums_file | |
| wavs = _do_predictions( | |
| texts=[text] * 2, | |
| chords=chords_sym, | |
| drum_prompt=drums, | |
| melody_matrix=melody_file, | |
| progress=True, | |
| gradio_progress=progress, | |
| cfg_coef_all=cfg_coef_all, | |
| cfg_coef_txt=cfg_coef_txt, | |
| ode_rtol=ode_rtol, | |
| ode_atol=ode_atol, | |
| euler=ode_solver == 'euler', | |
| euler_steps=ode_steps) | |
| return wavs | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # JASCO - Text-to-Music Generation with Temporal Control | |
| Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| submit = gr.Button("Generate") | |
| interrupt_btn = gr.Button("Interrupt") | |
| with gr.Column(): | |
| audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath') | |
| audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath') | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Text(label="Input Text", | |
| value="Strings, woodwind, orchestral, symphony.", | |
| interactive=True) | |
| with gr.Column(): | |
| model = gr.Radio([ | |
| 'facebook/jasco-chords-drums-400M', | |
| 'facebook/jasco-chords-drums-1B', | |
| 'facebook/jasco-chords-drums-melody-400M', | |
| 'facebook/jasco-chords-drums-melody-1B' | |
| ], label="Model", value='facebook/jasco-chords-drums-melody-400M') | |
| gr.Markdown("### Chords Conditions") | |
| chords_sym = gr.Text( | |
| label="Chord Progression", | |
| value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", | |
| interactive=True | |
| ) | |
| gr.Markdown("### Drums Conditions") | |
| with gr.Row(): | |
| drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source") | |
| drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File") | |
| drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic") | |
| gr.Markdown("### Melody Conditions") | |
| melody_file = gr.File(label="Melody File") | |
| with gr.Row(): | |
| cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25) | |
| cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25) | |
| ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5) | |
| ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler') | |
| ode_steps = gr.Number(label="Euler Steps", value=10, step=1) | |
| submit.click( | |
| fn=predict_full, | |
| inputs=[ | |
| model, text, chords_sym, melody_file, | |
| drums_file, drums_mic, drum_input_src, | |
| cfg_coef_all, cfg_coef_txt, | |
| ode_tol, ode_tol, ode_solver, ode_steps | |
| ], | |
| outputs=[audio_output_0, audio_output_1] | |
| ) | |
| interrupt_btn.click(fn=interrupt, queue=False) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "80s pop with groovy synth bass and electric piano", | |
| "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", | |
| None, | |
| None, | |
| ], | |
| [ | |
| "Strings, woodwind, orchestral, symphony.", | |
| "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", | |
| None, | |
| None, | |
| ], | |
| ], | |
| inputs=[text, chords_sym, melody_file, drums_file], | |
| outputs=[audio_output_0, audio_output_1] | |
| ) | |
| demo.queue().launch() |