diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7b9d33741ed04ce8ac5299ff580d34195ae35713 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..f2c7676a5e74c09421a40065fc4fc09398c04feb --- /dev/null +++ b/.gitattributes @@ -0,0 +1,6 @@ +*.pt filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/Aptfile b/Aptfile new file mode 100644 index 0000000000000000000000000000000000000000..78e026b394005cf20d0877c5a1d66bed8e191edd --- /dev/null +++ b/Aptfile @@ -0,0 +1,2 @@ +lilypond +fluidsynth \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a44cf5baca4c0750592a7f58f387b4d15bdd2e7e --- /dev/null +++ b/README.md @@ -0,0 +1,18 @@ +--- +title: Interactive Symbolic Music Demo +emoji: 🖼 +colorFrom: purple +colorTo: red +sdk: gradio +sdk_version: 4.42.0 +app_file: app.py +pinned: false +python_version: 3.9.19 +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +Please find mido.MidoFile inside the pretty_midi package, and set all arg "clip" to clip=True + +use set_seed(42) in sampler_sdf.py, generation result from chord slice (index = 2) is a good example (a wrong note is shifted to a correct one) \ No newline at end of file diff --git a/__pycache__/app.cpython-39.pyc b/__pycache__/app.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6aff3ff983bea3bbd41aec052861a96665ba366 Binary files /dev/null and b/__pycache__/app.cpython-39.pyc differ diff --git a/__pycache__/learner.cpython-39.pyc b/__pycache__/learner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8fc32d512444df684b54631f94c7152bf90d8bd Binary files /dev/null and b/__pycache__/learner.cpython-39.pyc differ diff --git a/__pycache__/params.cpython-39.pyc b/__pycache__/params.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24afc4dfebe39dbbfba6830b669704c0fd12040c Binary files /dev/null and b/__pycache__/params.cpython-39.pyc differ diff --git a/__pycache__/train_params.cpython-39.pyc b/__pycache__/train_params.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..875f8c1b15a3d93333550a3d86c21e1f03077a23 Binary files /dev/null and b/__pycache__/train_params.cpython-39.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c89499fa849543f32ab04c6433b625262cb04abd --- /dev/null +++ b/app.py @@ -0,0 +1,508 @@ +import gradio as gr +import pretty_midi +import matplotlib.pyplot as plt +import numpy as np +import soundfile as sf +import cv2 +import imageio + +import sys +import subprocess +import os +import torch +from model import init_ldm_model +from model.model_sdf import Diffpro_SDF +from model.sampler_sdf import SDFSampler + +import pickle +from train.train_params import params_chord_lsh_cond +from generation.gen_utils import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt' +chord_list = list(CHORD_DICTIONARY.keys()) + +def get_shape(file_path): + if file_path.endswith('.jpg'): + img = cv2.imread(file_path) + return img.shape # (height, width, channels) + + elif file_path.endswith('.mp4'): + vid = imageio.get_reader(file_path) + return vid.get_meta_data()['size'] # (width, height) + + else: + raise ValueError("Unsupported file type") + +# Function to convert MIDI to WAV +def midi_to_wav(midi, output_file): + # Synthesize the waveform from the MIDI using pretty_midi + audio_data = midi.fluidsynth() + + # Write the waveform to a WAV file + sf.write(output_file, audio_data, samplerate=44100) + +def update_musescore_image(selected_prompt): + # Logic to return the correct image file based on the selected prompt + if selected_prompt == "example 1": + return "samples/diy_examples/example1/example1.jpg" + elif selected_prompt == "example 2": + return "samples/diy_examples/example2/example2.jpg" + elif selected_prompt == "example 3": + return "samples/diy_examples/example3/example3.jpg" + elif selected_prompt == "example 4": + return "samples/diy_examples/example4/example4.jpg" + elif selected_prompt == "example 5": + return "samples/diy_examples/example5/example5.jpg" + elif selected_prompt == "example 6": + return "samples/diy_examples/example6/example6.jpg" + + +# Model for generating music (example) +def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"): + + ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False) + model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device) + sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False) + + if mode=="example": + if prompt == "example 1": + background_condition = np.load("samples/diy_examples/example1/example1.npy") + tempo=70 + elif prompt == "example 2": + background_condition = np.load("samples/diy_examples/example2/example2.npy") + elif prompt == "example 3": + background_condition = np.load("samples/diy_examples/example3/example3.npy") + elif prompt == "example 4": + background_condition = np.load("samples/diy_examples/example4/example4.npy") + + background_condition = np.tile(background_condition, (num_samples,1,1,1)) + background_condition = torch.Tensor(background_condition).to(device) + else: + background_condition = np.tile(prompt, (num_samples,1,1,1)) + background_condition = torch.Tensor(background_condition).to(device) + + if rhythm_control!="Yes": + background_condition[:,0:2] = background_condition[:,2:4] + # generate samples + output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples, + same_noise_all_measure=False, X0EditFunc=X0EditFunc, + use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False, + rhythm_control=rhythm_control) + output_x = torch.clamp(output_x, min=0, max=1) + output_x = output_x.cpu().numpy() + + # save samples + for i in range(num_samples): + full_roll = extend_piano_roll(output_x[i]) # accompaniment roll + full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll + full_lsh_roll = None + if background_condition.shape[1]>=6: + if background_condition[:,4:6,:,:].min()>=0: + full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy()) + midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo) + # filename = f'DDIM_w_rhythm_onset_0to10_{i}_edit_x0_and_eps'+'.mid' + filename = f"output_{i}.mid" + save_midi(midi_file, filename) + subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate() + + return 'output_0.mid', 'output_0.wav', midi_file + +# Function to visualize MIDI notes +def visualize_midi(midi): + # Get piano roll from MIDI + roll = midi.get_piano_roll(fs=100) + + # Plot the piano roll + plt.figure(figsize=(10, 4)) + plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest') + plt.title("Piano Roll") + plt.xlabel("Time") + plt.ylabel("Pitch") + plt.colorbar() + + # Save the plot as an image + output_image_path = "piano_roll.png" + plt.savefig(output_image_path) + return output_image_path + +def plot_rhythm(rhythm_str, label): + if rhythm_str=="null rhythm": + return None + fig, ax = plt.subplots(figsize=(6, 2)) + + # Ensure it's a 16-bit string + rhythm_str = rhythm_str[:16] + + # Convert string to a list of 0s and 1s + rhythm = [0 if bit=="0" else 1 for bit in rhythm_str] + + # Define the x axis for the 16 sixteenth notes + x = list(range(1, 17)) # 1 to 16 sixteenth notes + + # Plot each note (1 as filled circle, 0 as empty circle) + for i, bit in enumerate(rhythm): + if bit == 1: + ax.scatter(i + 1, 1, color='black', s=100, label="Note" if i == 0 else "") + else: + ax.scatter(i + 1, 1, edgecolor='black', facecolor='none', s=100, label="Rest" if i == 0 else "") + + # Distinguish groups of 4 using vertical dashed lines (no solid grid lines) + for i in range(4, 17, 4): + ax.axvline(x=i + 0.5, color='grey', linestyle='--') + + # Remove solid vertical grid lines by setting the grid off + ax.grid(False) + + # Formatting the plot + ax.set_xlim(0.5, 16.5) + ax.set_ylim(0.8, 1.2) + ax.set_xticks(x) + ax.set_yticks([]) + ax.set_xlabel("16th Notes") + ax.set_title("Rhythm Pattern") + + fig.savefig(f'samples/diy_examples/rhythm_plot_{label}.png') + plt.close(fig) + return f'samples/diy_examples/rhythm_plot_{label}.png' + +def adjust_rhythm_string(s): + # Truncate if longer than 16 characters + if len(s) > 16: + return s[:16] + # Pad with zeros if shorter than 16 characters + else: + return s.ljust(16, '0') +def rhythm_string_to_array(s): + # Ensure the string is 16 characters long + s = s[:16].ljust(16, '0') # Truncate or pad with '0' to make it 16 characters + # Convert to numpy array, treating non-'0' as '1' + arr = np.array([1 if char != '0' else 0 for char in s], dtype=int) + arr = arr*np.array([3,1,2,1,3,1,2,1,3,1,2,1,3,1,2,1]) + print(arr) + return arr + +# Gradio main function +def generate_from_example(prompt): + midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control=False) + piano_roll_image = visualize_midi(midi) + return audio_output, piano_roll_image + +def generate_diy(m1_chord, m2_chord, m3_chord, m4_chord, + m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm, tempo): + print("\n\n\n",m1_chord,type(m1_chord), "\n\n\n") + test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[m1_chord], (16, 1)), + np.tile(CHORD_DICTIONARY[m2_chord], (16, 1)), + np.tile(CHORD_DICTIONARY[m3_chord], (16, 1)), + np.tile(CHORD_DICTIONARY[m4_chord], (16, 1))]) + rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm] + + chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0) + + chd_roll = circular_extend(chd_roll) + chd_roll = -chd_roll-1 + + real_chd_roll = chd_roll + + melody_roll = -np.ones_like(chd_roll) + + if "null rhythm" not in rhythms: + rhythm_full = [] + for i in range(len(rhythms)): + rhythm = adjust_rhythm_string(rhythms[i]) + rhythm = rhythm_string_to_array(rhythm) + rhythm_full.append(rhythm) + rhythm_full = np.concatenate(rhythm_full, axis=0) + + onset_roll = test_chd_roll*rhythm_full[:, np.newaxis] + sustain_roll = np.zeros_like(onset_roll) + no_onset_pos = np.all(onset_roll == 0, axis=-1) + sustain_roll[no_onset_pos] = test_chd_roll[no_onset_pos] + + real_chd_roll = np.concatenate([onset_roll[np.newaxis,:,:], sustain_roll[np.newaxis,:,:]], axis=0) + real_chd_roll = circular_extend(real_chd_roll) + + background_condition = np.concatenate([real_chd_roll, chd_roll, melody_roll], axis=0) + + midi_output, audio_output, midi = generate_music(background_condition, tempo, mode="diy") + piano_roll_image = visualize_midi(midi) + return midi_output, audio_output, piano_roll_image + +# Prompt list +prompt_list = ["example 1", "example 2", "example 3", "example 4"] +rhythm_list = ["null rhythm", "1010101010101010", "1011101010111010","1111101010111010","1010001010101010","1010101000101010"] + + +custom_css = """ +.custom-row1 { + background-color: #fdebd0; + padding: 10px; + border-radius: 5px; +} +.custom-row2 { + background-color: #d1f2eb; + padding: 10px; + border-radius: 5px; +} +.custom-grey { + background-color: #f0f0f0; + padding: 10px; + border-radius: 5px; +} +.custom-purple { + background-color: #d7bde2; + padding: 10px; + border-radius: 5px; +} +.audio_waveform-container { + display: none !important; +} +""" + + +with gr.Blocks(css=custom_css) as demo: + gr.Markdown("#
Efficient Fine-Grained Guidance for Diffusion-Based Symbolic Music Generation
") + + gr.Markdown(" We introduce **Fine-Grained Guidance (FG)**, an efficient approach for symbolic music generation using **diffusion models**. Our method enhances guidance through:\ + \n   (1) Fine-grained conditioning during training,\ + \n   (2) Fine-grained control during the diffusion sampling process.\ + \n In particular, **sampling control** ensures tonal accuracy in every generated sample, allowing our model to produce music with high precision, consistent rhythmic patterns,\ + and even stylistic variations that align with user intent.") + gr.Markdown(" At the bottom of this page, we provide an interactive space for you to try our model by yourself! ") + + + gr.Markdown("\n\n\n") + gr.Markdown("# 1. Accompaniment Generation given Melody and Chord") + gr.Markdown(" In each example, the left column displays the melody provided as inputs to the model.\ + The right column showcases music samples generated by the model.") + + with gr.Column(elem_classes="custom-row1"): + gr.Markdown("## Example 1") + with gr.Row(): + with gr.Column(): + gr.Markdown(" With the following melody as condition ") + example1_mel = gr.Audio(value="samples/diy_examples/example1/example_1_mel.wav", label="Melody", scale = 5) + with gr.Column(): + gr.Markdown(" Generated Accompaniments ") + example1_audio = gr.Audio(value="samples/diy_examples/example1/sample1.wav", label="Generated Accompaniment", scale = 5) + + with gr.Column(elem_classes="custom-row2"): + gr.Markdown("## Example 2") + with gr.Row(): + with gr.Column(): + gr.Markdown(" With the following melody as condition ") + example1_mel = gr.Audio(value="samples/diy_examples/example2/example_2_mel.wav", label="Melody", scale = 5) + with gr.Column(): + gr.Markdown(" Generated Accompaniments ") + example1_audio = gr.Audio(value="samples/diy_examples/example2/sample1.wav", label="Generated Accompaniment", scale = 5) + + with gr.Column(elem_classes="custom-row1"): + gr.Markdown("## Example 3") + with gr.Row(): + with gr.Column(): + gr.Markdown(" With the following melody as condition ") + example1_mel = gr.Audio(value="samples/diy_examples/example3/example_3_mel.wav", label="Melody", scale = 5) + with gr.Column(): + gr.Markdown(" Generated Accompaniments ") + example1_audio = gr.Audio(value="samples/diy_examples/example3/sample1.wav", label="Generated Accompaniment", scale = 5) + + with gr.Column(elem_classes="custom-row2"): + gr.Markdown("## Example 4") + with gr.Row(): + with gr.Column(): + gr.Markdown(" With the following melody as condition ") + example1_mel = gr.Audio(value="samples/diy_examples/example4/example_4_mel.wav", label="Melody", scale = 5) + with gr.Column(): + gr.Markdown(" Generated Accompaniments ") + example1_audio = gr.Audio(value="samples/diy_examples/example4/sample1.wav", label="Generated Accompaniment", scale = 5) + + gr.HTML("
") + gr.Markdown("# \n\n\n") + gr.Markdown("# 2. Style-Controlled Music Generation") + gr.Markdown("Our approach enables controllable stylization in music generation. The sampling control is able to\ + ensure that all generated notes strictly adhere to the target musical style's scale.\ + This allows the model to generate music in specific styles — even those that were not present in \ + the training data.") + gr.Markdown(" Below, we demonstrate several examples of style-controlled music generation for:\ + \n   (1) Dorian Mode: (with scale being A-B-C-D-E-F#-G);\ + \n   (2) Chinese Style: (with scale being C-D-E-G-A). ") + + with gr.Column(elem_classes="custom-row1"): + gr.Markdown("## Dorian Mode") + gr.Markdown(" The following are two examples generated by our method ") + with gr.Row(): + with gr.Column(elem_classes="custom-grey"): + gr.Markdown(" Example 1 ") + example1_mel = gr.Audio(value="samples/different_styles/dorian_1.wav", scale = 5) + with gr.Column(elem_classes="custom-grey"): + gr.Markdown(" Example 2 ") + example1_audio = gr.Audio(value="samples/different_styles/dorian_2.wav", scale = 5) + + with gr.Column(elem_classes="custom-row2"): + gr.Markdown("## Chinese Style") + gr.Markdown(" The following are two examples generated by our method ") + with gr.Row(): + with gr.Column(elem_classes="custom-grey"): + gr.Markdown(" Example 1 ") + example1_mel = gr.Audio(value="samples/different_styles/chinese_1.wav", scale = 5) + with gr.Column(elem_classes="custom-grey"): + gr.Markdown(" Example 2 ") + example1_audio = gr.Audio(value="samples/different_styles/chinese_2.wav", scale = 5) + + gr.HTML("
") + gr.Markdown("\n\n\n") + gr.Markdown("# 3. Demonstrating the Effectiveness of Sampling Control by Comparison") + + gr.Markdown(" We demonstrate the impact of sampling control in an **accompaniment generation** task, given a melody and chord progression.\ + \n Each example generates accompaniments with and without sampling control using the same random seed, ensuring that the two results are comparable.\ + \n Sampling control effectively removes or replaces harmonically conflicting notes, ensuring tonal consistency.\ + \n We provide music sheets and audio files for both versions.") + + gr.Markdown(" Comparison of the results indicates that sampling control not only eliminates out-of-key notes but also enhances \ + the overall coherence and harmonic consistency of the accompaniments.\ + This highlights the effectiveness of our approach in maintaining musical coherence. ") + + + with gr.Column(elem_classes="custom-row1"): + gr.Markdown("## Example 1") + + with gr.Row(elem_classes="custom-grey"): + gr.Markdown(" With pre-defined melody and chord as follows") + with gr.Column(scale=2, min_width=10, ): + gr.Markdown("Melody Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=10, ): + gr.Markdown("Melody Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10) + + gr.Markdown("## Generated Accompaniments") + with gr.Row(elem_classes="custom-grey"): + gr.Markdown(" Without sampling control") + with gr.Column(scale=2, min_width=300): + gr.Markdown("Music Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=150): + gr.Markdown("Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10) + gr.Markdown("\n\n\n") + with gr.Row(elem_classes="custom-grey"): + with gr.Column(scale=1, min_width=150): + gr.Markdown("With sampling control") + with gr.Column(scale=2, min_width=300): + gr.Markdown("Music Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=150): + gr.Markdown("Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10) + + + with gr.Column(elem_classes="custom-row2"): + gr.Markdown("## Example 2") + + with gr.Row(elem_classes="custom-grey"): + gr.Markdown(" With pre-defined melody and chord as follows") + with gr.Column(scale=2, min_width=10, ): + gr.Markdown("Melody Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=10, ): + gr.Markdown("Melody Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10) + + gr.Markdown("## Generated Accompaniments") + with gr.Row(elem_classes="custom-grey"): + gr.Markdown(" Without sampling control") + with gr.Column(scale=2, min_width=300): + gr.Markdown("Music Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=150): + gr.Markdown("Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.wav", scale = 1, min_width=10) + gr.Markdown("\n\n\n") + with gr.Row(elem_classes="custom-grey"): + with gr.Column(scale=1, min_width=150): + gr.Markdown("With sampling control") + with gr.Column(scale=2, min_width=300): + gr.Markdown("Music Sheet") + example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + with gr.Column(scale=1, min_width=150): + gr.Markdown("Audio") + example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_control.wav", scale = 1, min_width=10) + + # with gr.Row(): + # with gr.Column(scale=1, min_width=300, elem_classes="custom-row1"): + # gr.Markdown("## Example 1") + # gr.Markdown(" With pre-defined melody and chord as follows") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # # Audio component to play the audio + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10) + + # gr.Markdown("## Generated Accompaniments") + # with gr.Row(): + # with gr.Column(scale=1, min_width=150): + # gr.Markdown(" without sampling control") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10) + # with gr.Column(scale=1, min_width=150): + # gr.Markdown(" with sampling control") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10) + # with gr.Column(scale=1, min_width=300, elem_classes="custom-row2"): + # gr.Markdown("## Example 2") + # gr.Markdown(" With pre-defined melody and chord as follows") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # # Audio component to play the audio + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10) + + # gr.Markdown("## Generated Accompaniments") + # with gr.Row(): + # with gr.Column(scale=1, min_width=150): + # gr.Markdown(" without sampling control") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10) + # with gr.Column(scale=1, min_width=150): + # gr.Markdown(" with sampling control") + # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10) + # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10) + + + + + + ''' Try to generate by users ''' + gr.HTML("
") + gr.Markdown("\n\n\n") + gr.Markdown("# 4. DIY in real time! ") + gr.Markdown(" Here is an interactive tool for you to try our model and generate by yourself.\ + You can generate new accompaniments for given melody and chord conditions ") + + gr.Markdown("### Currently this space is supported with Hugging Face CPU and on average,\ + it takes about 15 seconds to generate a 4-measure music piece. However, if other users are generating\ + music at the same time, one may enter a queue, which could slow down the process significantly.\ + If that happens, feel free to refresh the page. We appreciate your patience and understanding.\ + ") + + with gr.Column(elem_classes="custom-purple"): + gr.Markdown("### Select an example to generate music given melody and chord condition") + with gr.Row(): + with gr.Column(): + prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1") + gr.Markdown("### This is the melody to be conditioned on:") + condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition") + prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore) + + with gr.Column(): + generate_button = gr.Button("Generate") + gr.Markdown("### Generation results:") + audio_output = gr.Audio(label="Generated Music") + piano_roll_output = gr.Image(label="Generated Piano Roll") + + generate_button.click( + fn=generate_from_example, + inputs=[prompt_selector], + outputs=[audio_output, piano_roll_output] + ) + + +# Launch Gradio interface +if __name__ == "__main__": + demo.launch() diff --git a/filter_data/filter_by_instrument.ipynb b/filter_data/filter_by_instrument.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2a9b8472a793d457ca3bc130e26b80d1097cfc21 --- /dev/null +++ b/filter_data/filter_by_instrument.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from midi_utils import is_timesig_44, gather_full_instr, gather_instr\n", + "from midi_utils import has_brass\n", + "from midi_utils import has_piano, has_string, has_guitar, has_drums\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Program number 0: Acoustic Grand Piano\n", + "Program number 1: Bright Acoustic Piano\n", + "Program number 2: Electric Grand Piano\n", + "Program number 3: Honky-tonk Piano\n", + "Program number 4: Electric Piano 1\n", + "Program number 5: Electric Piano 2\n", + "Program number 6: Harpsichord\n", + "Program number 7: Clavinet\n", + "Program number 8: Celesta\n", + "Program number 9: Glockenspiel\n", + "Program number 10: Music Box\n", + "Program number 11: Vibraphone\n", + "Program number 12: Marimba\n", + "Program number 13: Xylophone\n", + "Program number 14: Tubular Bells\n", + "Program number 15: Dulcimer\n", + "Program number 16: Drawbar Organ\n", + "Program number 17: Percussive Organ\n", + "Program number 18: Rock Organ\n", + "Program number 19: Church Organ\n", + "Program number 20: Reed Organ\n", + "Program number 21: Accordion\n", + "Program number 22: Harmonica\n", + "Program number 23: Tango Accordion\n", + "Program number 24: Acoustic Guitar (nylon)\n", + "Program number 25: Acoustic Guitar (steel)\n", + "Program number 26: Electric Guitar (jazz)\n", + "Program number 27: Electric Guitar (clean)\n", + "Program number 28: Electric Guitar (muted)\n", + "Program number 29: Overdriven Guitar\n", + "Program number 30: Distortion Guitar\n", + "Program number 31: Guitar Harmonics\n", + "Program number 32: Acoustic Bass\n", + "Program number 33: Electric Bass (finger)\n", + "Program number 34: Electric Bass (pick)\n", + "Program number 35: Fretless Bass\n", + "Program number 36: Slap Bass 1\n", + "Program number 37: Slap Bass 2\n", + "Program number 38: Synth Bass 1\n", + "Program number 39: Synth Bass 2\n", + "Program number 40: Violin\n", + "Program number 41: Viola\n", + "Program number 42: Cello\n", + "Program number 43: Contrabass\n", + "Program number 44: Tremolo Strings\n", + "Program number 45: Pizzicato Strings\n", + "Program number 46: Orchestral Harp\n", + "Program number 47: Timpani\n", + "Program number 48: String Ensemble 1\n", + "Program number 49: String Ensemble 2\n", + "Program number 50: Synth Strings 1\n", + "Program number 51: Synth Strings 2\n", + "Program number 52: Choir Aahs\n", + "Program number 53: Voice Oohs\n", + "Program number 54: Synth Choir\n", + "Program number 55: Orchestra Hit\n", + "Program number 56: Trumpet\n", + "Program number 57: Trombone\n", + "Program number 58: Tuba\n", + "Program number 59: Muted Trumpet\n", + "Program number 60: French Horn\n", + "Program number 61: Brass Section\n", + "Program number 62: Synth Brass 1\n", + "Program number 63: Synth Brass 2\n", + "Program number 64: Soprano Sax\n", + "Program number 65: Alto Sax\n", + "Program number 66: Tenor Sax\n", + "Program number 67: Baritone Sax\n", + "Program number 68: Oboe\n", + "Program number 69: English Horn\n", + "Program number 70: Bassoon\n", + "Program number 71: Clarinet\n", + "Program number 72: Piccolo\n", + "Program number 73: Flute\n", + "Program number 74: Recorder\n", + "Program number 75: Pan Flute\n", + "Program number 76: Blown bottle\n", + "Program number 77: Shakuhachi\n", + "Program number 78: Whistle\n", + "Program number 79: Ocarina\n", + "Program number 80: Lead 1 (square)\n", + "Program number 81: Lead 2 (sawtooth)\n", + "Program number 82: Lead 3 (calliope)\n", + "Program number 83: Lead 4 chiff\n", + "Program number 84: Lead 5 (charang)\n", + "Program number 85: Lead 6 (voice)\n", + "Program number 86: Lead 7 (fifths)\n", + "Program number 87: Lead 8 (bass + lead)\n", + "Program number 88: Pad 1 (new age)\n", + "Program number 89: Pad 2 (warm)\n", + "Program number 90: Pad 3 (polysynth)\n", + "Program number 91: Pad 4 (choir)\n", + "Program number 92: Pad 5 (bowed)\n", + "Program number 93: Pad 6 (metallic)\n", + "Program number 94: Pad 7 (halo)\n", + "Program number 95: Pad 8 (sweep)\n", + "Program number 96: FX 1 (rain)\n", + "Program number 97: FX 2 (soundtrack)\n", + "Program number 98: FX 3 (crystal)\n", + "Program number 99: FX 4 (atmosphere)\n", + "Program number 100: FX 5 (brightness)\n", + "Program number 101: FX 6 (goblins)\n", + "Program number 102: FX 7 (echoes)\n", + "Program number 103: FX 8 (sci-fi)\n", + "Program number 104: Sitar\n", + "Program number 105: Banjo\n", + "Program number 106: Shamisen\n", + "Program number 107: Koto\n", + "Program number 108: Kalimba\n", + "Program number 109: Bagpipe\n", + "Program number 110: Fiddle\n", + "Program number 111: Shanai\n", + "Program number 112: Tinkle Bell\n", + "Program number 113: Agogo\n", + "Program number 114: Steel Drums\n", + "Program number 115: Woodblock\n", + "Program number 116: Taiko Drum\n", + "Program number 117: Melodic Tom\n", + "Program number 118: Synth Drum\n", + "Program number 119: Reverse Cymbal\n", + "Program number 120: Guitar Fret Noise\n", + "Program number 121: Breath Noise\n", + "Program number 122: Seashore\n", + "Program number 123: Bird Tweet\n", + "Program number 124: Telephone Ring\n", + "Program number 125: Helicopter\n", + "Program number 126: Applause\n", + "Program number 127: Gunshot\n" + ] + } + ], + "source": [ + "import pretty_midi\n", + "\n", + "def display_all_instrument_names():\n", + " for program_number in range(128):\n", + " instrument_name = pretty_midi.program_to_instrument_name(program_number)\n", + " print(f\"Program number {program_number}: {instrument_name}\")\n", + "\n", + "# Call the function to display all instrument names\n", + "display_all_instrument_names()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "#import mido\n", + "import pretty_midi\n", + "#from mido import KeySignatureError\n", + "\n", + "def filter_midi(file_path, max_track = 8, max_time = 300):\n", + " try:\n", + " pm = pretty_midi.PrettyMIDI(file_path)\n", + " except Exception as e:\n", + " return False\n", + " \n", + " # time signature 4/4\n", + " #if is_timesig_44(pm) == False:\n", + " #print(\"timesig\")\n", + " #return False\n", + " \n", + " # number of tracks\n", + " if len(pm.instruments)>max_track:\n", + " #print(\"tracks\")\n", + " return False\n", + " \n", + " # length of song\n", + " #if pm.get_end_time()>max_time:\n", + " #print(\"length\")\n", + " #return False\n", + "\n", + " # now filter by instruments\n", + "\n", + " # filter out the ones without drums\n", + " #if has_drums(pm)==False:\n", + " #print(\"no drums\")\n", + " #return False\n", + " \n", + " # filter out the ones with brass\n", + " instr = gather_instr(pm)\n", + " if has_brass(instr):\n", + " #print(\"has brass\")\n", + " return False\n", + " \n", + " # filter out the ones without full string and piano\n", + " full_instr = gather_full_instr(pm, threshold=0.7)\n", + " \n", + " if has_piano(full_instr)== False:\n", + " #print(\"no piano\")\n", + " return False\n", + "\n", + " if has_guitar(full_instr)== False:\n", + " return False\n", + " \n", + " #if has_string(full_instr)== False:\n", + " #print(\"no string\")\n", + " #return False\n", + " \n", + " return True\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#import pretty_midi\n", + "#midi_path = '/home/ubuntu/lakh-pianoroll-dataset/data/samples_with_strings/c10e69ec7f8212c68ff3658cceef5b9b.mid'\n", + "#pm = pretty_midi.PrettyMIDI(midi_path)\n", + "#mid = mido.MidiFile(midi_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import glob\n", + "import os\n", + "from tqdm import tqdm\n", + "\n", + "def find_midi_files_upto(root_dir, sample_size):\n", + " midi_files = glob.glob(os.path.join(root_dir, '**/*.mid'), recursive=True)\n", + " matching_files = []\n", + " match_count = 0\n", + "\n", + " pbar = tqdm(total=len(midi_files), desc=\"Processing MIDI files\")\n", + " for midi_file in midi_files:\n", + " if filter_midi(midi_file):\n", + " matching_files.append(midi_file)\n", + " match_count += 1\n", + " pbar.set_postfix({'Matching files': match_count})\n", + " if match_count >= sample_size:\n", + " break\n", + " pbar.update(1)\n", + " pbar.close()\n", + " return matching_files\n", + "\n", + "def copy_files(files, target_dir):\n", + " os.makedirs(target_dir, exist_ok=True)\n", + " for file in files:\n", + " shutil.copy(file, target_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing MIDI files: 0%| | 14/178561 [00:02<11:02:01, 4.49it/s]/home/ubuntu/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py:100: RuntimeWarning: Tempo, Key or Time signature change events found on non-zero tracks. This is not a valid type 0 or type 1 MIDI file. Tempo, Key or Time Signature may be wrong.\n", + " warnings.warn(\n", + "Processing MIDI files: 0%| | 148/178561 [00:18<4:21:09, 11.39it/s, Matching files=4] " + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_47712/2263633239.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mROOT_DIR\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msample_files\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_midi_files_upto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mROOT_DIR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtgt_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcopy_files\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_files\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgt_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_47712/2439092612.py\u001b[0m in \u001b[0;36mfind_midi_files_upto\u001b[0;34m(root_dir, sample_size)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Processing MIDI files\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmidi_file\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmidi_files\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mmatching_files\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mmatch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_47712/3402906515.py\u001b[0m in \u001b[0;36mfilter_midi\u001b[0;34m(file_path, max_track, max_time)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_track\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mpm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpretty_midi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPrettyMIDI\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, midi_file, resolution, initial_tempo)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# If a string was given, pass it as the string filename\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmidi_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmido\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMidiFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Otherwise, try passing it in as a file pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, file, type, ticks_per_beat, charset, debug, clip, tracks)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self, infile)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0m_dbg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Track {i}:'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m self.tracks.append(read_track(infile,\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m clip=self.clip))\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_track\u001b[0;34m(infile, debug, clip)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_sysex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpeek_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mtrack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_message\u001b[0;34m(infile, status_byte, peek_data, delta, clip)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data byte must be in range 0..127'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mMessage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdata_bytes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/messages.py\u001b[0m in \u001b[0;36mfrom_bytes\u001b[0;34m(cl, data, time)\u001b[0m\n\u001b[1;32m 161\u001b[0m \"\"\"\n\u001b[1;32m 162\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0mmsgdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecode_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'data'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSysexData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36mdecode_message\u001b[0;34m(msg_bytes, time, check)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_SPECIAL_CASES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_decode_data_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m_decode_data_bytes\u001b[0;34m(status_byte, data, spec)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing MIDI files: 0%| | 150/178561 [00:30<4:21:08, 11.39it/s, Matching files=4]" + ] + } + ], + "source": [ + "ROOT_DIR = '/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\n", + "sample_files = find_midi_files_upto(ROOT_DIR, sample_size=1500)\n", + "tgt_dir = '/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\n", + "copy_files(sample_files, tgt_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/filter_data/midi_utils.py b/filter_data/midi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70b530f47b641ebe3092fade5b09d699d14de13d --- /dev/null +++ b/filter_data/midi_utils.py @@ -0,0 +1,139 @@ + +import pretty_midi + +PIANO = (0,1) +STRING = (40,41,42,48,49,50,51) +GUITAR = (24,25,27) +BRASS = (56,57,58,59,61,62,63,64,65,66,67) + +############################## For single track analysis + +def calculate_active_duration(instrument): + # Collect all note intervals + intervals = [(note.start, note.end) for note in instrument.notes] + + # Sort intervals by start time + intervals.sort() + + # Merge overlapping intervals and calculate the active duration + active_duration = 0 + current_start, current_end = intervals[0] + + for start, end in intervals[1:]: + if start <= current_end: # There is an overlap + current_end = max(current_end, end) + else: # No overlap, add the previous interval duration and start a new interval + active_duration += current_end - current_start + current_start, current_end = start, end + + # Add the last interval + active_duration += current_end - current_start + + return active_duration + +def is_full_track(midi, instrument, threshold=0.6): + # Calculate the total duration of the track + total_duration = midi.get_end_time() + + # Calculate the active duration (time during which notes are playing) + active_duration = calculate_active_duration(instrument) + + # Calculate the percentage of active duration + active_percentage = active_duration / total_duration + + #print(f"Total duration: {total_duration:.2f} seconds") + #print(f"Active duration: {active_duration:.2f} seconds") + #print(f"Active percentage: {active_percentage:.2%}") + + # Check if the active duration meets or exceeds the threshold + return active_percentage >= threshold + +#################################### For gathering full tracks + +def gather_instr(pm): + # Gather all the program indexes of the instrument tracks + program_indexes = [instrument.program for instrument in pm.instruments] + + # Sort the program indexes + program_indexes.sort() + + # Convert the sorted list of program indexes to a tuple + program_indexes_tuple = tuple(program_indexes) + return program_indexes_tuple + +def gather_full_instr(pm, threshold = 0.6): + # Gather all the program indexes of the instrument tracks that exceed the duration threshold + program_indexes = [] + for instrument in pm.instruments: + if is_full_track(pm, instrument, threshold): + program_indexes.append(instrument.program) + program_indexes.sort() + # Convert the list of program indexes to a tuple + program_indexes_tuple = tuple(program_indexes) + + return program_indexes_tuple + +####################################### For finding instruments + +def has_intersection(wanted_instr, exist_instr): + # Convert both the tuple and the group of integers to sets + tuple_set = set(wanted_instr) + group_set = set(exist_instr) + + # Check if there is any intersection + return not tuple_set.isdisjoint(group_set) + +# The functions checking instruments in the midi file tracks +def has_piano(exist_instr): + wanted_instr = PIANO + return has_intersection(wanted_instr, exist_instr) + +def has_string(exist_instr): + wanted_instr = STRING + return has_intersection(wanted_instr, exist_instr) + +def has_guitar(exist_instr): + wanted_instr = GUITAR + return has_intersection(wanted_instr, exist_instr) + +def has_brass(exist_instr): + wanted_instr = BRASS + return has_intersection(wanted_instr, exist_instr) + +def has_drums(pm): + for instrument in pm.instruments: + if instrument.is_drum: + return True + return False + + +def print_track_details(instrument): + """ + For visualizing the information in a midi track + """ + print(f"Instrument: {pretty_midi.program_to_instrument_name(instrument.program)}") + print(f"Is drum: {instrument.is_drum}") + + print("\nNotes:") + for note in instrument.notes: + print(f"Start: {note.start:.2f}, End: {note.end:.2f}, Pitch: {note.pitch}, Velocity: {note.velocity}") + + print("\nControl Changes:") + for cc in instrument.control_changes: + print(f"Time: {cc.time:.2f}, Number: {cc.number}, Value: {cc.value}") + + print("\nPitch Bends:") + for pb in instrument.pitch_bends: + print(f"Time: {pb.time:.2f}, Pitch: {pb.pitch}") + +def is_timesig_44(pm): + for time_signature in pm.time_signature_changes: + if time_signature.numerator != 4 or time_signature.denominator != 4: + return False + return True + +def is_timesig_34(pm): + for time_signature in pm.time_signature_changes: + if time_signature.numerator != 4 or time_signature.denominator != 4: + return False + return True \ No newline at end of file diff --git a/generation/__pycache__/gen_utils.cpython-39.pyc b/generation/__pycache__/gen_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..772c1a785843c16d7934dc014e4afd29f36ac139 Binary files /dev/null and b/generation/__pycache__/gen_utils.cpython-39.pyc differ diff --git a/generation/gen_utils.py b/generation/gen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64024b389db798c137aaad5c1c691001c7e5e204 --- /dev/null +++ b/generation/gen_utils.py @@ -0,0 +1,302 @@ +import torch +import numpy as np +import pretty_midi as pm + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +CHORD_DICTIONARY = { + "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]), + "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]), + "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]), + "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]), + "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]), + "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]), + "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]), + "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]), + "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]), + "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]), + "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]), + "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]), + + "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]), + "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]), + "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]), + "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]), + "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]), + "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]), + "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]), + "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]), + "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]), + "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]), + "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]), + "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]), +} + + +def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True): + ''' + piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch + num_notes_onset: a tensor with shape (batch_size, length) + mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma + reduce_extra_notes: True if want to reduce extra notes + ''' + ########## for those greater than the threshold, if num of notes exceed num_notes[i], + ########## will keep the first ones and set others to threshold + print("Coming") + # we only edit onset + onset_roll = piano_roll_full[:,0,:,:] + mask = mask_full[:,0,:,:] + shape = onset_roll.shape + + onset_roll = onset_roll.reshape(-1,shape[-1]) + mask = mask.reshape(-1,shape[-1]) + num_notes = num_notes_onset.reshape(-1) + + reduce_note_threshold = 0.499 + increase_note_threshold = 0.501 + + # Initialize a tensor to store the modified values + final_onset_roll = onset_roll.clone() + + ########### if number of notes > required, remove the extra notes ############### + if reduce_extra_notes: + threshold_mask = onset_roll > reduce_note_threshold + # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection + values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) + + # Get the top num_notes.max() values for each row + num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row + topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1) + + # Create a mask for the top num_notes[i] values for each row + col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max) + topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf")) + + # Set all values greater than reduce_note_threshold to reduce_note_threshold initially + final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold + + # Create a flattened index to scatter the top values back into final_onset_roll + flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices) + flat_row_indices = flat_row_indices[topk_mask] + + # Gather the valid topk_indices and corresponding values + valid_topk_indices = topk_indices[topk_mask] + valid_topk_values = topk_values[topk_mask] + + # Use scatter to place the top num_notes[i] values back to their original positions + final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values) + + ########### if number of notes < required, add some notes ############### + pitch_less_84_mask = torch.ones_like(mask) + pitch_less_84_mask[:,51:] = 0 + + # Count how many values >= increase_note_threshold for each row + threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1) + greater_than_threshold2_count = threshold_mask_2.sum(dim=1) + + # For those rows, find the remaining number of values needed to be set to increase_note_threshold + remaining_needed = num_notes - greater_than_threshold2_count + remaining_needed_max = int(remaining_needed.max().item()) + print("\n\n\n",remaining_needed_max,"\n\n\n") + if remaining_needed_max>=0: # need to add notes + # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold) + values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) + topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1) + + # Mask to only adjust the needed number of values in each row + col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max) + adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf")) + + # Flatten row indices for the new top-k below increase_note_threshold + flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices) + flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask] + + # Gather the valid indices and set them to increase_note_threshold + valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask] + + # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold + final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device)) + + final_onset_roll = final_onset_roll.reshape(shape) + piano_roll_full[:,0,:,:] = final_onset_roll + return piano_roll_full + +def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"): + # 预先计算 major 和 minor 和弦的所有旋转 + maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) + maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1)) + min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) + min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1)) + + # all chords, with rotation + maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] + min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] + + # combine all chords + # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types, + # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes + chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0) + + # if using null rhythm condition, have to convert -2 to 1 and -1 to 0 + if background_condition[:,:2,:,:].min()<0: + correct_chord_condition = -background_condition[:,:2,:,:]-1 + else: + correct_chord_condition = background_condition[:,:2,:,:] + merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond + chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond + shape = chd_chroma_ours.shape + chd_chroma_ours = chd_chroma_ours.reshape(-1,64) + matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1) + seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape) + seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1)) + + no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1) + seven_notes_chroma_ours[no_chd_match] = 1. + + # edit notes based on chroma + x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0) + print("See Coming?") + # edit rhythm + if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided + num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1) + x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes) + + return x0 + +def expand_roll(roll, unit=4, contain_onset=False): + # roll: (Channel, T, H) -> (Channel, T * unit, H) + n_channel, length, height = roll.shape + + expanded_roll = roll.repeat(unit, axis=1) + if contain_onset: + expanded_roll = expanded_roll.reshape((n_channel, length, unit, height)) + expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:]) + + expanded_roll[::2, :, 1:] = 0 + expanded_roll = expanded_roll.reshape((n_channel, length * unit, height)) + return expanded_roll + +def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96): + piano_roll_cut = piano_roll[:,:,lowest:highest+1] + return piano_roll_cut + +def circular_extend(chd_roll, lowest=33, highest=96): + #chd_roll: 6*L*12->6*L*64 + C4 = 60-lowest + C3 = C4-12 + shape = chd_roll.shape + ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest)) + ext_chd[:,:,C4:C4+12] = chd_roll + ext_chd[:,:,C3:C3+12] = chd_roll + return ext_chd + + +def default_quantization(v): + return 1 if v > 0.5 else 0 + +def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96): + ## this function is for extending the cutted piano rolls into the full 128 piano rolls + ## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128) + padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0) + return padded_roll + + + +def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None): + """ + piano_roll: (2, L, 128), onset and sustain channel. + raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave + """ + def convert_p(p_, note_list): + edit_note_flag = False + for t in range(n_step): + onset_state = quantization_func(piano_roll[0, t, p_]) + sustain_state = quantization_func(piano_roll[1, t, p_]) + + is_onset = bool(onset_state) + is_sustain = bool(sustain_state) and not is_onset + + pitch = p_ + + if is_onset: + edit_note_flag = True + note_list.append([t, pitch, 1]) + elif is_sustain: + if edit_note_flag: + note_list[-1][-1] += 1 + else: + edit_note_flag = False + return note_list + + quantization_func = default_quantization if quantization_func is None else quantization_func + assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}" + + n_step = piano_roll.shape[1] + + notes = [] + for p in range(128): + convert_p(p, notes) + + return notes + + +def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100): + """Default use shift beat""" + + beat_alpha = 60 / bpm + step_alpha = unit * beat_alpha + + notes = [] + + shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha + + for note in note_mat: + onset, pitch, dur = note + start = onset * step_alpha + shift_sec + end = (onset + dur) * step_alpha + shift_sec + + notes.append(pm.Note(vel, int(pitch), start, end)) + + return notes + + +def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None): + midi = pm.PrettyMIDI(initial_tempo=bpm) + + piano_program = pm.instrument_name_to_program('Acoustic Grand Piano') + piano = pm.Instrument(program=piano_program) + piano.notes+=piano_notes_list + midi.instruments.append(piano) + + # chd_program = pm.instrument_name_to_program('Violin') + # chd = pm.Instrument(program=chd_program) + # chd.notes+=chd_notes_list + # midi.instruments.append(chd) + + if lsh_notes_list is not None: + lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano') + lsh = pm.Instrument(program=lsh_program) + lsh.notes+=lsh_notes_list + midi.instruments.append(lsh) + + return midi + +def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80): + piano_mat = piano_roll_to_note_mat(piano_roll) + piano_notes = note_mat_to_notes(piano_mat, bpm) + + chd_mat = piano_roll_to_note_mat(chd_roll) + chd_notes = note_mat_to_notes(chd_mat, bpm) + + if lsh_roll is not None: + lsh_mat = piano_roll_to_note_mat(lsh_roll) + lsh_notes = note_mat_to_notes(lsh_mat, bpm) + else: + lsh_notes=None + + piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes, + chd_notes_list=chd_notes, lsh_notes_list=lsh_notes) + return piano_pm + +def save_midi(pm, filename): + pm.write(filename) \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5dc75f1e000aa98775ead2b4430a4eb00dc9e7 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,59 @@ +from .latent_diffusion import LatentDiffusion +from .model_sdf import Diffpro_SDF +from .architecture.unet import UNetModel +import os + + +def init_ldm_model(params, debug_mode=False): + unet_model = UNetModel( + in_channels=params.in_channels, + out_channels=params.out_channels, + channels=params.channels, + attention_levels=params.attention_levels, + n_res_blocks=params.n_res_blocks, + channel_multipliers=params.channel_multipliers, + n_heads=params.n_heads, + tf_layers=params.tf_layers, + #d_cond=params.d_cond, + ) + + + ldm_model = LatentDiffusion( + unet_model=unet_model, + #autoencoder=None, + #autoreg_cond_enc=autoreg_cond_enc, + #external_cond_enc=external_cond_enc, + latent_scaling_factor=params.latent_scaling_factor, + n_steps=params.n_steps, + linear_start=params.linear_start, + linear_end=params.linear_end, + debug_mode=debug_mode + ) + + return ldm_model + + + +def init_diff_pro_sdf(ldm_model, params, device): + return Diffpro_SDF(ldm_model).to(device) + + +def get_model_path(model_dir, model_id=None): + model_desc = os.path.basename(model_dir) + if model_id is None: + model_path = os.path.join(model_dir, 'chkpts', 'weights.pt') + + # retrieve real model_id from the actual file weights.pt is pointing to + model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0] + + elif model_id == 'best': + model_path = os.path.join(model_dir, 'chkpts', 'weights_best.pt') + # retrieve real model_id from the actual file weights.pt is pointing to + model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0] + elif model_id == 'default': + model_path = os.path.join(model_dir, 'chkpts', 'weights_default.pt') + if not os.path.exists(model_path): + return get_model_path(model_dir, 'best') + else: + model_path = f"{model_dir}/chkpts/weights-{model_id}.pt" + return model_path, model_id, model_desc diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0082a4fba238c21cde1a333fe198bf447b2c974e Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/__pycache__/latent_diffusion.cpython-39.pyc b/model/__pycache__/latent_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b78ebcb9440337d801824127fedea456a2bf3e6 Binary files /dev/null and b/model/__pycache__/latent_diffusion.cpython-39.pyc differ diff --git a/model/__pycache__/model_sdf.cpython-39.pyc b/model/__pycache__/model_sdf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa2acd035f522389722dcd927709b804b87cece Binary files /dev/null and b/model/__pycache__/model_sdf.cpython-39.pyc differ diff --git a/model/__pycache__/sampler_sdf.cpython-39.pyc b/model/__pycache__/sampler_sdf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef313488a1362b53822ae9b8239101bddc5b041 Binary files /dev/null and b/model/__pycache__/sampler_sdf.cpython-39.pyc differ diff --git a/model/architecture/__pycache__/unet.cpython-39.pyc b/model/architecture/__pycache__/unet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aa5056279a5082643512082ac01ccecf2ffe101 Binary files /dev/null and b/model/architecture/__pycache__/unet.cpython-39.pyc differ diff --git a/model/architecture/__pycache__/unet_attention.cpython-39.pyc b/model/architecture/__pycache__/unet_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09005f959581d6013eab2ff0be9a496342c778fe Binary files /dev/null and b/model/architecture/__pycache__/unet_attention.cpython-39.pyc differ diff --git a/model/architecture/unet.py b/model/architecture/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..af7f656db30b07d7e8955a114b0a8a609cdd70b9 --- /dev/null +++ b/model/architecture/unet.py @@ -0,0 +1,364 @@ +""" +--- +title: U-Net for Stable Diffusion +summary: > + Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion. +--- + +# U-Net for [Stable Diffusion](../index.html) + +This implements the U-Net that + gives $\epsilon_\text{cond}(x_t, c)$ + +We have kept to the model definition and naming unchanged from +[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) +so that we can load the checkpoints directly. +""" + +import math +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .unet_attention import SpatialTransformer + + +class UNetModel(nn.Module): + """ + ## U-Net model + """ + def __init__( + self, + *, + in_channels: int, + out_channels: int, + channels: int, + n_res_blocks: int, + attention_levels: List[int], + channel_multipliers: List[int], + n_heads: int, + tf_layers: int = 1, + #d_cond: int = 768 + ): + """ + :param in_channels: is the number of channels in the input feature map + :param out_channels: is the number of channels in the output feature map + :param channels: is the base channel count for the model + :param n_res_blocks: number of residual blocks at each level + :param attention_levels: are the levels at which attention should be performed + :param channel_multipliers: are the multiplicative factors for number of channels for each level + :param n_heads: the number of attention heads in the transformers + """ + super().__init__() + self.channels = channels + self.out_channels = out_channels + #self.d_cond = d_cond + + # Number of levels + levels = len(channel_multipliers) + # Size time embeddings + d_time_emb = channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels, d_time_emb), + nn.SiLU(), + nn.Linear(d_time_emb, d_time_emb), + ) + + # Input half of the U-Net + self.input_blocks = nn.ModuleList() + # Initial $3 \times 3$ convolution that maps the input to `channels`. + # The blocks are wrapped in `TimestepEmbedSequential` module because + # different modules have different forward function signatures; + # for example, convolution only accepts the feature map and + # residual blocks accept the feature map and time embedding. + # `TimestepEmbedSequential` calls them accordingly. + self.input_blocks.append( + TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1)) + ) + # Number of channels at each block in the input half of U-Net + input_block_channels = [channels] + # Number of channels at each level + channels_list = [channels * m for m in channel_multipliers] + # Prepare levels + for i in range(levels): + # Add the residual blocks and attentions + for _ in range(n_res_blocks): + # Residual block maps from previous number of channels to the number of + # channels in the current level + layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])] + channels = channels_list[i] + # Add transformer + if i in attention_levels: + layers.append( + SpatialTransformer(channels, n_heads, tf_layers) + ) + # Add them to the input half of the U-Net and keep track of the number of channels of + # its output + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_channels.append(channels) + # Down sample at all levels except last + if i != levels - 1: + self.input_blocks.append(TimestepEmbedSequential(DownSample(channels))) + input_block_channels.append(channels) + + # The middle of the U-Net + self.middle_block = TimestepEmbedSequential( + ResBlock(channels, d_time_emb), + SpatialTransformer(channels, n_heads, tf_layers), + ResBlock(channels, d_time_emb), + ) + + # Second half of the U-Net + self.output_blocks = nn.ModuleList([]) + # Prepare levels in reverse order + for i in reversed(range(levels)): + # Add the residual blocks and attentions + for j in range(n_res_blocks + 1): + # Residual block maps from previous number of channels plus the + # skip connections from the input half of U-Net to the number of + # channels in the current level. + layers = [ + ResBlock( + channels + input_block_channels.pop(), + d_time_emb, + out_channels=channels_list[i] + ) + ] + channels = channels_list[i] + # Add transformer + if i in attention_levels: + layers.append( + SpatialTransformer(channels, n_heads, tf_layers) + ) + # Up-sample at every level after last residual block + # except the last one. + # Note that we are iterating in reverse; i.e. `i == 0` is the last. + if i != 0 and j == n_res_blocks: + layers.append(UpSample(channels)) + # Add to the output half of the U-Net + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + # Final normalization and $3 \times 3$ convolution + self.out = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv2d(channels, out_channels, 3, padding=1), + ) + + def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000): + """ + ## Create sinusoidal time step embeddings + + :param time_steps: are the time steps of shape `[batch_size]` + :param max_period: controls the minimum frequency of the embeddings. + """ + # $\frac{c}{2}$; half the channels are sin and the other half is cos, + half = self.channels // 2 + # $\frac{1}{10000^{\frac{2i}{c}}}$ + frequencies = torch.exp( + -math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=time_steps.device) + # $\frac{t}{10000^{\frac{2i}{c}}}$ + args = time_steps[:, None].float() * frequencies[None] + # $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ + return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + def forward(self, x: torch.Tensor, time_steps: torch.Tensor): + """ + :param x: is the input feature map of shape `[batch_size, channels, width, height]` + :param time_steps: are the time steps of shape `[batch_size]` + :param cond: conditioning of shape `[batch_size, n_cond, d_cond]` + """ + # To store the input half outputs for skip connections + x_input_block = [] + + # Get time step embeddings + t_emb = self.time_step_embedding(time_steps) + t_emb = self.time_embed(t_emb) + + # Input half of the U-Net + for module in self.input_blocks: + ########################## + #print("x:", x.dtype,"t_emb:",t_emb.dtype) + ########################## + #x = x.to(torch.float16) + x = module(x, t_emb) + x_input_block.append(x) + # Middle of the U-Net + x = self.middle_block(x, t_emb) + # Output half of the U-Net + for module in self.output_blocks: + # print(x.size(), 'a') + x = torch.cat([x, x_input_block.pop()], dim=1) + # print(x.size(), 'b') + x = module(x, t_emb) + + # Final normalization and $3 \times 3$ convolution + return self.out(x) + + +class TimestepEmbedSequential(nn.Sequential): + """ + ### Sequential block for modules with different inputs + + This sequential module can compose of different modules suck as `ResBlock`, + `nn.Conv` and `SpatialTransformer` and calls them with the matching signatures + """ + def forward(self, x, t_emb, cond=None): + for layer in self: + if isinstance(layer, ResBlock): + x = layer(x, t_emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x) + else: + x = layer(x) + return x + + +class UpSample(nn.Module): + """ + ### Up-sampling layer + """ + def __init__(self, channels: int): + """ + :param channels: is the number of channels + """ + super().__init__() + # $3 \times 3$ convolution mapping + self.conv = nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, x: torch.Tensor): + """ + :param x: is the input feature map with shape `[batch_size, channels, height, width]` + """ + # Up-sample by a factor of $2$ + x = F.interpolate(x, scale_factor=2, mode="nearest") + # Apply convolution + return self.conv(x) + + +class DownSample(nn.Module): + """ + ## Down-sampling layer + """ + def __init__(self, channels: int): + """ + :param channels: is the number of channels + """ + super().__init__() + # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$ + self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1) + + def forward(self, x: torch.Tensor): + """ + :param x: is the input feature map with shape `[batch_size, channels, height, width]` + """ + # Apply convolution + return self.op(x) + + +class ResBlock(nn.Module): + """ + ## ResNet Block + """ + def __init__(self, channels: int, d_t_emb: int, *, out_channels=None): + """ + :param channels: the number of input channels + :param d_t_emb: the size of timestep embeddings + :param out_channels: is the number of out channels. defaults to `channels. + """ + super().__init__() + # `out_channels` not specified + if out_channels is None: + out_channels = channels + + # First normalization and convolution + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv2d(channels, out_channels, 3, padding=1), + ) + + # Time step embeddings + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(d_t_emb, out_channels), + ) + # Final convolution layer + self.out_layers = nn.Sequential( + normalization(out_channels), nn.SiLU(), nn.Dropout(0.), + nn.Conv2d(out_channels, out_channels, 3, padding=1) + ) + + # `channels` to `out_channels` mapping layer for residual connection + if out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv2d(channels, out_channels, 1) + + def forward(self, x: torch.Tensor, t_emb: torch.Tensor): + """ + :param x: is the input feature map with shape `[batch_size, channels, height, width]` + :param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]` + """ + # Initial convolution + h = self.in_layers(x) + # Time step embeddings + t_emb = self.emb_layers(t_emb).type(h.dtype) + # Add time step embeddings + h = h + t_emb[:, :, None, None] + # Final convolution + h = self.out_layers(h) + # Add skip connection + return self.skip_connection(x) + h + + +class GroupNorm32(nn.GroupNorm): + """ + ### Group normalization with float32 casting + """ + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + ### Group normalization + + This is a helper function, with fixed number of groups.. + """ + return GroupNorm32(32, channels) + + +def _test_time_embeddings(): + """ + Test sinusoidal time step embeddings + """ + import matplotlib.pyplot as plt + + plt.figure(figsize=(15, 5)) + m = UNetModel( + in_channels=1, + out_channels=1, + channels=320, + n_res_blocks=1, + attention_levels=[], + channel_multipliers=[], + n_heads=1, + tf_layers=1, + d_cond=1 + ) + te = m.time_step_embedding(torch.arange(0, 1000)) + plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy()) + plt.legend(["dim %d" % p for p in [50, 100, 190, 260]]) + plt.title("Time embeddings") + plt.show() + + +# +if __name__ == '__main__': + _test_time_embeddings() diff --git a/model/architecture/unet_attention.py b/model/architecture/unet_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..84910964e624a04f3c83055455e9501ed56358a8 --- /dev/null +++ b/model/architecture/unet_attention.py @@ -0,0 +1,321 @@ +""" +--- +title: Transformer for Stable Diffusion U-Net +summary: > + Annotated PyTorch implementation/tutorial of the transformer + for U-Net in stable diffusion. +--- + +# Transformer for Stable Diffusion [U-Net](unet.html) + +This implements the transformer module used in [U-Net](unet.html) that + gives $\epsilon_\text{cond}(x_t, c)$ + +We have kept to the model definition and naming unchanged from +[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) +so that we can load the checkpoints directly. +""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class SpatialTransformer(nn.Module): + """ + ## Spatial Transformer + """ + def __init__(self, channels: int, n_heads: int, n_layers: int): + """ + :param channels: is the number of channels in the feature map + :param n_heads: is the number of attention heads + :param n_layers: is the number of transformer layers + :param d_cond: is the size of the conditional embedding + """ + super().__init__() + # Initial group normalization + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=channels, eps=1e-6, affine=True + ) + # Initial $1 \times 1$ convolution + self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + + # Transformer layers + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + channels, n_heads, channels // n_heads + ) for _ in range(n_layers) + ] + ) + + # Final $1 \times 1$ convolution + self.proj_out = nn.Conv2d( + channels, channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor): + """ + :param x: is the feature map of shape `[batch_size, channels, height, width]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + # Get shape `[batch_size, channels, height, width]` + b, c, h, w = x.shape + # For residual connection + x_in = x + # Normalize + x = self.norm(x) + # Initial $1 \times 1$ convolution + x = self.proj_in(x) + # Transpose and reshape from `[batch_size, channels, height, width]` + # to `[batch_size, height * width, channels]` + x = x.permute(0, 2, 3, 1).view(b, h * w, c) + # Apply the transformer layers + for block in self.transformer_blocks: + x = block(x) + # Reshape and transpose from `[batch_size, height * width, channels]` + # to `[batch_size, channels, height, width]` + x = x.view(b, h, w, c).permute(0, 3, 1, 2) + # Final $1 \times 1$ convolution + x = self.proj_out(x) + # Add residual + return x + x_in + + +class BasicTransformerBlock(nn.Module): + """ + ### Transformer Layer + """ + def __init__(self, d_model: int, n_heads: int, d_head: int): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + """ + super().__init__() + # Self-attention layer and pre-norm layer + self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head) + self.norm1 = nn.LayerNorm(d_model) + # Cross attention layer and pre-norm layer + #self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head) + self.norm2 = nn.LayerNorm(d_model) + # Feed-forward network and pre-norm layer + self.ff = FeedForward(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, x: torch.Tensor): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + # Self attention + x = self.attn1(self.norm1(x)) + x + # Cross-attention with conditioning + # x = self.attn2(self.norm2(x), cond=cond) + x + # Feed-forward network + x = self.ff(self.norm3(x)) + x + # + return x + + +class CrossAttention(nn.Module): + """ + ### Cross Attention Layer + + This falls-back to self-attention when conditional embeddings are not specified. + """ + + use_flash_attention: bool = False + + def __init__( + self, + d_model: int, + d_cond: int, + n_heads: int, + d_head: int, + is_inplace: bool = True + ): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + :param is_inplace: specifies whether to perform the attention softmax computation inplace to + save memory + """ + super().__init__() + + self.is_inplace = is_inplace + self.n_heads = n_heads + self.d_head = d_head + + # Attention scaling factor + self.scale = d_head**-0.5 + + # Query, key and value mappings + d_attn = d_head * n_heads + self.to_q = nn.Linear(d_model, d_attn, bias=False) + self.to_k = nn.Linear(d_cond, d_attn, bias=False) + self.to_v = nn.Linear(d_cond, d_attn, bias=False) + + # Final linear layer + self.to_out = nn.Sequential(nn.Linear(d_attn, d_model)) + + # Setup [flash attention](https://github.com/HazyResearch/flash-attention). + # Flash attention is only used if it's installed + # and `CrossAttention.use_flash_attention` is set to `True`. + try: + # You can install flash attention by cloning their Github repo, + # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) + # and then running `python setup.py install` + from flash_attn.flash_attention import FlashAttention + self.flash = FlashAttention() + # Set the scale for scaled dot-product attention. + self.flash.softmax_scale = self.scale + # Set to `None` if it's not installed + except ImportError: + self.flash = None + + def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + + # If `cond` is `None` we perform self attention + has_cond = cond is not None + if not has_cond: + cond = x + + # Get query, key and value vectors + q = self.to_q(x) + k = self.to_k(cond) + v = self.to_v(cond) + + # Use flash attention if it's available and the head size is less than or equal to `128` + if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128: + return self.flash_attention(q, k, v) + # Otherwise, fallback to normal attention + else: + return self.normal_attention(q, k, v) + + def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Flash Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Get batch size and number of elements along sequence axis (`width * height`) + batch_size, seq_len, _ = q.shape + + # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of + # shape `[batch_size, seq_len, 3, n_heads * d_head]` + qkv = torch.stack((q, k, v), dim=2) + # Split the heads + qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + + # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to + # fit this size. + if self.d_head <= 32: + pad = 32 - self.d_head + elif self.d_head <= 64: + pad = 64 - self.d_head + elif self.d_head <= 128: + pad = 128 - self.d_head + else: + raise ValueError(f'Head size ${self.d_head} too large for Flash Attention') + + # Pad the heads + if pad: + qkv = torch.cat( + (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 + ) + + # Compute attention + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` + out, _ = self.flash(qkv) + # Truncate the extra head size + out = out[:, :, :, : self.d_head] + # Reshape to `[batch_size, seq_len, n_heads * d_head]` + out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Normal Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` + q = q.view(*q.shape[: 2], self.n_heads, -1) + k = k.view(*k.shape[: 2], self.n_heads, -1) + v = v.view(*v.shape[: 2], self.n_heads, -1) + + # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ + attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale + + # Compute softmax + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ + if self.is_inplace: + half = attn.shape[0] // 2 + attn[half :] = attn[half :].softmax(dim=-1) + attn[: half] = attn[: half].softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + # Compute attention output + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + out = torch.einsum('bhij,bjhd->bihd', attn, v) + # Reshape to `[batch_size, height * width, n_heads * d_head]` + out = out.reshape(*out.shape[: 2], -1) + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + +class FeedForward(nn.Module): + """ + ### Feed-Forward Network + """ + def __init__(self, d_model: int, d_mult: int = 4): + """ + :param d_model: is the input embedding size + :param d_mult: is multiplicative factor for the hidden layer size + """ + super().__init__() + self.net = nn.Sequential( + GeGLU(d_model, d_model * d_mult), nn.Dropout(0.), + nn.Linear(d_model * d_mult, d_model) + ) + + def forward(self, x: torch.Tensor): + return self.net(x) + + +class GeGLU(nn.Module): + """ + ### GeGLU Activation + + $$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$ + """ + def __init__(self, d_in: int, d_out: int): + super().__init__() + # Combined linear projections $xW + b$ and $xV + c$ + self.proj = nn.Linear(d_in, d_out * 2) + + def forward(self, x: torch.Tensor): + # Get $xW + b$ and $xV + c$ + x, gate = self.proj(x).chunk(2, dim=-1) + # $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$ + return x * F.gelu(gate) diff --git a/model/latent_diffusion.py b/model/latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a73bad5e1b0635cdbe2b96e47c924036a55c5217 --- /dev/null +++ b/model/latent_diffusion.py @@ -0,0 +1,222 @@ +""" +--- +title: Latent Diffusion Models +summary: > + Annotated PyTorch implementation/tutorial of latent diffusion models from paper + High-Resolution Image Synthesis with Latent Diffusion Models +--- + +# Latent Diffusion Models + +Latent diffusion models use an auto-encoder to map between image space and +latent space. The diffusion model works on the diffusion space, which makes it +a lot easier to train. +It is based on paper +[High-Resolution Image Synthesis with Latent Diffusion Models](https://papers.labml.ai/paper/2112.10752). + +They use a pre-trained auto-encoder and train the diffusion U-Net on the latent +space of the pre-trained auto-encoder. + +For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html). +We use same notations for $\alpha_t$, $\beta_t$ schedules, etc. +""" + +from typing import List, Tuple, Optional, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from .architecture.unet import UNetModel +import random + + +def gather(consts: torch.Tensor, t: torch.Tensor): + """Gather consts for $t$ and reshape to feature map shape""" + c = consts.gather(-1, t) + return c.reshape(-1, 1, 1, 1) + + +class LatentDiffusion(nn.Module): + """ + ## Latent diffusion model + + This contains following components: + + * [AutoEncoder](model/autoencoder.html) + * [U-Net](model/unet.html) with [attention](model/unet_attention.html) + """ + eps_model: UNetModel + #first_stage_model: Optional[Autoencoder] = None + + def __init__( + self, + unet_model: UNetModel, + latent_scaling_factor: float, + n_steps: int, + linear_start: float, + linear_end: float, + debug_mode: Optional[bool] = False + ): + """ + :param unet_model: is the [U-Net](model/unet.html) that predicts noise + $\epsilon_\text{cond}(x_t, c)$, in latent space + :param autoencoder: is the [AutoEncoder](model/autoencoder.html) + :param latent_scaling_factor: is the scaling factor for the latent space. The encodings of + the autoencoder are scaled by this before feeding into the U-Net. + :param n_steps: is the number of diffusion steps $T$. + :param linear_start: is the start of the $\beta$ schedule. + :param linear_end: is the end of the $\beta$ schedule. + """ + super().__init__() + # Wrap the [U-Net](model/unet.html) to keep the same model structure as + # [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion). + self.eps_model = unet_model + self.latent_scaling_factor = latent_scaling_factor + + # Number of steps $T$ + self.n_steps = n_steps + + # $\beta$ schedule + beta = torch.linspace( + linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64 + ) ** 2 + # $\alpha_t = 1 - \beta_t$ + alpha = 1. - beta + # $\bar\alpha_t = \prod_{s=1}^t \alpha_s$ + alpha_bar = torch.cumprod(alpha, dim=0) + self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False) + self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False) + self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False) + self.alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]]) + self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) + self.sigma2 = self.beta + + self.debug_mode = debug_mode + + @property + def device(self): + """ + ### Get model device + """ + return next(iter(self.eps_model.parameters())).device + + + + def forward(self, x: torch.Tensor, t: torch.Tensor): + """ + ### Predict noise + + Predict noise given the latent representation $x_t$, time step $t$, and the + conditioning context $c$. + + $$\epsilon_\text{cond}(x_t, c)$$ + """ + return self.eps_model(x, t) + + def q_xt_x0(self, x0: torch.Tensor, + t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + #### Get $q(x_t|x_0)$ distribution + """ + + # [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$ + mean = gather(self.alpha_bar, t)**0.5 * x0 + # $(1-\bar\alpha_t) \mathbf{I}$ + var = 1 - gather(self.alpha_bar, t) + # + return mean, var + + def q_sample( + self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None + ): + """ + #### Sample from $q(x_t|x_0)$ + """ + + # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ + if eps is None: + eps = torch.randn_like(x0) + + # get $q(x_t|x_0)$ + mean, var = self.q_xt_x0(x0, t) + # Sample from $q(x_t|x_0)$ + return mean + (var**0.5) * eps + + def p_sample(self, xt: torch.Tensor, t: torch.Tensor): + """ + #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ + """ + + # $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ + eps_theta = self.eps_model(xt, t) + # [gather](utils.html) $\bar\alpha_t$ + alpha_bar = gather(self.alpha_bar, t) + # [gather](utils.html) $\bar\alpha_t-1$ + alpha_bar_prev = gather(self.alpha_bar_prev, t) + # [gather](utils.html) $\sigma_t$ + sigma_ddim = gather(self.sigma_ddim, t) + + # DDIM sampling + # $\frac{x_t-\sqrt{1-\bar\alpha_t}\epsilon}{\sqrt{\bar\alpha_t}}$ + predicted_x0 = (xt - (1-alpha_bar)**0.5 * eps_theta) / (alpha_bar)**.5 + # $\sqrt{1-\alpha_{t-1}-\sigma_t^2}$ + direction_to_xt = (1 - alpha_bar_prev - sigma_ddim**2)**0.5 * eps_theta + + # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ + eps = torch.randn(xt.shape, device=xt.device) + + # Sample + x_tm_1 = alpha_bar_prev**0.5 * predicted_x0 + direction_to_xt + sigma_ddim * eps + return x_tm_1 + + def loss( + self, + x0: torch.Tensor, + #autoreg_cond: Union[torch.Tensor, None], #This means it can be either a tensor or none + #external_cond: Union[torch.Tensor, None], + noise: Optional[torch.Tensor] = None, + ): + """ + #### Simplified Loss + """ + # Get batch size + batch_size = x0.shape[0] + # Get random $t$ for each sample in the batch + t = torch.randint( + 0, self.n_steps, (batch_size, ), device=x0.device, dtype=torch.long + ) + + + #autoreg_cond = -torch.ones(x0.size(0), 1, self.eps_model.d_cond, device=x0.device, dtype=x0.dtype) + #cond = autoreg_cond + + if x0.size(1) == self.eps_model.out_channels: # generating form + if self.debug_mode: + print('In the mode of root level:', x0.size()) + if noise is None: + x0 = x0.to(torch.float32) + noise = torch.randn_like(x0) + + xt = self.q_sample(x0, t, eps=noise) + + eps_theta = self.eps_model(xt, t) + + loss = F.mse_loss(noise, eps_theta) + else: + if self.debug_mode: + print('In the mode of non-root level:', x0.size()) + + if noise is None: + noise = torch.randn_like(x0[:, 0: 2]) + + front_t = self.q_sample(x0[:, 0: 2], t, eps=noise) + + background_cond = x0[:, 2:] + + xt = torch.cat([front_t, background_cond], 1) + + eps_theta = self.eps_model(xt, t) + + loss = F.mse_loss(noise, eps_theta) + if self.debug_mode: + print('loss:', loss) + return loss diff --git a/model/model_sdf.py b/model/model_sdf.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5a0f3603ba536528ec8d9b3b056315341d0ce5 --- /dev/null +++ b/model/model_sdf.py @@ -0,0 +1,55 @@ +import torch +import os +import torch.nn as nn +from .latent_diffusion import LatentDiffusion + + +class Diffpro_SDF(nn.Module): + + def __init__( + self, + ldm: LatentDiffusion, + ): + """ + cond_type: {chord, texture} + cond_mode: {cond, mix, uncond} + mix: use a special condition for unconditional learning with probability of 0.2 + use_enc: whether to use pretrained chord encoder to generate encoded condition + """ + super(Diffpro_SDF, self).__init__() + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.ldm = ldm + + @classmethod + def load_trained( + cls, + ldm, + chkpt_fpath, + ): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = cls(ldm) + trained_leaner = torch.load(chkpt_fpath, map_location=device) + try: + model.load_state_dict(trained_leaner["model"]) + except RuntimeError: + model_dict = trained_leaner["model"] + model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()} + model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()} + model.load_state_dict(model_dict) + return model + + def p_sample(self, xt: torch.Tensor, t: torch.Tensor): + return self.ldm.p_sample(xt, t) + + def q_sample(self, x0: torch.Tensor, t: torch.Tensor): + return self.ldm.q_sample(x0, t) + + def get_loss_dict(self, batch, step): + """ + z_y is the stuff the diffusion model needs to learn + """ + # x = batch.float().to(self.device) + + x= batch + loss = self.ldm.loss(x) + return {"loss": loss} diff --git a/model/sampler_sdf.py b/model/sampler_sdf.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c40682b7b298261adfc697c803d58b1aa099ef --- /dev/null +++ b/model/sampler_sdf.py @@ -0,0 +1,538 @@ +from typing import Optional, List, Union +import numpy as np +import torch +from labml import monit +from .latent_diffusion import LatentDiffusion + +def set_seed(seed): + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +# Call the function to set the seed +# set_seed(42) + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class DiffusionSampler: + """ + ## Base class for sampling algorithms + """ + model: LatentDiffusion + + def __init__(self, model: LatentDiffusion): + """ + :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ + """ + super().__init__() + # Set the model $\epsilon_\text{cond}(x_t, c)$ + self.model = model + # Get number of steps the model was trained with $T$ + self.n_steps = model.n_steps + + +class SDFSampler(DiffusionSampler): + """ + ## DDPM Sampler + + This extends the [`DiffusionSampler` base class](index.html). + + DDPM samples images by repeatedly removing noise by sampling step by step from + $p_\theta(x_{t-1} | x_t)$, + + \begin{align} + + p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\ + + \mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0 + + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\ + + \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\ + + x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\ + + \end{align} + """ + + model: LatentDiffusion + + def __init__( + self, + model: LatentDiffusion, + max_l, + h, + is_autocast=False, + is_show_image=False, + device=None, + debug_mode=False + ): + """ + :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ + """ + super().__init__(model) + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + # selected time steps ($\tau$) $1, 2, \dots, T$ + # self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32) + self.tau = torch.tensor([13, 53, 116, 193, 310, 443, 587, 730, 845, 999], device=self.device) # torch.tensor([999, 845, 730, 587, 443, 310, 193, 116, 53, 13]) + # self.tau = torch.tensor(np.asarray(list(range(self.n_steps)), dtype=np.int32), device=self.device) + self.used_n_steps = len(self.tau) + + self.is_show_image = is_show_image + + self.autocast = torch.cuda.amp.autocast(enabled=is_autocast) + + self.out_channel = self.model.eps_model.out_channels + self.max_l = max_l + self.h = h + self.debug_mode = debug_mode + self.guidance_scale = 7.5 + self.guidance_rescale = 0.7 + + # now, we set the coefficients + with torch.no_grad(): + # $\bar\alpha_t$ + self.alpha_bar = self.model.alpha_bar + # $\beta_t$ schedule + beta = self.model.beta + # $\bar\alpha_{t-1}$ + self.alpha_bar_prev = torch.cat([self.alpha_bar.new_tensor([1.]), self.alpha_bar[:-1]]) + # $\sigma_t$ in DDIM + self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) # DDPM noise schedule + + # $\frac{1}{\sqrt{\bar\alpha}}$ + self.one_over_sqrt_alpha_bar = 1 / (self.alpha_bar ** 0.5) + # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$ + self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar = (1 - self.alpha_bar)**0.5 / self.alpha_bar**0.5 + + # $\sqrt{\bar\alpha}$ + self.sqrt_alpha_bar = self.alpha_bar ** 0.5 + # $\sqrt{1 - \bar\alpha}$ + self.sqrt_1m_alpha_bar = (1 - self.alpha_bar) ** 0.5 + # # $\sqrt{\bar\alpha_{t-1}}$ + # self.sqrt_alpha_bar_prev = self.alpha_bar_prev ** 0.5 + # # $\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}$ + # self.sqrt_1m_alpha_bar_prev_m_sigma2 = (1 - self.alpha_bar_prev - self.sigma_ddim ** 2) ** 0.5 + + #@property + # def d_cond(self): + #return self.model.eps_model.d_cond + + def get_eps( + self, + x: torch.Tensor, + t: torch.Tensor, + background_cond: Optional[torch.Tensor], + + uncond_scale: Optional[float], + ): + """ + ## Get $\epsilon(x_t, c)$ + + :param x: is $x_t$ of shape `[batch_size, channels, height, width]` + :param t: is $t$ of shape `[batch_size]` + :param background_cond: background condition + :param autoreg_cond: autoregressive condition + :param external_cond: external condition + :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]` + :param uncond_scale: is the unconditional guidance scale $s$. This is used for + $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ + :param uncond_cond: is the conditional embedding for empty prompt $c_u$ + """ + # When the scale $s = 1$ + # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$ + + batch_size = x.size(0) + + # if hasattr(self.model, 'style_enc'): + # if external_cond is not None: + # external_cond = self.model.external_cond_enc(external_cond) + # if uncond_scale is None or uncond_scale == 1: + # external_uncond = (-torch.ones_like(external_cond)).to(self.device) + # else: + # external_uncond = None + # # if random.random() < 0.2: + # # external_cond = (-torch.ones_like(external_cond)).to(self.device) + # else: + # external_cond = -torch.ones(batch_size, 4, self.d_cond, device=x.device, dtype=x.dtype) + # external_uncond = None + # cond = torch.cat([autoreg_cond, external_cond], 1) + # if external_uncond is None: + # uncond = None + # else: + # uncond = torch.cat([autoreg_cond, external_uncond], 1) + # else: + # cond = autoreg_cond + # uncond = None + + if background_cond is not None: + x = torch.cat([x, background_cond], 1) if background_cond is not None else x + + # if uncond is None: + # e_t = self.model(x, t, cond) + # else: + # e_t_cond = self.model(x, t, cond) + # e_t_uncond = self.model(x, t, uncond) + # e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) + + e_t = self.model(x,t) + return e_t + + @torch.no_grad() + def p_sample( + self, + x: torch.Tensor, + background_cond: Optional[torch.Tensor], + #autoreg_cond: Optional[torch.Tensor], + #external_cond: Optional[torch.Tensor], + t: torch.Tensor, + step: int, + repeat_noise: bool = False, + temperature: float = 1., + uncond_scale: float = 1., + same_noise_all_measure: bool = False, + X0EditFunc = None, + use_classifier_free_guidance = False, + use_lsh = False, + reduce_extra_notes=True, + rhythm_control="Yes", + ): + print("p_sample") + """ + ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$ + + :param x: is $x_t$ of shape `[batch_size, channels, height, width]` + :param background_cond: background condition + :param autoreg_cond: autoregressive condition + :param external_cond: external condition + :param t: is $t$ of shape `[batch_size]` + :param step: is the step $t$ as an integer + :param repeat_noise: specified whether the noise should be same for all samples in the batch + :param temperature: is the noise temperature (random noise gets multiplied by this) + :param uncond_scale: is the unconditional guidance scale $s$. This is used for + $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ + """ + # Get current tau_i and tau_{i-1} + tau_i = self.tau[t] + step_tau_i = self.tau[step] + + # Get $\epsilon_\theta$ + with self.autocast: + if use_classifier_free_guidance: + if use_lsh: + assert background_cond.shape[1] == 6 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain, lsh_onset, lsh_sustain + null_lsh = -torch.ones_like(background_cond[:,4:,:,:]) + null_background_cond = torch.cat([background_cond[:,2:4,:,:], null_lsh], axis=1) + real_background_cond = torch.cat([background_cond[:,:2,:,:], background_cond[:,4:,:,:]], axis=1) + + e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale) + e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale) + e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null) + if self.guidance_rescale > 0: + e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale) + else: + assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain + null_background_cond = background_cond[:,2:,:,:] + real_background_cond = background_cond[:,:2,:,:] + e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale) + e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale) + e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null) + if self.guidance_rescale > 0: + e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale) + else: + if use_lsh: + assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, lsh_onset, lsh_sustain + e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale) + else: + assert background_cond.shape[1] == 2 # chd_onset, chd_sustain + e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale) + + # Get batch size + bs = x.shape[0] + + # $\frac{1}{\sqrt{\bar\alpha}}$ + one_over_sqrt_alpha_bar = x.new_full( + (bs, 1, 1, 1), self.one_over_sqrt_alpha_bar[step_tau_i] + ) + # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$ + sqrt_1m_alpha_bar_over_sqrt_alpha_bar = x.new_full( + (bs, 1, 1, 1), self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar[step_tau_i] + ) + + # $\sigma_t$ in DDIM + sigma_ddim = x.new_full( + (bs, 1, 1, 1), self.sigma_ddim[step_tau_i] + ) + + + # Calculate $x_0$ with current $\epsilon_\theta$ + # + # predicted x_0 in DDIM + predicted_x0 = one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - sqrt_1m_alpha_bar_over_sqrt_alpha_bar * e_tau_i + + # edit predicted x_0 + if X0EditFunc is not None: + predicted_x0 = X0EditFunc(predicted_x0, background_cond, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control) + e_tau_i = (one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - predicted_x0) / sqrt_1m_alpha_bar_over_sqrt_alpha_bar + + # Do not add noise when $t = 1$ (final step sampling process). + # Note that `step` is `0` when $t = 1$) + if step == 0: + noise = 0 + # If same noise is used for all samples in the batch + elif repeat_noise: + if same_noise_all_measure: + noise = torch.randn((1, predicted_x0.shape[1], 16, predicted_x0.shape[3]), device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1) + else: + noise = torch.randn((1, *predicted_x0.shape[1:]), device=self.device) + # Different noise for each sample + else: + if same_noise_all_measure: + noise = torch.randn(predicted_x0.shape[0], predicted_x0.shape[1], 16, predicted_x0.shape[3], device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1) + else: + noise = torch.randn(predicted_x0.shape, device=self.device) + + # Multiply noise by the temperature + noise = noise * temperature + + if step > 0: + step_tau_i_m_1 = self.tau[step-1] + # $\sqrt{\bar\alpha_{\tau_i-1}}$ + sqrt_alpha_bar_prev = x.new_full( + (bs, 1, 1, 1), self.sqrt_alpha_bar[step_tau_i_m_1] + ) + # $\sqrt{1-\bar\alpha_{\tau_i-1}-\sigma_\tau^2}$ + sqrt_1m_alpha_bar_prev_m_sigma2 = x.new_full( + (bs, 1, 1, 1), (1 - self.alpha_bar[step_tau_i_m_1] - self.sigma_ddim[step_tau_i] ** 2) ** 0.5 + ) + direction_to_xt = sqrt_1m_alpha_bar_prev_m_sigma2 * e_tau_i + x_prev = sqrt_alpha_bar_prev * predicted_x0 + direction_to_xt + sigma_ddim * noise + else: + x_prev = predicted_x0 + sigma_ddim * noise + + # Sample from, + # + # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$ + + # + return x_prev, predicted_x0, e_tau_i + + @torch.no_grad() + def q_sample( + self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None + ): + """ + ### Sample from $q(x_t|x_0)$ + + $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$ + + :param x0: is $x_0$ of shape `[batch_size, channels, height, width]` + :param index: is the time step $t$ index + :param noise: is the noise, $\epsilon$ + """ + + # Random noise, if noise is not specified + if noise is None: + noise = torch.randn_like(x0, device=self.device) + + # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$ + return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise + + @torch.no_grad() + def sample( + self, + shape: List[int], + background_cond: Optional[torch.Tensor] = None, + #autoreg_cond: Optional[torch.Tensor] = None, + #external_cond: Optional[torch.Tensor] = None, + repeat_noise: bool = False, + temperature: float = 1., + uncond_scale: float = 1., + x_last: Optional[torch.Tensor] = None, + t_start: int = 0, + same_noise_all_measure: bool = False, + X0EditFunc = None, + use_classifier_free_guidance = False, + use_lsh = False, + reduce_extra_notes=True, + rhythm_control="Yes", + ): + """ + ### Sampling Loop + + :param shape: is the shape of the generated images in the + form `[batch_size, channels, height, width]` + :param background_cond: background condition + :param autoreg_cond: autoregressive condition + :param external_cond: external condition + :param repeat_noise: specified whether the noise should be same for all samples in the batch + :param temperature: is the noise temperature (random noise gets multiplied by this) + :param x_last: is $x_T$. If not provided random noise will be used. + :param uncond_scale: is the unconditional guidance scale $s$. This is used for + $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ + :param t_start: t_start + """ + + # Get device and batch size + bs = shape[0] + + ###### + print(shape) + ###### + + + # Get $x_T$ + if same_noise_all_measure: + x = x_last if x_last is not None else torch.randn(shape[0],shape[1],16,shape[3], device=self.device).repeat(1,1,int(shape[2]/16),1) + else: + x = x_last if x_last is not None else torch.randn(shape, device=self.device) + + # Time steps to sample at $T - t', T - t' - 1, \dots, 1$ + time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:] + + # Sampling loop + for step in monit.iterate('Sample', time_steps): + # Time step $t$ + ts = x.new_full((bs, ), step, dtype=torch.long) + + x, pred_x0, e_t = self.p_sample( + x, + background_cond, + #autoreg_cond, + #external_cond, + ts, + step, + repeat_noise=repeat_noise, + temperature=temperature, + uncond_scale=uncond_scale, + same_noise_all_measure=same_noise_all_measure, + X0EditFunc = X0EditFunc, + use_classifier_free_guidance = use_classifier_free_guidance, + use_lsh=use_lsh, + reduce_extra_notes=reduce_extra_notes, + rhythm_control=rhythm_control + ) + + s1 = step + 1 + + # if self.is_show_image: + # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0): + # show_image(x, f"exp/img/x{s1}.png") + + # Return $x_0$ + # if self.is_show_image: + # show_image(x, f"exp/img/x0.png") + + return x + + @torch.no_grad() + def paint( + self, + x: Optional[torch.Tensor] = None, + background_cond: Optional[torch.Tensor] = None, + #autoreg_cond: Optional[torch.Tensor] = None, + #external_cond: Optional[torch.Tensor] = None, + t_start: int = 0, + orig: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + orig_noise: Optional[torch.Tensor] = None, + uncond_scale: float = 1., + same_noise_all_measure: bool = False, + X0EditFunc = None, + use_classifier_free_guidance = False, + use_lsh = False, + ): + """ + ### Painting Loop + + :param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]` + :param background_cond: background condition + :param autoreg_cond: autoregressive condition + :param external_cond: external condition + :param t_start: is the sampling step to start from, $S'$ + :param orig: is the original image in latent page which we are in paining. + If this is not provided, it'll be an image to image transformation. + :param mask: is the mask to keep the original image. + :param orig_noise: is fixed noise to be added to the original image. + :param uncond_scale: is the unconditional guidance scale $s$. This is used for + $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ + """ + # Get batch size + bs = orig.size(0) + + if x is None: + x = torch.randn(orig.shape, device=self.device) + + # Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$ + # time_steps = np.flip(self.time_steps[: t_start]) + time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:] + + for i, step in monit.enum('Paint', time_steps): + # Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$ + # index = len(time_steps) - i - 1 + # Time step $\tau_i$ + ts = x.new_full((bs, ), step, dtype=torch.long) + + # Sample $x_{\tau_{i-1}}$ + x, _, _ = self.p_sample( + x, + background_cond, + #autoreg_cond, + #external_cond, + t=ts, + step=step, + uncond_scale=uncond_scale, + same_noise_all_measure=same_noise_all_measure, + X0EditFunc = X0EditFunc, + use_classifier_free_guidance = use_classifier_free_guidance, + use_lsh=use_lsh, + ) + + # Replace the masked area with original image + if orig is not None: + assert mask is not None + # Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space + orig_t = self.q_sample(orig, self.tau[step], noise=orig_noise) + # Replace the masked area + x = orig_t * mask + x * (1 - mask) + + s1 = step + 1 + + # if self.is_show_image: + # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0): + # show_image(x, f"exp/img/x{s1}.png") + + # if self.is_show_image: + # show_image(x, f"exp/img/x0.png") + return x + + def generate(self, background_cond=None, batch_size=1, uncond_scale=None, + same_noise_all_measure=False, X0EditFunc=None, + use_classifier_free_guidance=False, use_lsh=False, reduce_extra_notes=True, rhythm_control="Yes"): + + shape = [batch_size, self.out_channel, self.max_l, self.h] + + if self.debug_mode: + return torch.randn(shape, dtype=torch.float) + + return self.sample(shape, background_cond, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure, + X0EditFunc=X0EditFunc, use_classifier_free_guidance=use_classifier_free_guidance, use_lsh=use_lsh, + reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control + ) + diff --git a/output_0.mid b/output_0.mid new file mode 100644 index 0000000000000000000000000000000000000000..26a8e00407f4dfa8be73d4643ab5a6560dbc4462 Binary files /dev/null and b/output_0.mid differ diff --git a/output_0.wav b/output_0.wav new file mode 100644 index 0000000000000000000000000000000000000000..86b791aa3e6a84ec9d2732490dbb0908b4af8969 --- /dev/null +++ b/output_0.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9abffb8b039f86161f025cab6419eecf93ec741ea67e66964ca8e79d333c9d4 +size 2469720 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..5cb2b5f4bb78a09b9b44fdc10a858d09f83a5b07 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +TiMidity++ \ No newline at end of file diff --git a/piano_roll.png b/piano_roll.png new file mode 100644 index 0000000000000000000000000000000000000000..584ba9d69def93476be0a7bf04cedc490060da97 --- /dev/null +++ b/piano_roll.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf4d433689089c7895ed3ebf569e2dda9284a8d443481f6d8aa9f9575089cc37 +size 16389 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1628aed3bcc7a7c613f97c177380a9295e56eed5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +gradio#==4.44.0 +imageio#==2.35.1 +imageio[ffmpeg] +labml#==0.5.3 +librosa +mir_eval +matplotlib +music21 +numba==0.53.1 +numpy==1.19.5 +opencv-python +# pandas==1.2.5 +pretty_midi +pydub +requests +soundfile +fluidsynth +scikit-learn +torch==2.4.1 +torchvision +tqdm +tensorboard \ No newline at end of file diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt new file mode 100644 index 0000000000000000000000000000000000000000..f1e3c65efd105af1dfb49178192f4a1c01647c56 --- /dev/null +++ b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:852cfd7b011bb1ba0ce1d0d05a7acd672c7cd4934756b2e0d357d8002b5ecb6b +size 441623592 diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 new file mode 100644 index 0000000000000000000000000000000000000000..4973f25dbe7aff4d6b1de8911ea7fdb6c9402412 Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 differ diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 new file mode 100644 index 0000000000000000000000000000000000000000..2282c4934cc1887fd2bcbb353a3f38a358f5ad73 Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 differ diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 new file mode 100644 index 0000000000000000000000000000000000000000..64251a673711c650dbdb300800991466aca755c4 Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 differ diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 new file mode 100644 index 0000000000000000000000000000000000000000..5b10b32f0a0fd3ab8d9aefa118dfae26b9bc0a04 Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 differ diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 new file mode 100644 index 0000000000000000000000000000000000000000..e2781e652ddfeb83f30c1dfc006e7d67a4081eca Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 differ diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json new file mode 100644 index 0000000000000000000000000000000000000000..d5eff601a3fbd0fce6396ef1b785330777b4e0b8 --- /dev/null +++ b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json @@ -0,0 +1 @@ +{"batch_size": 16, "max_epoch": 10, "learning_rate": 5e-05, "max_grad_norm": 10, "fp16": true, "in_channels": 6, "out_channels": 2, "channels": 64, "attention_levels": [2, 3], "n_res_blocks": 2, "channel_multipliers": [1, 2, 4, 4], "n_heads": 4, "tf_layers": 1, "d_cond": 2, "linear_start": 0.00085, "linear_end": 0.012, "n_steps": 1000, "latent_scaling_factor": 0.18215} \ No newline at end of file diff --git a/rhythm_plot_0.png b/rhythm_plot_0.png new file mode 100644 index 0000000000000000000000000000000000000000..0513bc832e5434f65f06b11a2039a3c83cfea7ad --- /dev/null +++ b/rhythm_plot_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:967207ba25f584c5f1559beea33807766b5c7472a025e4e9e182b82e3876e143 +size 11476 diff --git a/runtime.txt b/runtime.txt new file mode 100644 index 0000000000000000000000000000000000000000..032aea2dbcb01c94cd1199869e1e97d9dc351e2e --- /dev/null +++ b/runtime.txt @@ -0,0 +1 @@ +python-3.9 \ No newline at end of file diff --git a/samples/control_vs_uncontrol/example_1_acc_control.jpg b/samples/control_vs_uncontrol/example_1_acc_control.jpg new file mode 100644 index 0000000000000000000000000000000000000000..72932aacab9f1c895623c59988f9a97c4608f689 --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_acc_control.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b1addf1c8a495c2cd6dc3f4c2fe2d86139e6808ad839ab8fbc7070bf8d02314 +size 122732 diff --git a/samples/control_vs_uncontrol/example_1_acc_control.wav b/samples/control_vs_uncontrol/example_1_acc_control.wav new file mode 100644 index 0000000000000000000000000000000000000000..932a8f6f3180e528abd10fdd806a1e4a51e3cfe9 --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_acc_control.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98254caf2f2943c0f4bf4a07cfddf0c0dfa8bc33f41e5aafcd2626f33763b680 +size 2469720 diff --git a/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg b/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5f84c39af65b24d2813462c7efdf99bda35062e --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c857985742375d62e30b6a7ad86c2ac27978a5d2f9417a01d85cadf7aa87042 +size 136185 diff --git a/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav b/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav new file mode 100644 index 0000000000000000000000000000000000000000..7775b7cef3c71484f71351381c36ad89ecfc4d9c --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:582e9d559424e40216f3d7e2ded5b07844311da1783fd77792b58a27fa27ba8c +size 2469720 diff --git a/samples/control_vs_uncontrol/example_1_mel_chd.jpg b/samples/control_vs_uncontrol/example_1_mel_chd.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e91525f4de913dc0b0c1e1cd9676491647a8dae7 --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_mel_chd.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7308a7c0255420a712837437cbb37a567b2c01606c0d020caa4d44f94a1f8464 +size 71507 diff --git a/samples/control_vs_uncontrol/example_1_mel_chd.wav b/samples/control_vs_uncontrol/example_1_mel_chd.wav new file mode 100644 index 0000000000000000000000000000000000000000..59df9f6eee415112786a72c92103881fd9af7a86 --- /dev/null +++ b/samples/control_vs_uncontrol/example_1_mel_chd.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cedfa10bc44605e9cd1c0a04a82d90f7b6c31dfe68d790a7d9c8e6e27117c90e +size 5292046 diff --git a/samples/control_vs_uncontrol/example_2_acc_control.jpg b/samples/control_vs_uncontrol/example_2_acc_control.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8dce49aec16af5785380c1fbf83b20b5b709162e --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_acc_control.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e7a9691db8538d58c3d4fff42184c3aac32a73849cc4f2c14d4ec540e9f5d30 +size 96583 diff --git a/samples/control_vs_uncontrol/example_2_acc_control.wav b/samples/control_vs_uncontrol/example_2_acc_control.wav new file mode 100644 index 0000000000000000000000000000000000000000..717fb8d86981d66e25122fdaa5f96421ea6b509a --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_acc_control.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6e70d13d7d97c0cb3cbf38d9ad2473f5df5eb04c69a54ddb228aca1134f819a +size 2364108 diff --git a/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg b/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eae50dc0c93519922b38abbabcec7d2969d89050 --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01da8567d60a0ab49c4d0e9a78b21ab1a102e196c2eccb39c30a5a721f1998d9 +size 108920 diff --git a/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav b/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav new file mode 100644 index 0000000000000000000000000000000000000000..ff4b8970c5d4f8a86e39ee413160b4e052b6b0c3 --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a60053aca94dac302379e536774832a3f90bf2fd7a0d7a23c2dc97b742f78910 +size 2364108 diff --git a/samples/control_vs_uncontrol/example_2_mel_chd.jpg b/samples/control_vs_uncontrol/example_2_mel_chd.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9bdbd0c363b6cf36204117081d97b75679e03f9 --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_mel_chd.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad9825c1cff525d4f348d60e77e8a28769e58a4ef93730a4c8fcfc8394f40e92 +size 60218 diff --git a/samples/control_vs_uncontrol/example_2_mel_chd.wav b/samples/control_vs_uncontrol/example_2_mel_chd.wav new file mode 100644 index 0000000000000000000000000000000000000000..2b0657befab4806dc202f3fe80ad49c605ac039f --- /dev/null +++ b/samples/control_vs_uncontrol/example_2_mel_chd.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cafc0f0a3c375155457479eb2403e17aead89cd78aa021cdfeec78301930ee4f +size 5292046 diff --git a/samples/different_styles/chinese_1.wav b/samples/different_styles/chinese_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..e131cf54c032e1fc1eb0601baae6042aa5b09aaa --- /dev/null +++ b/samples/different_styles/chinese_1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cb5284cedced19d23ea5471ada1b3a9357326c64440170ecf05eec74fecf53f +size 2364108 diff --git a/samples/different_styles/chinese_2.wav b/samples/different_styles/chinese_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..8ebc18741528e863de67ed8b532fa7de9819322b --- /dev/null +++ b/samples/different_styles/chinese_2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84ae2c4c480d225239fd28aec180d83ff13b2a3d58074d22fc78e018ebf2d790 +size 2364108 diff --git a/samples/different_styles/chinese_scale.jpg b/samples/different_styles/chinese_scale.jpg new file mode 100644 index 0000000000000000000000000000000000000000..83c0d10e24c84ddc19bd3df6634b3bebc2ad060a --- /dev/null +++ b/samples/different_styles/chinese_scale.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bafa09570c3f9413ddc6e2f643e935ec6932e0e5a20c070f1a617ace53e6f08d +size 23100 diff --git a/samples/different_styles/dorian_1.wav b/samples/different_styles/dorian_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..43fff5448aa6534540147fd2f9045b4dadf39871 --- /dev/null +++ b/samples/different_styles/dorian_1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8913ce841a7b2e8b6eed4c4ffc87da3e22478908fecf8d2965d46790c44d29de +size 2364108 diff --git a/samples/different_styles/dorian_2.wav b/samples/different_styles/dorian_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..02a2b6a05e66a3631ff3d57ef1820365834df828 --- /dev/null +++ b/samples/different_styles/dorian_2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8746fe2cabc9997ace5584e49dc42d96f71f30fe869eed286310e71ffc1442ed +size 2364108 diff --git a/samples/different_styles/dorian_scale.jpg b/samples/different_styles/dorian_scale.jpg new file mode 100644 index 0000000000000000000000000000000000000000..604cd0d906592469b94e7a5fa0904fc16795a97f --- /dev/null +++ b/samples/different_styles/dorian_scale.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78380800a624c05b867e63f92b13f27008b4ca5217d622849a79fbe7f2c3a58b +size 26351 diff --git a/samples/diy_examples/example1/example1.jpg b/samples/diy_examples/example1/example1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1817ef2b7cfe25ecdeeb3f330c6772b063d6a211 --- /dev/null +++ b/samples/diy_examples/example1/example1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e93fb967cd7b52ea6435d775f9533053cc917b90745b3a56eedc4e4a16b9dc27 +size 80422 diff --git a/samples/diy_examples/example1/example1.npy b/samples/diy_examples/example1/example1.npy new file mode 100644 index 0000000000000000000000000000000000000000..341c0740f47d5ff51a4b48eadc3fd1ba9651d0ae Binary files /dev/null and b/samples/diy_examples/example1/example1.npy differ diff --git a/samples/diy_examples/example1/example_1_mel.wav b/samples/diy_examples/example1/example_1_mel.wav new file mode 100644 index 0000000000000000000000000000000000000000..7b279f604b1cfefbd98a72acc0c1068a641c86ba --- /dev/null +++ b/samples/diy_examples/example1/example_1_mel.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99de59651e0df5796d02cd08eb3980b09c11319f52d4d4c074a61cfc9f6f4e06 +size 2626764 diff --git a/samples/diy_examples/example1/sample1.mid b/samples/diy_examples/example1/sample1.mid new file mode 100644 index 0000000000000000000000000000000000000000..bd85386415644767c3239663970b6df26b654bdf Binary files /dev/null and b/samples/diy_examples/example1/sample1.mid differ diff --git a/samples/diy_examples/example1/sample1.wav b/samples/diy_examples/example1/sample1.wav new file mode 100644 index 0000000000000000000000000000000000000000..2424674521564a3206ed1ade75faf5912066d9db --- /dev/null +++ b/samples/diy_examples/example1/sample1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d81863d650333e13918c6f39b550fc301e268370c9f257ae9a03c4d0cb24708 +size 2626764 diff --git a/samples/diy_examples/example2/example2.jpg b/samples/diy_examples/example2/example2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4f758687e062b1d70cda2a42137be5dd5002a516 --- /dev/null +++ b/samples/diy_examples/example2/example2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:136c3f4e30ac8c1a3860f689f82f0f59e2f2c1666671113d193ba4413163c864 +size 87430 diff --git a/samples/diy_examples/example2/example2.mid b/samples/diy_examples/example2/example2.mid new file mode 100644 index 0000000000000000000000000000000000000000..8b4df24b1be2d570379974985fd2d3737a763174 Binary files /dev/null and b/samples/diy_examples/example2/example2.mid differ diff --git a/samples/diy_examples/example2/example2.npy b/samples/diy_examples/example2/example2.npy new file mode 100644 index 0000000000000000000000000000000000000000..d988a4c733ba4801ec59b82ae0c15d0e28a2f905 Binary files /dev/null and b/samples/diy_examples/example2/example2.npy differ diff --git a/samples/diy_examples/example2/example_2_mel.wav b/samples/diy_examples/example2/example_2_mel.wav new file mode 100644 index 0000000000000000000000000000000000000000..ed67a8b610939ab66988c17887a34a97f9b7862d --- /dev/null +++ b/samples/diy_examples/example2/example_2_mel.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f545731a0d7b01b73df05fe2c5e1aba995365e685b276c81a3073d7dc054929b +size 2364108 diff --git a/samples/diy_examples/example2/sample1.mid b/samples/diy_examples/example2/sample1.mid new file mode 100644 index 0000000000000000000000000000000000000000..ab0aaae10c65ba748bd6c5f7a1f02987c0cd51bc Binary files /dev/null and b/samples/diy_examples/example2/sample1.mid differ diff --git a/samples/diy_examples/example2/sample1.wav b/samples/diy_examples/example2/sample1.wav new file mode 100644 index 0000000000000000000000000000000000000000..807079ce13e1af6e7927044d7e5ec446ae22e336 --- /dev/null +++ b/samples/diy_examples/example2/sample1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae68eb033e06bc43f9c26eceaed67bac4745a481c42492a66ad8400ca56b3a9d +size 2364108 diff --git a/samples/diy_examples/example3/example3.jpg b/samples/diy_examples/example3/example3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9da8c41b79887955a4933dba653826f85aad63fa --- /dev/null +++ b/samples/diy_examples/example3/example3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45deba5c0642d156ba05df0f5ebacda82cfc3d786704560278fcefec9f94c4a0 +size 101831 diff --git a/samples/diy_examples/example3/example3.mid b/samples/diy_examples/example3/example3.mid new file mode 100644 index 0000000000000000000000000000000000000000..5216951964b117ca02104e155e737b7d6b6dc53c Binary files /dev/null and b/samples/diy_examples/example3/example3.mid differ diff --git a/samples/diy_examples/example3/example3.npy b/samples/diy_examples/example3/example3.npy new file mode 100644 index 0000000000000000000000000000000000000000..65780f67cfedb735b6fbf11132a6ae33d71d3013 Binary files /dev/null and b/samples/diy_examples/example3/example3.npy differ diff --git a/samples/diy_examples/example3/example_3_mel.wav b/samples/diy_examples/example3/example_3_mel.wav new file mode 100644 index 0000000000000000000000000000000000000000..7f13b8aa125923bf8a3733ab1466e2900a3621c8 --- /dev/null +++ b/samples/diy_examples/example3/example_3_mel.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a1206053ff85d6ccb5449206a2395bf220dcd88c5cad430eccdefea9c4956d9 +size 2364108 diff --git a/samples/diy_examples/example3/sample1.mid b/samples/diy_examples/example3/sample1.mid new file mode 100644 index 0000000000000000000000000000000000000000..5e8e3593d1458c68ab301ae36b800a89e5afeb29 Binary files /dev/null and b/samples/diy_examples/example3/sample1.mid differ diff --git a/samples/diy_examples/example3/sample1.wav b/samples/diy_examples/example3/sample1.wav new file mode 100644 index 0000000000000000000000000000000000000000..659e97eaceac9022ac37f4663360d39be7cc756a --- /dev/null +++ b/samples/diy_examples/example3/sample1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75c4e5568759f78e30f346f6a928ea839063259068d9628dea5b47db249a1361 +size 2364108 diff --git a/samples/diy_examples/example4/example4.jpg b/samples/diy_examples/example4/example4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..036f138c273a02b9c7755987aa52cb818122768a --- /dev/null +++ b/samples/diy_examples/example4/example4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:974035fcdc0f415c6cc41acb7ffd013e5499930169c156f65ed645ed5f24eccf +size 77774 diff --git a/samples/diy_examples/example4/example4.npy b/samples/diy_examples/example4/example4.npy new file mode 100644 index 0000000000000000000000000000000000000000..334289164de39ad1283ec21e41ddfc75218ea648 Binary files /dev/null and b/samples/diy_examples/example4/example4.npy differ diff --git a/samples/diy_examples/example4/example_4_mel.wav b/samples/diy_examples/example4/example_4_mel.wav new file mode 100644 index 0000000000000000000000000000000000000000..4b53f0f93530a7069302bd6068df8ec3d0f54f9d --- /dev/null +++ b/samples/diy_examples/example4/example_4_mel.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d71a72473f89275da9f3947bb702ea037f17938cb71ccb504aa693b833a90212 +size 2364108 diff --git a/samples/diy_examples/example4/sample1.mid b/samples/diy_examples/example4/sample1.mid new file mode 100644 index 0000000000000000000000000000000000000000..df3c989fa9d75069fd030bc08e0d531e212cd763 Binary files /dev/null and b/samples/diy_examples/example4/sample1.mid differ diff --git a/samples/diy_examples/example4/sample1.wav b/samples/diy_examples/example4/sample1.wav new file mode 100644 index 0000000000000000000000000000000000000000..a6bf2312a42fbe9044be9e0069ee3c7cfe9c7b5a --- /dev/null +++ b/samples/diy_examples/example4/sample1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:666b08d944c809662f3781f373bbfbd83af18925d6acc9004944598da532aa09 +size 2364108 diff --git a/samples/diy_examples/rhythm_plot_1.png b/samples/diy_examples/rhythm_plot_1.png new file mode 100644 index 0000000000000000000000000000000000000000..bfcb7d5623206b2441c4b890e95e35a68465edf9 --- /dev/null +++ b/samples/diy_examples/rhythm_plot_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe87b8025fb47e2a41cae493aeb69bc9690784349e646d1db80bf2e100a47632 +size 11844 diff --git a/samples/diy_examples/rhythm_plot_2.png b/samples/diy_examples/rhythm_plot_2.png new file mode 100644 index 0000000000000000000000000000000000000000..bc0ceda8f6bdda7b24635d89f061fee119f89691 --- /dev/null +++ b/samples/diy_examples/rhythm_plot_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13e5ee8c7bad342216a57b22486680022112a605e1b45e6b951b148daab6152d +size 10835 diff --git a/samples/diy_examples/rhythm_plot_3.png b/samples/diy_examples/rhythm_plot_3.png new file mode 100644 index 0000000000000000000000000000000000000000..961d8753080221d7875b51708f7088e2926cdece --- /dev/null +++ b/samples/diy_examples/rhythm_plot_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b57d1e070f9cd784ffb505c2c0d05593ea9ffd3cf07a718ead783645ce3562c +size 12093 diff --git a/samples/diy_examples/rhythm_plot_4.png b/samples/diy_examples/rhythm_plot_4.png new file mode 100644 index 0000000000000000000000000000000000000000..bfcb7d5623206b2441c4b890e95e35a68465edf9 --- /dev/null +++ b/samples/diy_examples/rhythm_plot_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe87b8025fb47e2a41cae493aeb69bc9690784349e646d1db80bf2e100a47632 +size 11844 diff --git a/samples/diy_examples/rhythm_plot_default.png b/samples/diy_examples/rhythm_plot_default.png new file mode 100644 index 0000000000000000000000000000000000000000..9d258ef3cf18ae16a480b9f1997134356b41e4e6 --- /dev/null +++ b/samples/diy_examples/rhythm_plot_default.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:931d7148bbb618c0befbab6d0c26fa3d8b65acdc87ca71b94686404d3b389cb3 +size 11216 diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..535b969dee506b444a93a6421142452abd0f47b2 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,45 @@ +import torch +import json +import os +from datetime import datetime +from torch.utils.data import DataLoader +from torch.optim import Optimizer +from .learner import DiffproLearner + + +class TrainConfig: + + model: torch.nn.Module + train_dl: DataLoader + val_dl: DataLoader + optimizer: Optimizer + + def __init__(self, params, param_scheduler, output_dir) -> None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.params = params + self.param_scheduler = param_scheduler + self.output_dir = output_dir + + def train(self): + # collect and display total parameters + total_parameters = sum( + p.numel() for p in self.model.parameters() if p.requires_grad + ) + print(f"Total parameters: {total_parameters}") + + # dealing with the output storing + output_dir = self.output_dir + if os.path.exists(f"{output_dir}/chkpts/weights.pt"): + print("Checkpoint already exists.") + if input("Resume training? (y/n)") != "y": + return + else: + output_dir = f"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}" + print(f"Creating new log folder as {output_dir}") + + # prepare the learner structure and parameters + learner = DiffproLearner( + output_dir, self.model, self.train_dl, self.val_dl, self.optimizer, + self.params, self.param_scheduler + ) + learner.train(max_epoch=self.params.max_epoch) diff --git a/train/__pycache__/__init__.cpython-39.pyc b/train/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bfd6354b8745c9dd31372b585d2769b58c515e9 Binary files /dev/null and b/train/__pycache__/__init__.cpython-39.pyc differ diff --git a/train/__pycache__/learner.cpython-39.pyc b/train/__pycache__/learner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14829af006935184e207fdac73839197af8d5a25 Binary files /dev/null and b/train/__pycache__/learner.cpython-39.pyc differ diff --git a/train/__pycache__/train_config.cpython-39.pyc b/train/__pycache__/train_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0588530e141a4d1ca0b80c7fe20b44c0cdb2b98 Binary files /dev/null and b/train/__pycache__/train_config.cpython-39.pyc differ diff --git a/train/__pycache__/train_params.cpython-39.pyc b/train/__pycache__/train_params.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6423405379640e108400c53c89db11e8451057b3 Binary files /dev/null and b/train/__pycache__/train_params.cpython-39.pyc differ diff --git a/train/learner.py b/train/learner.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6a9a7cfd5ca60366a81cb346ef4b34008fe851 --- /dev/null +++ b/train/learner.py @@ -0,0 +1,211 @@ +import torch +import json +import torch.nn as nn +from tqdm import tqdm +from torch.utils.tensorboard.writer import SummaryWriter +from typing import Optional +import os + + +def nested_map(struct, map_fn): + """This is for trasfering into cuda device""" + if isinstance(struct, tuple): + return tuple(nested_map(x, map_fn) for x in struct) + if isinstance(struct, list): + return [nested_map(x, map_fn) for x in struct] + if isinstance(struct, dict): + return {k: nested_map(v, map_fn) for k, v in struct.items()} + return map_fn(struct) + + +class DiffproLearner: + def __init__( + self, output_dir, model, train_dl, val_dl, optimizer, params + ): + # model output + self.output_dir = output_dir + self.log_dir = f"{output_dir}/logs" + self.checkpoint_dir = f"{output_dir}/chkpts" + # model (architecture and loss) + self.model = model + # data loader + self.train_dl = train_dl + self.val_dl = val_dl + # optimizer + self.optimizer = optimizer + # what is this ???? + self.params = params + # current time recoder + self.step = 0 + self.epoch = 0 + self.grad_norm = 0. + # other information + self.summary_writer = None + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.autocast = torch.cuda.amp.autocast(enabled=params.fp16) + self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16) + + self.best_val_loss = torch.tensor([1e10], device=self.device) + + # restore if directory exists + if os.path.exists(self.output_dir): + self.restore_from_checkpoint() + else: + os.makedirs(self.output_dir) + os.makedirs(self.log_dir) + os.makedirs(self.checkpoint_dir) + with open(f"{output_dir}/params.json", "w") as params_file: + json.dump(self.params, params_file) + + print(json.dumps(self.params, sort_keys=True, indent=4)) + + def _write_summary(self, losses: dict, scheduled_params: Optional[dict], type): + """type: train or val""" + summary_losses = losses + summary_losses["grad_norm"] = self.grad_norm + if scheduled_params is not None: + for k, v in scheduled_params.items(): + summary_losses[f"sched_{k}"] = v + writer = self.summary_writer or SummaryWriter( + self.log_dir, purge_step=self.step + ) + writer.add_scalars(type, summary_losses, self.step) + writer.flush() + self.summary_writer = writer + + def state_dict(self): + # state dictionary + model_state = self.model.state_dict() + return { + "step": self.step, + "epoch": self.epoch, + "model": + { + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in model_state.items() + }, + "optimizer": + { + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in self.optimizer.state_dict().items() + }, + "scaler": self.scaler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.epoch = state_dict["epoch"] + self.model.load_state_dict(state_dict["model"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) + self.scaler.load_state_dict(state_dict["scaler"]) + + def restore_from_checkpoint(self, fname="weights"): + try: + fpath = f"{self.checkpoint_dir}/{fname}.pt" + checkpoint = torch.load(fpath) + self.load_state_dict(checkpoint) + print(f"Restored from checkpoint {fpath} --> {fname}-{self.epoch}.pt!") + return True + except FileNotFoundError: + print("No checkpoint found. Starting from scratch...") + return False + + def _link_checkpoint(self, save_name, link_fpath): + if os.path.islink(link_fpath): + os.unlink(link_fpath) + os.symlink(save_name, link_fpath) + + def save_to_checkpoint(self, fname="weights", is_best=False): + save_name = f"{fname}-{self.epoch}.pt" + save_fpath = f"{self.checkpoint_dir}/{save_name}" + link_best_fpath = f"{self.checkpoint_dir}/{fname}_best.pt" + link_fpath = f"{self.checkpoint_dir}/{fname}.pt" + torch.save(self.state_dict(), save_fpath) + self._link_checkpoint(save_name, link_fpath) + if is_best: + self._link_checkpoint(save_name, link_best_fpath) + + def train(self, max_epoch=None): + self.model.train() + + while True: + self.epoch = self.step // len(self.train_dl) + if max_epoch is not None and self.epoch >= max_epoch: + return + + for batch in tqdm(self.train_dl, desc=f"Epoch {self.epoch}"): + #print("type of batch:", type(batch)) + batch = nested_map( + batch, lambda x: x.to(self.device) + if isinstance(x, torch.Tensor) else x + ) + #print("type of batch:", type(batch)) + losses, scheduled_params = self.train_step(batch) + # check NaN + for loss_value in list(losses.values()): + if isinstance(loss_value, + torch.Tensor) and torch.isnan(loss_value).any(): + raise RuntimeError( + f"Detected NaN loss at step {self.step}, epoch {self.epoch}" + ) + if self.step % 50 == 0: + self._write_summary(losses, scheduled_params, "train") + if self.step % 5000 == 0 and self.step != 0 \ + and self.epoch != 0: + self.valid() + self.step += 1 + + # valid + self.valid() + + def valid(self): + # self.model.eval() + losses = None + for batch in self.val_dl: + batch = nested_map( + batch, lambda x: x.to(self.device) if isinstance(x, torch.Tensor) else x + ) + current_losses, _ = self.val_step(batch) + losses = losses or current_losses + for k, v in current_losses.items(): + losses[k] += v + assert losses is not None + for k, v in losses.items(): + losses[k] /= len(self.val_dl) + self._write_summary(losses, None, "val") + + if self.best_val_loss >= losses["loss"]: + self.best_val_loss = losses["loss"] + self.save_to_checkpoint(is_best=True) + else: + self.save_to_checkpoint(is_best=False) + + def train_step(self, batch): + # people say this is the better way to set zero grad + # instead of self.optimizer.zero_grad() + for param in self.model.parameters(): + param.grad = None + + # here forward the model + with self.autocast: + scheduled_params = None + loss_dict = self.model.get_loss_dict(batch, self.step) + + loss = loss_dict["loss"] + self.scaler.scale(loss).backward() + self.scaler.unscale_(self.optimizer) + self.grad_norm = nn.utils.clip_grad.clip_grad_norm_( + self.model.parameters(), self.params.max_grad_norm or 1e9 + ) + self.scaler.step(self.optimizer) + self.scaler.update() + return loss_dict, scheduled_params + + def val_step(self, batch): + with torch.no_grad(): + with self.autocast: + + scheduled_params = None + loss_dict = self.model.get_loss_dict(batch, self.step) + + return loss_dict, scheduled_params diff --git a/train/train_config.py b/train/train_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0efe1e2524af469e0d9466fbd67c1b4106121c24 --- /dev/null +++ b/train/train_config.py @@ -0,0 +1,45 @@ +from . import * + +import sys +import os + +# Determine the absolute path to the external folder +current_directory = os.path.dirname(os.path.abspath(__file__)) +external_directory = os.path.abspath(os.path.join(current_directory, '../data')) + +# Add the external folder to sys.path +sys.path.append(external_directory) + +# Now you can import the external module +from data_utils import load_datasets, create_train_valid_dataloaders +from model import init_ldm_model, init_diff_pro_sdf + + +class LdmTrainConfig(TrainConfig): + + def __init__(self, params, output_dir, mode, + mask_background, multi_phrase_label, random_pitch_aug, debug_mode=False) -> None: + super().__init__(params, None, output_dir) + self.debug_mode = debug_mode + #self.use_autoreg_cond = use_autoreg_cond + #self.use_external_cond = use_external_cond + self.mask_background = mask_background + self.multi_phrase_label = multi_phrase_label + self.random_pitch_aug = random_pitch_aug + + # create model + self.ldm_model = init_ldm_model(mode, params, debug_mode) + self.model = init_diff_pro_sdf(self.ldm_model, params, self.device) + + # Create dataloader + load_first_n = 10 if self.debug_mode else None + train_set, valid_set = load_datasets( + mode, multi_phrase_label, random_pitch_aug, + mask_background, load_first_n + ) + self.train_dl, self.val_dl = create_train_valid_dataloaders(params.batch_size, train_set, valid_set) + + # Create optimizer + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=params.learning_rate + ) diff --git a/train/train_params.py b/train/train_params.py new file mode 100644 index 0000000000000000000000000000000000000000..79b227cf33a0c70292f7ceefe82abae867bd2f6f --- /dev/null +++ b/train/train_params.py @@ -0,0 +1,97 @@ +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + def override(self, attrs): + if isinstance(attrs, dict): + self.__dict__.update(**attrs) + elif isinstance(attrs, (list, tuple, set)): + for attr in attrs: + self.override(attr) + elif attrs is not None: + raise NotImplementedError + return self + + + +params_chord = AttrDict( + # Training params + batch_size=16, + max_epoch=10, + learning_rate=5e-5, + max_grad_norm=10, + fp16=True, + + # unet + in_channels=2, + out_channels=2, + channels=64, + attention_levels=[2, 3], + n_res_blocks=2, + channel_multipliers=[1, 2, 4, 4], + n_heads=4, + tf_layers=1, + d_cond=12, + + # ldm + linear_start=0.00085, + linear_end=0.0120, + n_steps=1000, + latent_scaling_factor=0.18215 +) + + + +params_chord_cond = AttrDict( + # Training params + batch_size=16, + max_epoch=10, + learning_rate=5e-5, + max_grad_norm=10, + fp16=True, + + # unet + in_channels=4, + out_channels=2, + channels=64, + attention_levels=[2, 3], + n_res_blocks=2, + channel_multipliers=[1, 2, 4, 4], + n_heads=4, + tf_layers=1, + d_cond=2, + + # ldm + linear_start=0.00085, + linear_end=0.0120, + n_steps=1000, + latent_scaling_factor=0.18215 +) + + +params_chord_lsh_cond = AttrDict( + # Training params + batch_size=16, + max_epoch=10, + learning_rate=5e-5, + max_grad_norm=10, + fp16=True, + + # unet + in_channels=6, + out_channels=2, + channels=64, + attention_levels=[2, 3], + n_res_blocks=2, + channel_multipliers=[1, 2, 4, 4], + n_heads=4, + tf_layers=1, + d_cond=2, + + # ldm + linear_start=0.00085, + linear_end=0.0120, + n_steps=1000, + latent_scaling_factor=0.18215 +) \ No newline at end of file diff --git a/train_cond_w_rhythm_onset.ipynb b/train_cond_w_rhythm_onset.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9f02d2814bf74cf41d810b755ccb9e11520dd648 --- /dev/null +++ b/train_cond_w_rhythm_onset.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from torch.optim import Optimizer\n", + "import os\n", + "from datetime import datetime\n", + "from train.learner import DiffproLearner\n", + "\n", + "class TrainConfig:\n", + "\n", + " model: torch.nn.Module\n", + " train_dl: DataLoader\n", + " val_dl: DataLoader\n", + " optimizer: Optimizer\n", + "\n", + " def __init__(self, params, param_scheduler, output_dir) -> None:\n", + " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " self.params = params\n", + " self.param_scheduler = param_scheduler\n", + " self.output_dir = output_dir\n", + "\n", + " def train(self):\n", + " # collect and display total parameters\n", + " total_parameters = sum(\n", + " p.numel() for p in self.model.parameters() if p.requires_grad\n", + " )\n", + " print(f\"Total parameters: {total_parameters}\")\n", + "\n", + " # dealing with the output storing\n", + " output_dir = self.output_dir\n", + " if os.path.exists(f\"{output_dir}/chkpts/weights.pt\"):\n", + " print(\"Checkpoint already exists.\")\n", + " if input(\"Resume training? (y/n)\") != \"y\":\n", + " return\n", + " else:\n", + " output_dir = f\"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}\"\n", + " print(f\"Creating new log folder as {output_dir}\")\n", + "\n", + " # prepare the learner structure and parameters\n", + " learner = DiffproLearner(\n", + " output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,\n", + " self.params\n", + " )\n", + " learner.train(max_epoch=self.params.max_epoch)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from model import init_ldm_model, init_diff_pro_sdf\n", + "from data.dataset_loading import load_datasets, create_dataloader\n", + "\n", + "WITH_RHYTHM = \"onset\"\n", + "\n", + "class LdmTrainConfig(TrainConfig):\n", + "\n", + " def __init__(self, params, output_dir, debug_mode=False) -> None:\n", + " super().__init__(params, None, output_dir)\n", + " self.debug_mode = debug_mode\n", + " #self.use_autoreg_cond = use_autoreg_cond\n", + " #self.use_external_cond = use_external_cond\n", + " #self.mask_background = mask_background\n", + " #self.random_pitch_aug = random_pitch_aug\n", + "\n", + " # create model\n", + " self.ldm_model = init_ldm_model(params, debug_mode)\n", + " self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)\n", + "\n", + " # Create dataloader\n", + " train_set = load_datasets(with_rhythm=WITH_RHYTHM)\n", + " self.train_dl = create_dataloader(params.batch_size, train_set)\n", + " self.val_dl = create_dataloader(params.batch_size, train_set) # we temporarily use train_set for validation\n", + "\n", + " # Create optimizer4\n", + " self.optimizer = torch.optim.Adam(\n", + " self.model.parameters(), lr=params.learning_rate\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/music/chord_trainer/train/learner.py:45: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", + " self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)\n", + "/home/music/chord_trainer/train/learner.py:46: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", + " self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total parameters: 36755330\n", + "Creating new log folder as results/test/09-13_171940\n", + "{\n", + " \"attention_levels\": [\n", + " 2,\n", + " 3\n", + " ],\n", + " \"batch_size\": 16,\n", + " \"channel_multipliers\": [\n", + " 1,\n", + " 2,\n", + " 4,\n", + " 4\n", + " ],\n", + " \"channels\": 64,\n", + " \"d_cond\": 2,\n", + " \"fp16\": true,\n", + " \"in_channels\": 4,\n", + " \"latent_scaling_factor\": 0.18215,\n", + " \"learning_rate\": 5e-05,\n", + " \"linear_end\": 0.012,\n", + " \"linear_start\": 0.00085,\n", + " \"max_epoch\": 10,\n", + " \"max_grad_norm\": 10,\n", + " \"n_heads\": 4,\n", + " \"n_res_blocks\": 2,\n", + " \"n_steps\": 1000,\n", + " \"out_channels\": 2,\n", + " \"tf_layers\": 1\n", + "}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|██████████| 1141/1141 [00:51<00:00, 22.08it/s]\n", + "Epoch 1: 100%|██████████| 1141/1141 [00:50<00:00, 22.43it/s]\n", + "Epoch 2: 100%|██████████| 1141/1141 [00:47<00:00, 24.02it/s]\n", + "Epoch 3: 100%|██████████| 1141/1141 [00:47<00:00, 24.07it/s]\n", + "Epoch 4: 100%|██████████| 1141/1141 [01:04<00:00, 17.70it/s]\n", + "Epoch 5: 100%|██████████| 1141/1141 [00:50<00:00, 22.42it/s]\n", + "Epoch 6: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n", + "Epoch 7: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n", + "Epoch 8: 100%|██████████| 1141/1141 [01:05<00:00, 17.38it/s]\n", + "Epoch 9: 100%|██████████| 1141/1141 [00:49<00:00, 22.83it/s]\n" + ] + } + ], + "source": [ + "\n", + "# Import necessary libraries\n", + "from train.train_params import params_chord_cond, params_chord\n", + "import os\n", + "\n", + "# Set the argument values directly\n", + "args = {\n", + " 'output_dir': 'results',\n", + " 'uniform_pitch_shift': False,\n", + " # 'debug': False,\n", + " # 'data_source': \"lmd\",\n", + " # 'load_chkpt_from': None,\n", + " # 'dataset_path': \"data/lmd_sample/no_drum_sample\",\n", + "}\n", + "\n", + "# Determine random pitch augmentation\n", + "random_pitch_aug = not args['uniform_pitch_shift']\n", + "\n", + "# Generate the filename based on argument settings\n", + "fn = 'test'\n", + "\n", + "# Set the output directory\n", + "output_dir = os.path.join(args['output_dir'], fn)\n", + "\n", + "# Create the training configuration\n", + "config = LdmTrainConfig(params_chord_cond, output_dir)\n", + "\n", + "config.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "music_demo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utility.ipynb b/utility.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9ef29f12074bed96dac825de4080e6506593f3b5 --- /dev/null +++ b/utility.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/music/.conda/envs/music_demo/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(256, 320) (514, 1880, 3)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from app import * \n", + "\n", + "test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n", + " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n", + " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n", + " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1))])\n", + "\n", + "rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm]\n", + "\n", + "chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)\n", + "\n", + "chd_roll = circular_extend(chd_roll)\n", + "chd_roll = -chd_roll-1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "example3 = np.load(\"samples/diy_examples/example3.npy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "example3[:2,:,:] = example3[2:4,:,:]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"samples/diy_examples/example3.npy\", example3)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_chd_roll.min(axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "example0 = np.load(\"samples/diy_examples/example3.npy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "example0[2,:,:] = np.min(example0[2:4,:,:], axis=0)\n", + "example0[3,:,:] = np.min(example0[2:4,:,:], axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n", + " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n", + " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n", + " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n", + " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example0[2,:,:].min(axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"samples/diy_examples/example3.npy\", example0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "music_demo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/values_below_threshold2.npy b/values_below_threshold2.npy new file mode 100644 index 0000000000000000000000000000000000000000..fc12ed54bb62f03ccc479731cbaeff8270c87898 Binary files /dev/null and b/values_below_threshold2.npy differ