File size: 9,081 Bytes
62f1377 049cfd4 62f1377 049cfd4 62f1377 049cfd4 62f1377 049cfd4 62f1377 049cfd4 62f1377 049cfd4 62f1377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
import pretty_midi
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import cv2
import imageio
import sys
import subprocess
import os
import torch
from model import init_ldm_model
from model.model_sdf import Diffpro_SDF
from model.sampler_sdf import SDFSampler
import pickle
from train.train_params import params_chord_lsh_cond
from generation.gen_utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
chord_list = list(CHORD_DICTIONARY.keys())
def get_shape(file_path):
if file_path.endswith('.jpg'):
img = cv2.imread(file_path)
return img.shape # (height, width, channels)
elif file_path.endswith('.mp4'):
vid = imageio.get_reader(file_path)
return vid.get_meta_data()['size'] # (width, height)
else:
raise ValueError("Unsupported file type")
# Function to convert MIDI to WAV
def midi_to_wav(midi, output_file):
# Synthesize the waveform from the MIDI using pretty_midi
audio_data = midi.fluidsynth()
# Write the waveform to a WAV file
sf.write(output_file, audio_data, samplerate=44100)
def update_musescore_image(selected_prompt):
# Logic to return the correct image file based on the selected prompt
if selected_prompt == "example 1":
return "samples/diy_examples/example1/example1.jpg"
elif selected_prompt == "example 2":
return "samples/diy_examples/example2/example2.jpg"
elif selected_prompt == "example 3":
return "samples/diy_examples/example3/example3.jpg"
elif selected_prompt == "example 4":
return "samples/diy_examples/example4/example4.jpg"
elif selected_prompt == "example 5":
return "samples/diy_examples/example5/example5.jpg"
elif selected_prompt == "example 6":
return "samples/diy_examples/example6/example6.jpg"
# Model for generating music
def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)
if mode=="example":
if prompt == "example 1":
background_condition = np.load("samples/diy_examples/example1/example1.npy")
tempo=70
elif prompt == "example 2":
background_condition = np.load("samples/diy_examples/example2/example2.npy")
elif prompt == "example 3":
background_condition = np.load("samples/diy_examples/example3/example3.npy")
elif prompt == "example 4":
background_condition = np.load("samples/diy_examples/example4/example4.npy")
background_condition = np.tile(background_condition, (num_samples,1,1,1))
background_condition = torch.Tensor(background_condition).to(device)
else:
background_condition = np.tile(prompt, (num_samples,1,1,1))
background_condition = torch.Tensor(background_condition).to(device)
if rhythm_control!="Yes":
background_condition[:,0:2] = background_condition[:,2:4]
# generate samples
output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples,
same_noise_all_measure=False, X0EditFunc=X0EditFunc,
use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
rhythm_control=rhythm_control)
output_x = torch.clamp(output_x, min=0, max=1)
output_x = output_x.cpu().numpy()
# save samples
for i in range(num_samples):
full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
full_lsh_roll = None
if background_condition.shape[1]>=6:
if background_condition[:,4:6,:,:].min()>=0:
full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
filename = f"output_{i}.mid"
save_midi(midi_file, filename)
subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
return 'output_0.mid', 'output_0.wav', midi_file
# Function to visualize MIDI notes
def visualize_midi(midi):
# Get piano roll from MIDI
roll = midi.get_piano_roll(fs=100)
# Plot the piano roll
plt.figure(figsize=(10, 4))
plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
plt.title("Piano Roll")
plt.xlabel("Time")
plt.ylabel("Pitch")
plt.colorbar()
# Save the plot as an image
output_image_path = "piano_roll.png"
plt.savefig(output_image_path)
return output_image_path
# Gradio main function
def generate_from_example(prompt):
midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control="No")
piano_roll_image = visualize_midi(midi)
return audio_output, piano_roll_image
# Prompt list
prompt_list = ["example 1", "example 2", "example 3", "example 4"]
custom_css = """
.custom-purple {
background-color: #d7bde2;
padding: 10px;
border-radius: 5px;
}
.audio_waveform-container {
display: none !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion Model Based Symbolic Music Generation <div style='text-align: center;'>")
gr.Markdown("<div style='text-align: center;font-size:20px'>Tingyu Zhu<sup>*</sup>, Haoyu Liu<sup>*</sup>, Ziyu Wang, Zhimin Jiang, Zeyu Zheng</div>")
gr.Markdown("<div style='text-align: center;font-size:20px'><a href='https://arxiv.org/abs/2410.08435'>[Paper]</a> <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a></div>")
gr.Markdown("<span style='font-size:25px;'> For detailed information and demonstrations of our method, please visit our [GitHub Pages site](https://huajianduzhuo-code.github.io/FGG-diffusion-music/) to explore:\
\n   1. Accompaniment Generation given Melody and Chord\
\n   2. Style-Controlled Music Generation\
\n   3. Demonstrating the Effectiveness of Sampling Control by Comparison</span>")
gr.HTML("<div style='height: 50px;'></div>")
gr.Markdown("\n\n\n")
gr.Markdown("# <span style='color: red;'> Interactive Demo </span>")
gr.Markdown(
"<span style='font-size:20px;'>"
"π΅ Try out our interactive tool to generate music with our model!<br>"
"You can create new accompaniments conditioned on a given melody and chord progression."
"</span>"
)
gr.Markdown(
"<span style='color:blue; font-size:20px;'>"
"β οΈ This Space currently runs on a Hugging Face-provided CPU. On average, it takes ~15 seconds to generate a 4-measure music segment.<br>"
"If multiple users are generating at the same time, you may enter a queue, which can cause delays.<br><br>"
"π On our local server (NVIDIA RTX 6000 Ada GPU), the same generation takes only 0.4 seconds.<br><br>"
"To speed things up, you can: <br>"
"β’ π Fork this Space and select a different hardware configuration<br>"
"β’ π§βπ» Clone our <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a> and run the generation notebooks locally after installing dependencies and downloading the model weights."
"</span>"
)
with gr.Column(elem_classes="custom-purple"):
gr.Markdown("### Select an example to generate music given melody and chord condition")
with gr.Row():
with gr.Column():
prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
gr.Markdown("### This is the melody to be conditioned on:")
condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)
with gr.Column():
generate_button = gr.Button("Generate")
gr.Markdown("### Generation results:")
audio_output = gr.Audio(label="Generated Music")
piano_roll_output = gr.Image(label="Generated Piano Roll")
generate_button.click(
fn=generate_from_example,
inputs=[prompt_selector],
outputs=[audio_output, piano_roll_output]
)
# Launch Gradio interface
if __name__ == "__main__":
demo.launch()
|