File size: 9,081 Bytes
62f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049cfd4
62f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049cfd4
62f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049cfd4
62f1377
049cfd4
 
62f1377
049cfd4
 
 
 
62f1377
 
 
049cfd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import pretty_midi
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import cv2
import imageio

import sys
import subprocess
import os
import torch
from model import init_ldm_model
from model.model_sdf import Diffpro_SDF
from model.sampler_sdf import SDFSampler

import pickle
from train.train_params import params_chord_lsh_cond
from generation.gen_utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
chord_list = list(CHORD_DICTIONARY.keys())

def get_shape(file_path):
    if file_path.endswith('.jpg'):
        img = cv2.imread(file_path)
        return img.shape  # (height, width, channels)

    elif file_path.endswith('.mp4'):
        vid = imageio.get_reader(file_path)
        return vid.get_meta_data()['size']  # (width, height)

    else:
        raise ValueError("Unsupported file type")

# Function to convert MIDI to WAV
def midi_to_wav(midi, output_file):
    # Synthesize the waveform from the MIDI using pretty_midi
    audio_data = midi.fluidsynth()
    
    # Write the waveform to a WAV file
    sf.write(output_file, audio_data, samplerate=44100)

def update_musescore_image(selected_prompt):
    # Logic to return the correct image file based on the selected prompt
    if selected_prompt == "example 1":
        return "samples/diy_examples/example1/example1.jpg"
    elif selected_prompt == "example 2":
        return "samples/diy_examples/example2/example2.jpg"
    elif selected_prompt == "example 3":
        return "samples/diy_examples/example3/example3.jpg"
    elif selected_prompt == "example 4":
        return "samples/diy_examples/example4/example4.jpg"
    elif selected_prompt == "example 5":
        return "samples/diy_examples/example5/example5.jpg"
    elif selected_prompt == "example 6":
        return "samples/diy_examples/example6/example6.jpg"

# Model for generating music
def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
    ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
    model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
    sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)

    if mode=="example":
        if prompt == "example 1":
            background_condition = np.load("samples/diy_examples/example1/example1.npy")
            tempo=70
        elif prompt == "example 2":
            background_condition = np.load("samples/diy_examples/example2/example2.npy")
        elif prompt == "example 3":
            background_condition = np.load("samples/diy_examples/example3/example3.npy")
        elif prompt == "example 4":
            background_condition = np.load("samples/diy_examples/example4/example4.npy")
        
        background_condition = np.tile(background_condition, (num_samples,1,1,1))
        background_condition = torch.Tensor(background_condition).to(device)
    else:
        background_condition = np.tile(prompt, (num_samples,1,1,1))
        background_condition = torch.Tensor(background_condition).to(device)

    if rhythm_control!="Yes":
        background_condition[:,0:2] = background_condition[:,2:4]
    # generate samples
    output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples, 
                                same_noise_all_measure=False, X0EditFunc=X0EditFunc, 
                                use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
                                rhythm_control=rhythm_control)
    output_x = torch.clamp(output_x, min=0, max=1)
    output_x = output_x.cpu().numpy()

    # save samples
    for i in range(num_samples):
        full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
        full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
        full_lsh_roll = None
        if background_condition.shape[1]>=6:
            if background_condition[:,4:6,:,:].min()>=0:
                full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
        midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
        filename = f"output_{i}.mid"
        save_midi(midi_file, filename)
        subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
    
    return 'output_0.mid', 'output_0.wav', midi_file

# Function to visualize MIDI notes
def visualize_midi(midi):
    # Get piano roll from MIDI
    roll = midi.get_piano_roll(fs=100)
    
    # Plot the piano roll
    plt.figure(figsize=(10, 4))
    plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
    plt.title("Piano Roll")
    plt.xlabel("Time")
    plt.ylabel("Pitch")
    plt.colorbar()
    
    # Save the plot as an image
    output_image_path = "piano_roll.png"
    plt.savefig(output_image_path)
    return output_image_path

# Gradio main function
def generate_from_example(prompt):
    midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control="No")
    piano_roll_image = visualize_midi(midi)
    return audio_output, piano_roll_image

# Prompt list
prompt_list = ["example 1", "example 2", "example 3", "example 4"]

custom_css = """
.custom-purple {
    background-color: #d7bde2;
    padding: 10px;
    border-radius: 5px;
}
.audio_waveform-container {
    display: none !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion Model Based Symbolic Music Generation <div style='text-align: center;'>")

    gr.Markdown("<div style='text-align: center;font-size:20px'>Tingyu Zhu<sup>*</sup>, Haoyu Liu<sup>*</sup>, Ziyu Wang, Zhimin Jiang, Zeyu Zheng</div>")
    gr.Markdown("<div style='text-align: center;font-size:20px'><a href='https://arxiv.org/abs/2410.08435'>[Paper]</a> <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a></div>")

    gr.Markdown("<span style='font-size:25px;'> For detailed information and demonstrations of our method, please visit our [GitHub Pages site](https://huajianduzhuo-code.github.io/FGG-diffusion-music/) to explore:\
                \n &emsp; 1. Accompaniment Generation given Melody and Chord\
                \n &emsp; 2. Style-Controlled Music Generation\
                \n &emsp; 3. Demonstrating the Effectiveness of Sampling Control by Comparison</span>")

    gr.HTML("<div style='height: 50px;'></div>")
    gr.Markdown("\n\n\n")
    gr.Markdown("# <span style='color: red;'> Interactive Demo </span>")
    gr.Markdown(
        "<span style='font-size:20px;'>"
        "🎡 Try out our interactive tool to generate music with our model!<br>"
        "You can create new accompaniments conditioned on a given melody and chord progression."
        "</span>"
    )

    gr.Markdown(
        "<span style='color:blue; font-size:20px;'>"
        "⚠️ This Space currently runs on a Hugging Face-provided CPU. On average, it takes ~15 seconds to generate a 4-measure music segment.<br>"
        "If multiple users are generating at the same time, you may enter a queue, which can cause delays.<br><br>"
        "πŸš€ On our local server (NVIDIA RTX 6000 Ada GPU), the same generation takes only 0.4 seconds.<br><br>"
        "To speed things up, you can: <br>"
        "β€’ πŸ” Fork this Space and select a different hardware configuration<br>"
        "β€’ πŸ§‘β€πŸ’» Clone our <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a> and run the generation notebooks locally after installing dependencies and downloading the model weights."
        "</span>"
    )


    with gr.Column(elem_classes="custom-purple"):
        gr.Markdown("### Select an example to generate music given melody and chord condition")
        with gr.Row():
            with gr.Column():
                prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
                gr.Markdown("### This is the melody to be conditioned on:")
                condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
                prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)

            with gr.Column():
                generate_button = gr.Button("Generate")
                gr.Markdown("### Generation results:")
                audio_output = gr.Audio(label="Generated Music")
                piano_roll_output = gr.Image(label="Generated Piano Roll")

                generate_button.click(
                    fn=generate_from_example,
                    inputs=[prompt_selector],
                    outputs=[audio_output, piano_roll_output]
                )

# Launch Gradio interface
if __name__ == "__main__":
    demo.launch()