diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..7b9d33741ed04ce8ac5299ff580d34195ae35713
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..f2c7676a5e74c09421a40065fc4fc09398c04feb
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,6 @@
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
diff --git a/Aptfile b/Aptfile
new file mode 100644
index 0000000000000000000000000000000000000000..78e026b394005cf20d0877c5a1d66bed8e191edd
--- /dev/null
+++ b/Aptfile
@@ -0,0 +1,2 @@
+lilypond
+fluidsynth
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a44cf5baca4c0750592a7f58f387b4d15bdd2e7e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,18 @@
+---
+title: Interactive Symbolic Music Demo
+emoji: 🖼
+colorFrom: purple
+colorTo: red
+sdk: gradio
+sdk_version: 4.42.0
+app_file: app.py
+pinned: false
+python_version: 3.9.19
+license: mit
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+Please find mido.MidoFile inside the pretty_midi package, and set all arg "clip" to clip=True
+
+use set_seed(42) in sampler_sdf.py, generation result from chord slice (index = 2) is a good example (a wrong note is shifted to a correct one)
\ No newline at end of file
diff --git a/__pycache__/app.cpython-39.pyc b/__pycache__/app.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6aff3ff983bea3bbd41aec052861a96665ba366
Binary files /dev/null and b/__pycache__/app.cpython-39.pyc differ
diff --git a/__pycache__/learner.cpython-39.pyc b/__pycache__/learner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8fc32d512444df684b54631f94c7152bf90d8bd
Binary files /dev/null and b/__pycache__/learner.cpython-39.pyc differ
diff --git a/__pycache__/params.cpython-39.pyc b/__pycache__/params.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24afc4dfebe39dbbfba6830b669704c0fd12040c
Binary files /dev/null and b/__pycache__/params.cpython-39.pyc differ
diff --git a/__pycache__/train_params.cpython-39.pyc b/__pycache__/train_params.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..875f8c1b15a3d93333550a3d86c21e1f03077a23
Binary files /dev/null and b/__pycache__/train_params.cpython-39.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c89499fa849543f32ab04c6433b625262cb04abd
--- /dev/null
+++ b/app.py
@@ -0,0 +1,508 @@
+import gradio as gr
+import pretty_midi
+import matplotlib.pyplot as plt
+import numpy as np
+import soundfile as sf
+import cv2
+import imageio
+
+import sys
+import subprocess
+import os
+import torch
+from model import init_ldm_model
+from model.model_sdf import Diffpro_SDF
+from model.sampler_sdf import SDFSampler
+
+import pickle
+from train.train_params import params_chord_lsh_cond
+from generation.gen_utils import *
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
+chord_list = list(CHORD_DICTIONARY.keys())
+
+def get_shape(file_path):
+ if file_path.endswith('.jpg'):
+ img = cv2.imread(file_path)
+ return img.shape # (height, width, channels)
+
+ elif file_path.endswith('.mp4'):
+ vid = imageio.get_reader(file_path)
+ return vid.get_meta_data()['size'] # (width, height)
+
+ else:
+ raise ValueError("Unsupported file type")
+
+# Function to convert MIDI to WAV
+def midi_to_wav(midi, output_file):
+ # Synthesize the waveform from the MIDI using pretty_midi
+ audio_data = midi.fluidsynth()
+
+ # Write the waveform to a WAV file
+ sf.write(output_file, audio_data, samplerate=44100)
+
+def update_musescore_image(selected_prompt):
+ # Logic to return the correct image file based on the selected prompt
+ if selected_prompt == "example 1":
+ return "samples/diy_examples/example1/example1.jpg"
+ elif selected_prompt == "example 2":
+ return "samples/diy_examples/example2/example2.jpg"
+ elif selected_prompt == "example 3":
+ return "samples/diy_examples/example3/example3.jpg"
+ elif selected_prompt == "example 4":
+ return "samples/diy_examples/example4/example4.jpg"
+ elif selected_prompt == "example 5":
+ return "samples/diy_examples/example5/example5.jpg"
+ elif selected_prompt == "example 6":
+ return "samples/diy_examples/example6/example6.jpg"
+
+
+# Model for generating music (example)
+def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
+
+ ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
+ model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
+ sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)
+
+ if mode=="example":
+ if prompt == "example 1":
+ background_condition = np.load("samples/diy_examples/example1/example1.npy")
+ tempo=70
+ elif prompt == "example 2":
+ background_condition = np.load("samples/diy_examples/example2/example2.npy")
+ elif prompt == "example 3":
+ background_condition = np.load("samples/diy_examples/example3/example3.npy")
+ elif prompt == "example 4":
+ background_condition = np.load("samples/diy_examples/example4/example4.npy")
+
+ background_condition = np.tile(background_condition, (num_samples,1,1,1))
+ background_condition = torch.Tensor(background_condition).to(device)
+ else:
+ background_condition = np.tile(prompt, (num_samples,1,1,1))
+ background_condition = torch.Tensor(background_condition).to(device)
+
+ if rhythm_control!="Yes":
+ background_condition[:,0:2] = background_condition[:,2:4]
+ # generate samples
+ output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples,
+ same_noise_all_measure=False, X0EditFunc=X0EditFunc,
+ use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
+ rhythm_control=rhythm_control)
+ output_x = torch.clamp(output_x, min=0, max=1)
+ output_x = output_x.cpu().numpy()
+
+ # save samples
+ for i in range(num_samples):
+ full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
+ full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
+ full_lsh_roll = None
+ if background_condition.shape[1]>=6:
+ if background_condition[:,4:6,:,:].min()>=0:
+ full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
+ midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
+ # filename = f'DDIM_w_rhythm_onset_0to10_{i}_edit_x0_and_eps'+'.mid'
+ filename = f"output_{i}.mid"
+ save_midi(midi_file, filename)
+ subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
+
+ return 'output_0.mid', 'output_0.wav', midi_file
+
+# Function to visualize MIDI notes
+def visualize_midi(midi):
+ # Get piano roll from MIDI
+ roll = midi.get_piano_roll(fs=100)
+
+ # Plot the piano roll
+ plt.figure(figsize=(10, 4))
+ plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
+ plt.title("Piano Roll")
+ plt.xlabel("Time")
+ plt.ylabel("Pitch")
+ plt.colorbar()
+
+ # Save the plot as an image
+ output_image_path = "piano_roll.png"
+ plt.savefig(output_image_path)
+ return output_image_path
+
+def plot_rhythm(rhythm_str, label):
+ if rhythm_str=="null rhythm":
+ return None
+ fig, ax = plt.subplots(figsize=(6, 2))
+
+ # Ensure it's a 16-bit string
+ rhythm_str = rhythm_str[:16]
+
+ # Convert string to a list of 0s and 1s
+ rhythm = [0 if bit=="0" else 1 for bit in rhythm_str]
+
+ # Define the x axis for the 16 sixteenth notes
+ x = list(range(1, 17)) # 1 to 16 sixteenth notes
+
+ # Plot each note (1 as filled circle, 0 as empty circle)
+ for i, bit in enumerate(rhythm):
+ if bit == 1:
+ ax.scatter(i + 1, 1, color='black', s=100, label="Note" if i == 0 else "")
+ else:
+ ax.scatter(i + 1, 1, edgecolor='black', facecolor='none', s=100, label="Rest" if i == 0 else "")
+
+ # Distinguish groups of 4 using vertical dashed lines (no solid grid lines)
+ for i in range(4, 17, 4):
+ ax.axvline(x=i + 0.5, color='grey', linestyle='--')
+
+ # Remove solid vertical grid lines by setting the grid off
+ ax.grid(False)
+
+ # Formatting the plot
+ ax.set_xlim(0.5, 16.5)
+ ax.set_ylim(0.8, 1.2)
+ ax.set_xticks(x)
+ ax.set_yticks([])
+ ax.set_xlabel("16th Notes")
+ ax.set_title("Rhythm Pattern")
+
+ fig.savefig(f'samples/diy_examples/rhythm_plot_{label}.png')
+ plt.close(fig)
+ return f'samples/diy_examples/rhythm_plot_{label}.png'
+
+def adjust_rhythm_string(s):
+ # Truncate if longer than 16 characters
+ if len(s) > 16:
+ return s[:16]
+ # Pad with zeros if shorter than 16 characters
+ else:
+ return s.ljust(16, '0')
+def rhythm_string_to_array(s):
+ # Ensure the string is 16 characters long
+ s = s[:16].ljust(16, '0') # Truncate or pad with '0' to make it 16 characters
+ # Convert to numpy array, treating non-'0' as '1'
+ arr = np.array([1 if char != '0' else 0 for char in s], dtype=int)
+ arr = arr*np.array([3,1,2,1,3,1,2,1,3,1,2,1,3,1,2,1])
+ print(arr)
+ return arr
+
+# Gradio main function
+def generate_from_example(prompt):
+ midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control=False)
+ piano_roll_image = visualize_midi(midi)
+ return audio_output, piano_roll_image
+
+def generate_diy(m1_chord, m2_chord, m3_chord, m4_chord,
+ m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm, tempo):
+ print("\n\n\n",m1_chord,type(m1_chord), "\n\n\n")
+ test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[m1_chord], (16, 1)),
+ np.tile(CHORD_DICTIONARY[m2_chord], (16, 1)),
+ np.tile(CHORD_DICTIONARY[m3_chord], (16, 1)),
+ np.tile(CHORD_DICTIONARY[m4_chord], (16, 1))])
+ rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm]
+
+ chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)
+
+ chd_roll = circular_extend(chd_roll)
+ chd_roll = -chd_roll-1
+
+ real_chd_roll = chd_roll
+
+ melody_roll = -np.ones_like(chd_roll)
+
+ if "null rhythm" not in rhythms:
+ rhythm_full = []
+ for i in range(len(rhythms)):
+ rhythm = adjust_rhythm_string(rhythms[i])
+ rhythm = rhythm_string_to_array(rhythm)
+ rhythm_full.append(rhythm)
+ rhythm_full = np.concatenate(rhythm_full, axis=0)
+
+ onset_roll = test_chd_roll*rhythm_full[:, np.newaxis]
+ sustain_roll = np.zeros_like(onset_roll)
+ no_onset_pos = np.all(onset_roll == 0, axis=-1)
+ sustain_roll[no_onset_pos] = test_chd_roll[no_onset_pos]
+
+ real_chd_roll = np.concatenate([onset_roll[np.newaxis,:,:], sustain_roll[np.newaxis,:,:]], axis=0)
+ real_chd_roll = circular_extend(real_chd_roll)
+
+ background_condition = np.concatenate([real_chd_roll, chd_roll, melody_roll], axis=0)
+
+ midi_output, audio_output, midi = generate_music(background_condition, tempo, mode="diy")
+ piano_roll_image = visualize_midi(midi)
+ return midi_output, audio_output, piano_roll_image
+
+# Prompt list
+prompt_list = ["example 1", "example 2", "example 3", "example 4"]
+rhythm_list = ["null rhythm", "1010101010101010", "1011101010111010","1111101010111010","1010001010101010","1010101000101010"]
+
+
+custom_css = """
+.custom-row1 {
+ background-color: #fdebd0;
+ padding: 10px;
+ border-radius: 5px;
+}
+.custom-row2 {
+ background-color: #d1f2eb;
+ padding: 10px;
+ border-radius: 5px;
+}
+.custom-grey {
+ background-color: #f0f0f0;
+ padding: 10px;
+ border-radius: 5px;
+}
+.custom-purple {
+ background-color: #d7bde2;
+ padding: 10px;
+ border-radius: 5px;
+}
+.audio_waveform-container {
+ display: none !important;
+}
+"""
+
+
+with gr.Blocks(css=custom_css) as demo:
+ gr.Markdown("#
Efficient Fine-Grained Guidance for Diffusion-Based Symbolic Music Generation
")
+
+ gr.Markdown("
We introduce **Fine-Grained Guidance (FG)**, an efficient approach for symbolic music generation using **diffusion models**. Our method enhances guidance through:\
+ \n (1) Fine-grained conditioning during training,\
+ \n (2) Fine-grained control during the diffusion sampling process.\
+ \n In particular, **sampling control** ensures tonal accuracy in every generated sample, allowing our model to produce music with high precision, consistent rhythmic patterns,\
+ and even stylistic variations that align with user intent.")
+ gr.Markdown(" At the bottom of this page, we provide an interactive space for you to try our model by yourself! ")
+
+
+ gr.Markdown("\n\n\n")
+ gr.Markdown("# 1. Accompaniment Generation given Melody and Chord")
+ gr.Markdown(" In each example, the left column displays the melody provided as inputs to the model.\
+ The right column showcases music samples generated by the model.")
+
+ with gr.Column(elem_classes="custom-row1"):
+ gr.Markdown("## Example 1")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(" With the following melody as condition ")
+ example1_mel = gr.Audio(value="samples/diy_examples/example1/example_1_mel.wav", label="Melody", scale = 5)
+ with gr.Column():
+ gr.Markdown(" Generated Accompaniments ")
+ example1_audio = gr.Audio(value="samples/diy_examples/example1/sample1.wav", label="Generated Accompaniment", scale = 5)
+
+ with gr.Column(elem_classes="custom-row2"):
+ gr.Markdown("## Example 2")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(" With the following melody as condition ")
+ example1_mel = gr.Audio(value="samples/diy_examples/example2/example_2_mel.wav", label="Melody", scale = 5)
+ with gr.Column():
+ gr.Markdown(" Generated Accompaniments ")
+ example1_audio = gr.Audio(value="samples/diy_examples/example2/sample1.wav", label="Generated Accompaniment", scale = 5)
+
+ with gr.Column(elem_classes="custom-row1"):
+ gr.Markdown("## Example 3")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(" With the following melody as condition ")
+ example1_mel = gr.Audio(value="samples/diy_examples/example3/example_3_mel.wav", label="Melody", scale = 5)
+ with gr.Column():
+ gr.Markdown(" Generated Accompaniments ")
+ example1_audio = gr.Audio(value="samples/diy_examples/example3/sample1.wav", label="Generated Accompaniment", scale = 5)
+
+ with gr.Column(elem_classes="custom-row2"):
+ gr.Markdown("## Example 4")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(" With the following melody as condition ")
+ example1_mel = gr.Audio(value="samples/diy_examples/example4/example_4_mel.wav", label="Melody", scale = 5)
+ with gr.Column():
+ gr.Markdown(" Generated Accompaniments ")
+ example1_audio = gr.Audio(value="samples/diy_examples/example4/sample1.wav", label="Generated Accompaniment", scale = 5)
+
+ gr.HTML("")
+ gr.Markdown("# \n\n\n")
+ gr.Markdown("# 2. Style-Controlled Music Generation")
+ gr.Markdown("Our approach enables controllable stylization in music generation. The sampling control is able to\
+ ensure that all generated notes strictly adhere to the target musical style's scale.\
+ This allows the model to generate music in specific styles — even those that were not present in \
+ the training data.")
+ gr.Markdown(" Below, we demonstrate several examples of style-controlled music generation for:\
+ \n (1) Dorian Mode: (with scale being A-B-C-D-E-F#-G);\
+ \n (2) Chinese Style: (with scale being C-D-E-G-A). ")
+
+ with gr.Column(elem_classes="custom-row1"):
+ gr.Markdown("## Dorian Mode")
+ gr.Markdown(" The following are two examples generated by our method ")
+ with gr.Row():
+ with gr.Column(elem_classes="custom-grey"):
+ gr.Markdown(" Example 1 ")
+ example1_mel = gr.Audio(value="samples/different_styles/dorian_1.wav", scale = 5)
+ with gr.Column(elem_classes="custom-grey"):
+ gr.Markdown(" Example 2 ")
+ example1_audio = gr.Audio(value="samples/different_styles/dorian_2.wav", scale = 5)
+
+ with gr.Column(elem_classes="custom-row2"):
+ gr.Markdown("## Chinese Style")
+ gr.Markdown(" The following are two examples generated by our method ")
+ with gr.Row():
+ with gr.Column(elem_classes="custom-grey"):
+ gr.Markdown(" Example 1 ")
+ example1_mel = gr.Audio(value="samples/different_styles/chinese_1.wav", scale = 5)
+ with gr.Column(elem_classes="custom-grey"):
+ gr.Markdown(" Example 2 ")
+ example1_audio = gr.Audio(value="samples/different_styles/chinese_2.wav", scale = 5)
+
+ gr.HTML("")
+ gr.Markdown("\n\n\n")
+ gr.Markdown("# 3. Demonstrating the Effectiveness of Sampling Control by Comparison")
+
+ gr.Markdown(" We demonstrate the impact of sampling control in an **accompaniment generation** task, given a melody and chord progression.\
+ \n Each example generates accompaniments with and without sampling control using the same random seed, ensuring that the two results are comparable.\
+ \n Sampling control effectively removes or replaces harmonically conflicting notes, ensuring tonal consistency.\
+ \n We provide music sheets and audio files for both versions.")
+
+ gr.Markdown(" Comparison of the results indicates that sampling control not only eliminates out-of-key notes but also enhances \
+ the overall coherence and harmonic consistency of the accompaniments.\
+ This highlights the effectiveness of our approach in maintaining musical coherence. ")
+
+
+ with gr.Column(elem_classes="custom-row1"):
+ gr.Markdown("## Example 1")
+
+ with gr.Row(elem_classes="custom-grey"):
+ gr.Markdown(" With pre-defined melody and chord as follows")
+ with gr.Column(scale=2, min_width=10, ):
+ gr.Markdown("Melody Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=10, ):
+ gr.Markdown("Melody Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
+
+ gr.Markdown("## Generated Accompaniments")
+ with gr.Row(elem_classes="custom-grey"):
+ gr.Markdown(" Without sampling control")
+ with gr.Column(scale=2, min_width=300):
+ gr.Markdown("Music Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
+ gr.Markdown("\n\n\n")
+ with gr.Row(elem_classes="custom-grey"):
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("With sampling control")
+ with gr.Column(scale=2, min_width=300):
+ gr.Markdown("Music Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
+
+
+ with gr.Column(elem_classes="custom-row2"):
+ gr.Markdown("## Example 2")
+
+ with gr.Row(elem_classes="custom-grey"):
+ gr.Markdown(" With pre-defined melody and chord as follows")
+ with gr.Column(scale=2, min_width=10, ):
+ gr.Markdown("Melody Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=10, ):
+ gr.Markdown("Melody Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
+
+ gr.Markdown("## Generated Accompaniments")
+ with gr.Row(elem_classes="custom-grey"):
+ gr.Markdown(" Without sampling control")
+ with gr.Column(scale=2, min_width=300):
+ gr.Markdown("Music Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.wav", scale = 1, min_width=10)
+ gr.Markdown("\n\n\n")
+ with gr.Row(elem_classes="custom-grey"):
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("With sampling control")
+ with gr.Column(scale=2, min_width=300):
+ gr.Markdown("Music Sheet")
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ with gr.Column(scale=1, min_width=150):
+ gr.Markdown("Audio")
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_control.wav", scale = 1, min_width=10)
+
+ # with gr.Row():
+ # with gr.Column(scale=1, min_width=300, elem_classes="custom-row1"):
+ # gr.Markdown("## Example 1")
+ # gr.Markdown(" With pre-defined melody and chord as follows")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # # Audio component to play the audio
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
+
+ # gr.Markdown("## Generated Accompaniments")
+ # with gr.Row():
+ # with gr.Column(scale=1, min_width=150):
+ # gr.Markdown(" without sampling control")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
+ # with gr.Column(scale=1, min_width=150):
+ # gr.Markdown(" with sampling control")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
+ # with gr.Column(scale=1, min_width=300, elem_classes="custom-row2"):
+ # gr.Markdown("## Example 2")
+ # gr.Markdown(" With pre-defined melody and chord as follows")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # # Audio component to play the audio
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
+
+ # gr.Markdown("## Generated Accompaniments")
+ # with gr.Row():
+ # with gr.Column(scale=1, min_width=150):
+ # gr.Markdown(" without sampling control")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
+ # with gr.Column(scale=1, min_width=150):
+ # gr.Markdown(" with sampling control")
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
+
+
+
+
+
+ ''' Try to generate by users '''
+ gr.HTML("")
+ gr.Markdown("\n\n\n")
+ gr.Markdown("# 4. DIY in real time! ")
+ gr.Markdown(" Here is an interactive tool for you to try our model and generate by yourself.\
+ You can generate new accompaniments for given melody and chord conditions ")
+
+ gr.Markdown("### Currently this space is supported with Hugging Face CPU and on average,\
+ it takes about 15 seconds to generate a 4-measure music piece. However, if other users are generating\
+ music at the same time, one may enter a queue, which could slow down the process significantly.\
+ If that happens, feel free to refresh the page. We appreciate your patience and understanding.\
+ ")
+
+ with gr.Column(elem_classes="custom-purple"):
+ gr.Markdown("### Select an example to generate music given melody and chord condition")
+ with gr.Row():
+ with gr.Column():
+ prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
+ gr.Markdown("### This is the melody to be conditioned on:")
+ condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
+ prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)
+
+ with gr.Column():
+ generate_button = gr.Button("Generate")
+ gr.Markdown("### Generation results:")
+ audio_output = gr.Audio(label="Generated Music")
+ piano_roll_output = gr.Image(label="Generated Piano Roll")
+
+ generate_button.click(
+ fn=generate_from_example,
+ inputs=[prompt_selector],
+ outputs=[audio_output, piano_roll_output]
+ )
+
+
+# Launch Gradio interface
+if __name__ == "__main__":
+ demo.launch()
diff --git a/filter_data/filter_by_instrument.ipynb b/filter_data/filter_by_instrument.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..2a9b8472a793d457ca3bc130e26b80d1097cfc21
--- /dev/null
+++ b/filter_data/filter_by_instrument.ipynb
@@ -0,0 +1,353 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from midi_utils import is_timesig_44, gather_full_instr, gather_instr\n",
+ "from midi_utils import has_brass\n",
+ "from midi_utils import has_piano, has_string, has_guitar, has_drums\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Program number 0: Acoustic Grand Piano\n",
+ "Program number 1: Bright Acoustic Piano\n",
+ "Program number 2: Electric Grand Piano\n",
+ "Program number 3: Honky-tonk Piano\n",
+ "Program number 4: Electric Piano 1\n",
+ "Program number 5: Electric Piano 2\n",
+ "Program number 6: Harpsichord\n",
+ "Program number 7: Clavinet\n",
+ "Program number 8: Celesta\n",
+ "Program number 9: Glockenspiel\n",
+ "Program number 10: Music Box\n",
+ "Program number 11: Vibraphone\n",
+ "Program number 12: Marimba\n",
+ "Program number 13: Xylophone\n",
+ "Program number 14: Tubular Bells\n",
+ "Program number 15: Dulcimer\n",
+ "Program number 16: Drawbar Organ\n",
+ "Program number 17: Percussive Organ\n",
+ "Program number 18: Rock Organ\n",
+ "Program number 19: Church Organ\n",
+ "Program number 20: Reed Organ\n",
+ "Program number 21: Accordion\n",
+ "Program number 22: Harmonica\n",
+ "Program number 23: Tango Accordion\n",
+ "Program number 24: Acoustic Guitar (nylon)\n",
+ "Program number 25: Acoustic Guitar (steel)\n",
+ "Program number 26: Electric Guitar (jazz)\n",
+ "Program number 27: Electric Guitar (clean)\n",
+ "Program number 28: Electric Guitar (muted)\n",
+ "Program number 29: Overdriven Guitar\n",
+ "Program number 30: Distortion Guitar\n",
+ "Program number 31: Guitar Harmonics\n",
+ "Program number 32: Acoustic Bass\n",
+ "Program number 33: Electric Bass (finger)\n",
+ "Program number 34: Electric Bass (pick)\n",
+ "Program number 35: Fretless Bass\n",
+ "Program number 36: Slap Bass 1\n",
+ "Program number 37: Slap Bass 2\n",
+ "Program number 38: Synth Bass 1\n",
+ "Program number 39: Synth Bass 2\n",
+ "Program number 40: Violin\n",
+ "Program number 41: Viola\n",
+ "Program number 42: Cello\n",
+ "Program number 43: Contrabass\n",
+ "Program number 44: Tremolo Strings\n",
+ "Program number 45: Pizzicato Strings\n",
+ "Program number 46: Orchestral Harp\n",
+ "Program number 47: Timpani\n",
+ "Program number 48: String Ensemble 1\n",
+ "Program number 49: String Ensemble 2\n",
+ "Program number 50: Synth Strings 1\n",
+ "Program number 51: Synth Strings 2\n",
+ "Program number 52: Choir Aahs\n",
+ "Program number 53: Voice Oohs\n",
+ "Program number 54: Synth Choir\n",
+ "Program number 55: Orchestra Hit\n",
+ "Program number 56: Trumpet\n",
+ "Program number 57: Trombone\n",
+ "Program number 58: Tuba\n",
+ "Program number 59: Muted Trumpet\n",
+ "Program number 60: French Horn\n",
+ "Program number 61: Brass Section\n",
+ "Program number 62: Synth Brass 1\n",
+ "Program number 63: Synth Brass 2\n",
+ "Program number 64: Soprano Sax\n",
+ "Program number 65: Alto Sax\n",
+ "Program number 66: Tenor Sax\n",
+ "Program number 67: Baritone Sax\n",
+ "Program number 68: Oboe\n",
+ "Program number 69: English Horn\n",
+ "Program number 70: Bassoon\n",
+ "Program number 71: Clarinet\n",
+ "Program number 72: Piccolo\n",
+ "Program number 73: Flute\n",
+ "Program number 74: Recorder\n",
+ "Program number 75: Pan Flute\n",
+ "Program number 76: Blown bottle\n",
+ "Program number 77: Shakuhachi\n",
+ "Program number 78: Whistle\n",
+ "Program number 79: Ocarina\n",
+ "Program number 80: Lead 1 (square)\n",
+ "Program number 81: Lead 2 (sawtooth)\n",
+ "Program number 82: Lead 3 (calliope)\n",
+ "Program number 83: Lead 4 chiff\n",
+ "Program number 84: Lead 5 (charang)\n",
+ "Program number 85: Lead 6 (voice)\n",
+ "Program number 86: Lead 7 (fifths)\n",
+ "Program number 87: Lead 8 (bass + lead)\n",
+ "Program number 88: Pad 1 (new age)\n",
+ "Program number 89: Pad 2 (warm)\n",
+ "Program number 90: Pad 3 (polysynth)\n",
+ "Program number 91: Pad 4 (choir)\n",
+ "Program number 92: Pad 5 (bowed)\n",
+ "Program number 93: Pad 6 (metallic)\n",
+ "Program number 94: Pad 7 (halo)\n",
+ "Program number 95: Pad 8 (sweep)\n",
+ "Program number 96: FX 1 (rain)\n",
+ "Program number 97: FX 2 (soundtrack)\n",
+ "Program number 98: FX 3 (crystal)\n",
+ "Program number 99: FX 4 (atmosphere)\n",
+ "Program number 100: FX 5 (brightness)\n",
+ "Program number 101: FX 6 (goblins)\n",
+ "Program number 102: FX 7 (echoes)\n",
+ "Program number 103: FX 8 (sci-fi)\n",
+ "Program number 104: Sitar\n",
+ "Program number 105: Banjo\n",
+ "Program number 106: Shamisen\n",
+ "Program number 107: Koto\n",
+ "Program number 108: Kalimba\n",
+ "Program number 109: Bagpipe\n",
+ "Program number 110: Fiddle\n",
+ "Program number 111: Shanai\n",
+ "Program number 112: Tinkle Bell\n",
+ "Program number 113: Agogo\n",
+ "Program number 114: Steel Drums\n",
+ "Program number 115: Woodblock\n",
+ "Program number 116: Taiko Drum\n",
+ "Program number 117: Melodic Tom\n",
+ "Program number 118: Synth Drum\n",
+ "Program number 119: Reverse Cymbal\n",
+ "Program number 120: Guitar Fret Noise\n",
+ "Program number 121: Breath Noise\n",
+ "Program number 122: Seashore\n",
+ "Program number 123: Bird Tweet\n",
+ "Program number 124: Telephone Ring\n",
+ "Program number 125: Helicopter\n",
+ "Program number 126: Applause\n",
+ "Program number 127: Gunshot\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pretty_midi\n",
+ "\n",
+ "def display_all_instrument_names():\n",
+ " for program_number in range(128):\n",
+ " instrument_name = pretty_midi.program_to_instrument_name(program_number)\n",
+ " print(f\"Program number {program_number}: {instrument_name}\")\n",
+ "\n",
+ "# Call the function to display all instrument names\n",
+ "display_all_instrument_names()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#import mido\n",
+ "import pretty_midi\n",
+ "#from mido import KeySignatureError\n",
+ "\n",
+ "def filter_midi(file_path, max_track = 8, max_time = 300):\n",
+ " try:\n",
+ " pm = pretty_midi.PrettyMIDI(file_path)\n",
+ " except Exception as e:\n",
+ " return False\n",
+ " \n",
+ " # time signature 4/4\n",
+ " #if is_timesig_44(pm) == False:\n",
+ " #print(\"timesig\")\n",
+ " #return False\n",
+ " \n",
+ " # number of tracks\n",
+ " if len(pm.instruments)>max_track:\n",
+ " #print(\"tracks\")\n",
+ " return False\n",
+ " \n",
+ " # length of song\n",
+ " #if pm.get_end_time()>max_time:\n",
+ " #print(\"length\")\n",
+ " #return False\n",
+ "\n",
+ " # now filter by instruments\n",
+ "\n",
+ " # filter out the ones without drums\n",
+ " #if has_drums(pm)==False:\n",
+ " #print(\"no drums\")\n",
+ " #return False\n",
+ " \n",
+ " # filter out the ones with brass\n",
+ " instr = gather_instr(pm)\n",
+ " if has_brass(instr):\n",
+ " #print(\"has brass\")\n",
+ " return False\n",
+ " \n",
+ " # filter out the ones without full string and piano\n",
+ " full_instr = gather_full_instr(pm, threshold=0.7)\n",
+ " \n",
+ " if has_piano(full_instr)== False:\n",
+ " #print(\"no piano\")\n",
+ " return False\n",
+ "\n",
+ " if has_guitar(full_instr)== False:\n",
+ " return False\n",
+ " \n",
+ " #if has_string(full_instr)== False:\n",
+ " #print(\"no string\")\n",
+ " #return False\n",
+ " \n",
+ " return True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#import pretty_midi\n",
+ "#midi_path = '/home/ubuntu/lakh-pianoroll-dataset/data/samples_with_strings/c10e69ec7f8212c68ff3658cceef5b9b.mid'\n",
+ "#pm = pretty_midi.PrettyMIDI(midi_path)\n",
+ "#mid = mido.MidiFile(midi_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import shutil\n",
+ "import glob\n",
+ "import os\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "def find_midi_files_upto(root_dir, sample_size):\n",
+ " midi_files = glob.glob(os.path.join(root_dir, '**/*.mid'), recursive=True)\n",
+ " matching_files = []\n",
+ " match_count = 0\n",
+ "\n",
+ " pbar = tqdm(total=len(midi_files), desc=\"Processing MIDI files\")\n",
+ " for midi_file in midi_files:\n",
+ " if filter_midi(midi_file):\n",
+ " matching_files.append(midi_file)\n",
+ " match_count += 1\n",
+ " pbar.set_postfix({'Matching files': match_count})\n",
+ " if match_count >= sample_size:\n",
+ " break\n",
+ " pbar.update(1)\n",
+ " pbar.close()\n",
+ " return matching_files\n",
+ "\n",
+ "def copy_files(files, target_dir):\n",
+ " os.makedirs(target_dir, exist_ok=True)\n",
+ " for file in files:\n",
+ " shutil.copy(file, target_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing MIDI files: 0%| | 14/178561 [00:02<11:02:01, 4.49it/s]/home/ubuntu/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py:100: RuntimeWarning: Tempo, Key or Time signature change events found on non-zero tracks. This is not a valid type 0 or type 1 MIDI file. Tempo, Key or Time Signature may be wrong.\n",
+ " warnings.warn(\n",
+ "Processing MIDI files: 0%| | 148/178561 [00:18<4:21:09, 11.39it/s, Matching files=4] "
+ ]
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_47712/2263633239.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mROOT_DIR\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msample_files\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_midi_files_upto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mROOT_DIR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtgt_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcopy_files\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_files\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgt_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/tmp/ipykernel_47712/2439092612.py\u001b[0m in \u001b[0;36mfind_midi_files_upto\u001b[0;34m(root_dir, sample_size)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Processing MIDI files\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmidi_file\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmidi_files\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mmatching_files\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mmatch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/tmp/ipykernel_47712/3402906515.py\u001b[0m in \u001b[0;36mfilter_midi\u001b[0;34m(file_path, max_track, max_time)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_track\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mpm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpretty_midi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPrettyMIDI\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, midi_file, resolution, initial_tempo)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# If a string was given, pass it as the string filename\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmidi_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmido\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMidiFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Otherwise, try passing it in as a file pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, file, type, ticks_per_beat, charset, debug, clip, tracks)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self, infile)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0m_dbg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Track {i}:'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m self.tracks.append(read_track(infile,\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m clip=self.clip))\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_track\u001b[0;34m(infile, debug, clip)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_sysex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpeek_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mtrack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_message\u001b[0;34m(infile, status_byte, peek_data, delta, clip)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data byte must be in range 0..127'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mMessage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdata_bytes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/messages.py\u001b[0m in \u001b[0;36mfrom_bytes\u001b[0;34m(cl, data, time)\u001b[0m\n\u001b[1;32m 161\u001b[0m \"\"\"\n\u001b[1;32m 162\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0mmsgdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecode_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'data'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSysexData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36mdecode_message\u001b[0;34m(msg_bytes, time, check)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_SPECIAL_CASES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_decode_data_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m_decode_data_bytes\u001b[0;34m(status_byte, data, spec)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing MIDI files: 0%| | 150/178561 [00:30<4:21:08, 11.39it/s, Matching files=4]"
+ ]
+ }
+ ],
+ "source": [
+ "ROOT_DIR = '/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\n",
+ "sample_files = find_midi_files_upto(ROOT_DIR, sample_size=1500)\n",
+ "tgt_dir = '/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\n",
+ "copy_files(sample_files, tgt_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/filter_data/midi_utils.py b/filter_data/midi_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..70b530f47b641ebe3092fade5b09d699d14de13d
--- /dev/null
+++ b/filter_data/midi_utils.py
@@ -0,0 +1,139 @@
+
+import pretty_midi
+
+PIANO = (0,1)
+STRING = (40,41,42,48,49,50,51)
+GUITAR = (24,25,27)
+BRASS = (56,57,58,59,61,62,63,64,65,66,67)
+
+############################## For single track analysis
+
+def calculate_active_duration(instrument):
+ # Collect all note intervals
+ intervals = [(note.start, note.end) for note in instrument.notes]
+
+ # Sort intervals by start time
+ intervals.sort()
+
+ # Merge overlapping intervals and calculate the active duration
+ active_duration = 0
+ current_start, current_end = intervals[0]
+
+ for start, end in intervals[1:]:
+ if start <= current_end: # There is an overlap
+ current_end = max(current_end, end)
+ else: # No overlap, add the previous interval duration and start a new interval
+ active_duration += current_end - current_start
+ current_start, current_end = start, end
+
+ # Add the last interval
+ active_duration += current_end - current_start
+
+ return active_duration
+
+def is_full_track(midi, instrument, threshold=0.6):
+ # Calculate the total duration of the track
+ total_duration = midi.get_end_time()
+
+ # Calculate the active duration (time during which notes are playing)
+ active_duration = calculate_active_duration(instrument)
+
+ # Calculate the percentage of active duration
+ active_percentage = active_duration / total_duration
+
+ #print(f"Total duration: {total_duration:.2f} seconds")
+ #print(f"Active duration: {active_duration:.2f} seconds")
+ #print(f"Active percentage: {active_percentage:.2%}")
+
+ # Check if the active duration meets or exceeds the threshold
+ return active_percentage >= threshold
+
+#################################### For gathering full tracks
+
+def gather_instr(pm):
+ # Gather all the program indexes of the instrument tracks
+ program_indexes = [instrument.program for instrument in pm.instruments]
+
+ # Sort the program indexes
+ program_indexes.sort()
+
+ # Convert the sorted list of program indexes to a tuple
+ program_indexes_tuple = tuple(program_indexes)
+ return program_indexes_tuple
+
+def gather_full_instr(pm, threshold = 0.6):
+ # Gather all the program indexes of the instrument tracks that exceed the duration threshold
+ program_indexes = []
+ for instrument in pm.instruments:
+ if is_full_track(pm, instrument, threshold):
+ program_indexes.append(instrument.program)
+ program_indexes.sort()
+ # Convert the list of program indexes to a tuple
+ program_indexes_tuple = tuple(program_indexes)
+
+ return program_indexes_tuple
+
+####################################### For finding instruments
+
+def has_intersection(wanted_instr, exist_instr):
+ # Convert both the tuple and the group of integers to sets
+ tuple_set = set(wanted_instr)
+ group_set = set(exist_instr)
+
+ # Check if there is any intersection
+ return not tuple_set.isdisjoint(group_set)
+
+# The functions checking instruments in the midi file tracks
+def has_piano(exist_instr):
+ wanted_instr = PIANO
+ return has_intersection(wanted_instr, exist_instr)
+
+def has_string(exist_instr):
+ wanted_instr = STRING
+ return has_intersection(wanted_instr, exist_instr)
+
+def has_guitar(exist_instr):
+ wanted_instr = GUITAR
+ return has_intersection(wanted_instr, exist_instr)
+
+def has_brass(exist_instr):
+ wanted_instr = BRASS
+ return has_intersection(wanted_instr, exist_instr)
+
+def has_drums(pm):
+ for instrument in pm.instruments:
+ if instrument.is_drum:
+ return True
+ return False
+
+
+def print_track_details(instrument):
+ """
+ For visualizing the information in a midi track
+ """
+ print(f"Instrument: {pretty_midi.program_to_instrument_name(instrument.program)}")
+ print(f"Is drum: {instrument.is_drum}")
+
+ print("\nNotes:")
+ for note in instrument.notes:
+ print(f"Start: {note.start:.2f}, End: {note.end:.2f}, Pitch: {note.pitch}, Velocity: {note.velocity}")
+
+ print("\nControl Changes:")
+ for cc in instrument.control_changes:
+ print(f"Time: {cc.time:.2f}, Number: {cc.number}, Value: {cc.value}")
+
+ print("\nPitch Bends:")
+ for pb in instrument.pitch_bends:
+ print(f"Time: {pb.time:.2f}, Pitch: {pb.pitch}")
+
+def is_timesig_44(pm):
+ for time_signature in pm.time_signature_changes:
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
+ return False
+ return True
+
+def is_timesig_34(pm):
+ for time_signature in pm.time_signature_changes:
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
+ return False
+ return True
\ No newline at end of file
diff --git a/generation/__pycache__/gen_utils.cpython-39.pyc b/generation/__pycache__/gen_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..772c1a785843c16d7934dc014e4afd29f36ac139
Binary files /dev/null and b/generation/__pycache__/gen_utils.cpython-39.pyc differ
diff --git a/generation/gen_utils.py b/generation/gen_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..64024b389db798c137aaad5c1c691001c7e5e204
--- /dev/null
+++ b/generation/gen_utils.py
@@ -0,0 +1,302 @@
+import torch
+import numpy as np
+import pretty_midi as pm
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+CHORD_DICTIONARY = {
+ "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
+ "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
+ "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
+ "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
+ "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
+ "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
+ "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
+ "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
+ "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
+ "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
+ "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
+ "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),
+
+ "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
+ "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
+ "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
+ "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
+ "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
+ "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
+ "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
+ "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
+ "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
+ "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
+ "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
+ "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
+}
+
+
+def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
+ '''
+ piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
+ num_notes_onset: a tensor with shape (batch_size, length)
+ mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
+ reduce_extra_notes: True if want to reduce extra notes
+ '''
+ ########## for those greater than the threshold, if num of notes exceed num_notes[i],
+ ########## will keep the first ones and set others to threshold
+ print("Coming")
+ # we only edit onset
+ onset_roll = piano_roll_full[:,0,:,:]
+ mask = mask_full[:,0,:,:]
+ shape = onset_roll.shape
+
+ onset_roll = onset_roll.reshape(-1,shape[-1])
+ mask = mask.reshape(-1,shape[-1])
+ num_notes = num_notes_onset.reshape(-1)
+
+ reduce_note_threshold = 0.499
+ increase_note_threshold = 0.501
+
+ # Initialize a tensor to store the modified values
+ final_onset_roll = onset_roll.clone()
+
+ ########### if number of notes > required, remove the extra notes ###############
+ if reduce_extra_notes:
+ threshold_mask = onset_roll > reduce_note_threshold
+ # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
+ values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
+
+ # Get the top num_notes.max() values for each row
+ num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row
+ topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)
+
+ # Create a mask for the top num_notes[i] values for each row
+ col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
+ topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))
+
+ # Set all values greater than reduce_note_threshold to reduce_note_threshold initially
+ final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold
+
+ # Create a flattened index to scatter the top values back into final_onset_roll
+ flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
+ flat_row_indices = flat_row_indices[topk_mask]
+
+ # Gather the valid topk_indices and corresponding values
+ valid_topk_indices = topk_indices[topk_mask]
+ valid_topk_values = topk_values[topk_mask]
+
+ # Use scatter to place the top num_notes[i] values back to their original positions
+ final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)
+
+ ########### if number of notes < required, add some notes ###############
+ pitch_less_84_mask = torch.ones_like(mask)
+ pitch_less_84_mask[:,51:] = 0
+
+ # Count how many values >= increase_note_threshold for each row
+ threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
+ greater_than_threshold2_count = threshold_mask_2.sum(dim=1)
+
+ # For those rows, find the remaining number of values needed to be set to increase_note_threshold
+ remaining_needed = num_notes - greater_than_threshold2_count
+ remaining_needed_max = int(remaining_needed.max().item())
+ print("\n\n\n",remaining_needed_max,"\n\n\n")
+ if remaining_needed_max>=0: # need to add notes
+ # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
+ values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
+ topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)
+
+ # Mask to only adjust the needed number of values in each row
+ col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
+ adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))
+
+ # Flatten row indices for the new top-k below increase_note_threshold
+ flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
+ flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]
+
+ # Gather the valid indices and set them to increase_note_threshold
+ valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]
+
+ # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
+ final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
+
+ final_onset_roll = final_onset_roll.reshape(shape)
+ piano_roll_full[:,0,:,:] = final_onset_roll
+ return piano_roll_full
+
+def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
+ # 预先计算 major 和 minor 和弦的所有旋转
+ maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
+ maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
+ min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
+ min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
+
+ # all chords, with rotation
+ maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
+ min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
+
+ # combine all chords
+ # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types,
+ # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
+ chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)
+
+ # if using null rhythm condition, have to convert -2 to 1 and -1 to 0
+ if background_condition[:,:2,:,:].min()<0:
+ correct_chord_condition = -background_condition[:,:2,:,:]-1
+ else:
+ correct_chord_condition = background_condition[:,:2,:,:]
+ merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
+ chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
+ shape = chd_chroma_ours.shape
+ chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
+ matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
+ seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
+ seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))
+
+ no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
+ seven_notes_chroma_ours[no_chd_match] = 1.
+
+ # edit notes based on chroma
+ x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
+ print("See Coming?")
+ # edit rhythm
+ if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
+ num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
+ x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)
+
+ return x0
+
+def expand_roll(roll, unit=4, contain_onset=False):
+ # roll: (Channel, T, H) -> (Channel, T * unit, H)
+ n_channel, length, height = roll.shape
+
+ expanded_roll = roll.repeat(unit, axis=1)
+ if contain_onset:
+ expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
+ expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])
+
+ expanded_roll[::2, :, 1:] = 0
+ expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
+ return expanded_roll
+
+def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
+ piano_roll_cut = piano_roll[:,:,lowest:highest+1]
+ return piano_roll_cut
+
+def circular_extend(chd_roll, lowest=33, highest=96):
+ #chd_roll: 6*L*12->6*L*64
+ C4 = 60-lowest
+ C3 = C4-12
+ shape = chd_roll.shape
+ ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
+ ext_chd[:,:,C4:C4+12] = chd_roll
+ ext_chd[:,:,C3:C3+12] = chd_roll
+ return ext_chd
+
+
+def default_quantization(v):
+ return 1 if v > 0.5 else 0
+
+def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
+ ## this function is for extending the cutted piano rolls into the full 128 piano rolls
+ ## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
+ padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
+ return padded_roll
+
+
+
+def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
+ """
+ piano_roll: (2, L, 128), onset and sustain channel.
+ raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
+ """
+ def convert_p(p_, note_list):
+ edit_note_flag = False
+ for t in range(n_step):
+ onset_state = quantization_func(piano_roll[0, t, p_])
+ sustain_state = quantization_func(piano_roll[1, t, p_])
+
+ is_onset = bool(onset_state)
+ is_sustain = bool(sustain_state) and not is_onset
+
+ pitch = p_
+
+ if is_onset:
+ edit_note_flag = True
+ note_list.append([t, pitch, 1])
+ elif is_sustain:
+ if edit_note_flag:
+ note_list[-1][-1] += 1
+ else:
+ edit_note_flag = False
+ return note_list
+
+ quantization_func = default_quantization if quantization_func is None else quantization_func
+ assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}"
+
+ n_step = piano_roll.shape[1]
+
+ notes = []
+ for p in range(128):
+ convert_p(p, notes)
+
+ return notes
+
+
+def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
+ """Default use shift beat"""
+
+ beat_alpha = 60 / bpm
+ step_alpha = unit * beat_alpha
+
+ notes = []
+
+ shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha
+
+ for note in note_mat:
+ onset, pitch, dur = note
+ start = onset * step_alpha + shift_sec
+ end = (onset + dur) * step_alpha + shift_sec
+
+ notes.append(pm.Note(vel, int(pitch), start, end))
+
+ return notes
+
+
+def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
+ midi = pm.PrettyMIDI(initial_tempo=bpm)
+
+ piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
+ piano = pm.Instrument(program=piano_program)
+ piano.notes+=piano_notes_list
+ midi.instruments.append(piano)
+
+ # chd_program = pm.instrument_name_to_program('Violin')
+ # chd = pm.Instrument(program=chd_program)
+ # chd.notes+=chd_notes_list
+ # midi.instruments.append(chd)
+
+ if lsh_notes_list is not None:
+ lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
+ lsh = pm.Instrument(program=lsh_program)
+ lsh.notes+=lsh_notes_list
+ midi.instruments.append(lsh)
+
+ return midi
+
+def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
+ piano_mat = piano_roll_to_note_mat(piano_roll)
+ piano_notes = note_mat_to_notes(piano_mat, bpm)
+
+ chd_mat = piano_roll_to_note_mat(chd_roll)
+ chd_notes = note_mat_to_notes(chd_mat, bpm)
+
+ if lsh_roll is not None:
+ lsh_mat = piano_roll_to_note_mat(lsh_roll)
+ lsh_notes = note_mat_to_notes(lsh_mat, bpm)
+ else:
+ lsh_notes=None
+
+ piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes,
+ chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
+ return piano_pm
+
+def save_midi(pm, filename):
+ pm.write(filename)
\ No newline at end of file
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5dc75f1e000aa98775ead2b4430a4eb00dc9e7
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,59 @@
+from .latent_diffusion import LatentDiffusion
+from .model_sdf import Diffpro_SDF
+from .architecture.unet import UNetModel
+import os
+
+
+def init_ldm_model(params, debug_mode=False):
+ unet_model = UNetModel(
+ in_channels=params.in_channels,
+ out_channels=params.out_channels,
+ channels=params.channels,
+ attention_levels=params.attention_levels,
+ n_res_blocks=params.n_res_blocks,
+ channel_multipliers=params.channel_multipliers,
+ n_heads=params.n_heads,
+ tf_layers=params.tf_layers,
+ #d_cond=params.d_cond,
+ )
+
+
+ ldm_model = LatentDiffusion(
+ unet_model=unet_model,
+ #autoencoder=None,
+ #autoreg_cond_enc=autoreg_cond_enc,
+ #external_cond_enc=external_cond_enc,
+ latent_scaling_factor=params.latent_scaling_factor,
+ n_steps=params.n_steps,
+ linear_start=params.linear_start,
+ linear_end=params.linear_end,
+ debug_mode=debug_mode
+ )
+
+ return ldm_model
+
+
+
+def init_diff_pro_sdf(ldm_model, params, device):
+ return Diffpro_SDF(ldm_model).to(device)
+
+
+def get_model_path(model_dir, model_id=None):
+ model_desc = os.path.basename(model_dir)
+ if model_id is None:
+ model_path = os.path.join(model_dir, 'chkpts', 'weights.pt')
+
+ # retrieve real model_id from the actual file weights.pt is pointing to
+ model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
+
+ elif model_id == 'best':
+ model_path = os.path.join(model_dir, 'chkpts', 'weights_best.pt')
+ # retrieve real model_id from the actual file weights.pt is pointing to
+ model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
+ elif model_id == 'default':
+ model_path = os.path.join(model_dir, 'chkpts', 'weights_default.pt')
+ if not os.path.exists(model_path):
+ return get_model_path(model_dir, 'best')
+ else:
+ model_path = f"{model_dir}/chkpts/weights-{model_id}.pt"
+ return model_path, model_id, model_desc
diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0082a4fba238c21cde1a333fe198bf447b2c974e
Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/model/__pycache__/latent_diffusion.cpython-39.pyc b/model/__pycache__/latent_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b78ebcb9440337d801824127fedea456a2bf3e6
Binary files /dev/null and b/model/__pycache__/latent_diffusion.cpython-39.pyc differ
diff --git a/model/__pycache__/model_sdf.cpython-39.pyc b/model/__pycache__/model_sdf.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fa2acd035f522389722dcd927709b804b87cece
Binary files /dev/null and b/model/__pycache__/model_sdf.cpython-39.pyc differ
diff --git a/model/__pycache__/sampler_sdf.cpython-39.pyc b/model/__pycache__/sampler_sdf.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef313488a1362b53822ae9b8239101bddc5b041
Binary files /dev/null and b/model/__pycache__/sampler_sdf.cpython-39.pyc differ
diff --git a/model/architecture/__pycache__/unet.cpython-39.pyc b/model/architecture/__pycache__/unet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9aa5056279a5082643512082ac01ccecf2ffe101
Binary files /dev/null and b/model/architecture/__pycache__/unet.cpython-39.pyc differ
diff --git a/model/architecture/__pycache__/unet_attention.cpython-39.pyc b/model/architecture/__pycache__/unet_attention.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09005f959581d6013eab2ff0be9a496342c778fe
Binary files /dev/null and b/model/architecture/__pycache__/unet_attention.cpython-39.pyc differ
diff --git a/model/architecture/unet.py b/model/architecture/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..af7f656db30b07d7e8955a114b0a8a609cdd70b9
--- /dev/null
+++ b/model/architecture/unet.py
@@ -0,0 +1,364 @@
+"""
+---
+title: U-Net for Stable Diffusion
+summary: >
+ Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
+---
+
+# U-Net for [Stable Diffusion](../index.html)
+
+This implements the U-Net that
+ gives $\epsilon_\text{cond}(x_t, c)$
+
+We have kept to the model definition and naming unchanged from
+[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
+so that we can load the checkpoints directly.
+"""
+
+import math
+from typing import List
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .unet_attention import SpatialTransformer
+
+
+class UNetModel(nn.Module):
+ """
+ ## U-Net model
+ """
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ out_channels: int,
+ channels: int,
+ n_res_blocks: int,
+ attention_levels: List[int],
+ channel_multipliers: List[int],
+ n_heads: int,
+ tf_layers: int = 1,
+ #d_cond: int = 768
+ ):
+ """
+ :param in_channels: is the number of channels in the input feature map
+ :param out_channels: is the number of channels in the output feature map
+ :param channels: is the base channel count for the model
+ :param n_res_blocks: number of residual blocks at each level
+ :param attention_levels: are the levels at which attention should be performed
+ :param channel_multipliers: are the multiplicative factors for number of channels for each level
+ :param n_heads: the number of attention heads in the transformers
+ """
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels
+ #self.d_cond = d_cond
+
+ # Number of levels
+ levels = len(channel_multipliers)
+ # Size time embeddings
+ d_time_emb = channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(channels, d_time_emb),
+ nn.SiLU(),
+ nn.Linear(d_time_emb, d_time_emb),
+ )
+
+ # Input half of the U-Net
+ self.input_blocks = nn.ModuleList()
+ # Initial $3 \times 3$ convolution that maps the input to `channels`.
+ # The blocks are wrapped in `TimestepEmbedSequential` module because
+ # different modules have different forward function signatures;
+ # for example, convolution only accepts the feature map and
+ # residual blocks accept the feature map and time embedding.
+ # `TimestepEmbedSequential` calls them accordingly.
+ self.input_blocks.append(
+ TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
+ )
+ # Number of channels at each block in the input half of U-Net
+ input_block_channels = [channels]
+ # Number of channels at each level
+ channels_list = [channels * m for m in channel_multipliers]
+ # Prepare levels
+ for i in range(levels):
+ # Add the residual blocks and attentions
+ for _ in range(n_res_blocks):
+ # Residual block maps from previous number of channels to the number of
+ # channels in the current level
+ layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
+ channels = channels_list[i]
+ # Add transformer
+ if i in attention_levels:
+ layers.append(
+ SpatialTransformer(channels, n_heads, tf_layers)
+ )
+ # Add them to the input half of the U-Net and keep track of the number of channels of
+ # its output
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_channels.append(channels)
+ # Down sample at all levels except last
+ if i != levels - 1:
+ self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
+ input_block_channels.append(channels)
+
+ # The middle of the U-Net
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(channels, d_time_emb),
+ SpatialTransformer(channels, n_heads, tf_layers),
+ ResBlock(channels, d_time_emb),
+ )
+
+ # Second half of the U-Net
+ self.output_blocks = nn.ModuleList([])
+ # Prepare levels in reverse order
+ for i in reversed(range(levels)):
+ # Add the residual blocks and attentions
+ for j in range(n_res_blocks + 1):
+ # Residual block maps from previous number of channels plus the
+ # skip connections from the input half of U-Net to the number of
+ # channels in the current level.
+ layers = [
+ ResBlock(
+ channels + input_block_channels.pop(),
+ d_time_emb,
+ out_channels=channels_list[i]
+ )
+ ]
+ channels = channels_list[i]
+ # Add transformer
+ if i in attention_levels:
+ layers.append(
+ SpatialTransformer(channels, n_heads, tf_layers)
+ )
+ # Up-sample at every level after last residual block
+ # except the last one.
+ # Note that we are iterating in reverse; i.e. `i == 0` is the last.
+ if i != 0 and j == n_res_blocks:
+ layers.append(UpSample(channels))
+ # Add to the output half of the U-Net
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+ # Final normalization and $3 \times 3$ convolution
+ self.out = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ nn.Conv2d(channels, out_channels, 3, padding=1),
+ )
+
+ def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
+ """
+ ## Create sinusoidal time step embeddings
+
+ :param time_steps: are the time steps of shape `[batch_size]`
+ :param max_period: controls the minimum frequency of the embeddings.
+ """
+ # $\frac{c}{2}$; half the channels are sin and the other half is cos,
+ half = self.channels // 2
+ # $\frac{1}{10000^{\frac{2i}{c}}}$
+ frequencies = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=time_steps.device)
+ # $\frac{t}{10000^{\frac{2i}{c}}}$
+ args = time_steps[:, None].float() * frequencies[None]
+ # $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+
+ def forward(self, x: torch.Tensor, time_steps: torch.Tensor):
+ """
+ :param x: is the input feature map of shape `[batch_size, channels, width, height]`
+ :param time_steps: are the time steps of shape `[batch_size]`
+ :param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
+ """
+ # To store the input half outputs for skip connections
+ x_input_block = []
+
+ # Get time step embeddings
+ t_emb = self.time_step_embedding(time_steps)
+ t_emb = self.time_embed(t_emb)
+
+ # Input half of the U-Net
+ for module in self.input_blocks:
+ ##########################
+ #print("x:", x.dtype,"t_emb:",t_emb.dtype)
+ ##########################
+ #x = x.to(torch.float16)
+ x = module(x, t_emb)
+ x_input_block.append(x)
+ # Middle of the U-Net
+ x = self.middle_block(x, t_emb)
+ # Output half of the U-Net
+ for module in self.output_blocks:
+ # print(x.size(), 'a')
+ x = torch.cat([x, x_input_block.pop()], dim=1)
+ # print(x.size(), 'b')
+ x = module(x, t_emb)
+
+ # Final normalization and $3 \times 3$ convolution
+ return self.out(x)
+
+
+class TimestepEmbedSequential(nn.Sequential):
+ """
+ ### Sequential block for modules with different inputs
+
+ This sequential module can compose of different modules suck as `ResBlock`,
+ `nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
+ """
+ def forward(self, x, t_emb, cond=None):
+ for layer in self:
+ if isinstance(layer, ResBlock):
+ x = layer(x, t_emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x)
+ else:
+ x = layer(x)
+ return x
+
+
+class UpSample(nn.Module):
+ """
+ ### Up-sampling layer
+ """
+ def __init__(self, channels: int):
+ """
+ :param channels: is the number of channels
+ """
+ super().__init__()
+ # $3 \times 3$ convolution mapping
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
+
+ def forward(self, x: torch.Tensor):
+ """
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
+ """
+ # Up-sample by a factor of $2$
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ # Apply convolution
+ return self.conv(x)
+
+
+class DownSample(nn.Module):
+ """
+ ## Down-sampling layer
+ """
+ def __init__(self, channels: int):
+ """
+ :param channels: is the number of channels
+ """
+ super().__init__()
+ # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
+ self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
+
+ def forward(self, x: torch.Tensor):
+ """
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
+ """
+ # Apply convolution
+ return self.op(x)
+
+
+class ResBlock(nn.Module):
+ """
+ ## ResNet Block
+ """
+ def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
+ """
+ :param channels: the number of input channels
+ :param d_t_emb: the size of timestep embeddings
+ :param out_channels: is the number of out channels. defaults to `channels.
+ """
+ super().__init__()
+ # `out_channels` not specified
+ if out_channels is None:
+ out_channels = channels
+
+ # First normalization and convolution
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ nn.Conv2d(channels, out_channels, 3, padding=1),
+ )
+
+ # Time step embeddings
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(d_t_emb, out_channels),
+ )
+ # Final convolution layer
+ self.out_layers = nn.Sequential(
+ normalization(out_channels), nn.SiLU(), nn.Dropout(0.),
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
+ )
+
+ # `channels` to `out_channels` mapping layer for residual connection
+ if out_channels == channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = nn.Conv2d(channels, out_channels, 1)
+
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
+ """
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
+ :param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
+ """
+ # Initial convolution
+ h = self.in_layers(x)
+ # Time step embeddings
+ t_emb = self.emb_layers(t_emb).type(h.dtype)
+ # Add time step embeddings
+ h = h + t_emb[:, :, None, None]
+ # Final convolution
+ h = self.out_layers(h)
+ # Add skip connection
+ return self.skip_connection(x) + h
+
+
+class GroupNorm32(nn.GroupNorm):
+ """
+ ### Group normalization with float32 casting
+ """
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels):
+ """
+ ### Group normalization
+
+ This is a helper function, with fixed number of groups..
+ """
+ return GroupNorm32(32, channels)
+
+
+def _test_time_embeddings():
+ """
+ Test sinusoidal time step embeddings
+ """
+ import matplotlib.pyplot as plt
+
+ plt.figure(figsize=(15, 5))
+ m = UNetModel(
+ in_channels=1,
+ out_channels=1,
+ channels=320,
+ n_res_blocks=1,
+ attention_levels=[],
+ channel_multipliers=[],
+ n_heads=1,
+ tf_layers=1,
+ d_cond=1
+ )
+ te = m.time_step_embedding(torch.arange(0, 1000))
+ plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
+ plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
+ plt.title("Time embeddings")
+ plt.show()
+
+
+#
+if __name__ == '__main__':
+ _test_time_embeddings()
diff --git a/model/architecture/unet_attention.py b/model/architecture/unet_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..84910964e624a04f3c83055455e9501ed56358a8
--- /dev/null
+++ b/model/architecture/unet_attention.py
@@ -0,0 +1,321 @@
+"""
+---
+title: Transformer for Stable Diffusion U-Net
+summary: >
+ Annotated PyTorch implementation/tutorial of the transformer
+ for U-Net in stable diffusion.
+---
+
+# Transformer for Stable Diffusion [U-Net](unet.html)
+
+This implements the transformer module used in [U-Net](unet.html) that
+ gives $\epsilon_\text{cond}(x_t, c)$
+
+We have kept to the model definition and naming unchanged from
+[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
+so that we can load the checkpoints directly.
+"""
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SpatialTransformer(nn.Module):
+ """
+ ## Spatial Transformer
+ """
+ def __init__(self, channels: int, n_heads: int, n_layers: int):
+ """
+ :param channels: is the number of channels in the feature map
+ :param n_heads: is the number of attention heads
+ :param n_layers: is the number of transformer layers
+ :param d_cond: is the size of the conditional embedding
+ """
+ super().__init__()
+ # Initial group normalization
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=channels, eps=1e-6, affine=True
+ )
+ # Initial $1 \times 1$ convolution
+ self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
+
+ # Transformer layers
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ channels, n_heads, channels // n_heads
+ ) for _ in range(n_layers)
+ ]
+ )
+
+ # Final $1 \times 1$ convolution
+ self.proj_out = nn.Conv2d(
+ channels, channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x: torch.Tensor):
+ """
+ :param x: is the feature map of shape `[batch_size, channels, height, width]`
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+ """
+ # Get shape `[batch_size, channels, height, width]`
+ b, c, h, w = x.shape
+ # For residual connection
+ x_in = x
+ # Normalize
+ x = self.norm(x)
+ # Initial $1 \times 1$ convolution
+ x = self.proj_in(x)
+ # Transpose and reshape from `[batch_size, channels, height, width]`
+ # to `[batch_size, height * width, channels]`
+ x = x.permute(0, 2, 3, 1).view(b, h * w, c)
+ # Apply the transformer layers
+ for block in self.transformer_blocks:
+ x = block(x)
+ # Reshape and transpose from `[batch_size, height * width, channels]`
+ # to `[batch_size, channels, height, width]`
+ x = x.view(b, h, w, c).permute(0, 3, 1, 2)
+ # Final $1 \times 1$ convolution
+ x = self.proj_out(x)
+ # Add residual
+ return x + x_in
+
+
+class BasicTransformerBlock(nn.Module):
+ """
+ ### Transformer Layer
+ """
+ def __init__(self, d_model: int, n_heads: int, d_head: int):
+ """
+ :param d_model: is the input embedding size
+ :param n_heads: is the number of attention heads
+ :param d_head: is the size of a attention head
+ :param d_cond: is the size of the conditional embeddings
+ """
+ super().__init__()
+ # Self-attention layer and pre-norm layer
+ self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
+ self.norm1 = nn.LayerNorm(d_model)
+ # Cross attention layer and pre-norm layer
+ #self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
+ self.norm2 = nn.LayerNorm(d_model)
+ # Feed-forward network and pre-norm layer
+ self.ff = FeedForward(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ def forward(self, x: torch.Tensor):
+ """
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+ """
+ # Self attention
+ x = self.attn1(self.norm1(x)) + x
+ # Cross-attention with conditioning
+ # x = self.attn2(self.norm2(x), cond=cond) + x
+ # Feed-forward network
+ x = self.ff(self.norm3(x)) + x
+ #
+ return x
+
+
+class CrossAttention(nn.Module):
+ """
+ ### Cross Attention Layer
+
+ This falls-back to self-attention when conditional embeddings are not specified.
+ """
+
+ use_flash_attention: bool = False
+
+ def __init__(
+ self,
+ d_model: int,
+ d_cond: int,
+ n_heads: int,
+ d_head: int,
+ is_inplace: bool = True
+ ):
+ """
+ :param d_model: is the input embedding size
+ :param n_heads: is the number of attention heads
+ :param d_head: is the size of a attention head
+ :param d_cond: is the size of the conditional embeddings
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
+ save memory
+ """
+ super().__init__()
+
+ self.is_inplace = is_inplace
+ self.n_heads = n_heads
+ self.d_head = d_head
+
+ # Attention scaling factor
+ self.scale = d_head**-0.5
+
+ # Query, key and value mappings
+ d_attn = d_head * n_heads
+ self.to_q = nn.Linear(d_model, d_attn, bias=False)
+ self.to_k = nn.Linear(d_cond, d_attn, bias=False)
+ self.to_v = nn.Linear(d_cond, d_attn, bias=False)
+
+ # Final linear layer
+ self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
+
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
+ # Flash attention is only used if it's installed
+ # and `CrossAttention.use_flash_attention` is set to `True`.
+ try:
+ # You can install flash attention by cloning their Github repo,
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
+ # and then running `python setup.py install`
+ from flash_attn.flash_attention import FlashAttention
+ self.flash = FlashAttention()
+ # Set the scale for scaled dot-product attention.
+ self.flash.softmax_scale = self.scale
+ # Set to `None` if it's not installed
+ except ImportError:
+ self.flash = None
+
+ def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
+ """
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+ """
+
+ # If `cond` is `None` we perform self attention
+ has_cond = cond is not None
+ if not has_cond:
+ cond = x
+
+ # Get query, key and value vectors
+ q = self.to_q(x)
+ k = self.to_k(cond)
+ v = self.to_v(cond)
+
+ # Use flash attention if it's available and the head size is less than or equal to `128`
+ if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
+ return self.flash_attention(q, k, v)
+ # Otherwise, fallback to normal attention
+ else:
+ return self.normal_attention(q, k, v)
+
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ """
+ #### Flash Attention
+
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ """
+
+ # Get batch size and number of elements along sequence axis (`width * height`)
+ batch_size, seq_len, _ = q.shape
+
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
+ qkv = torch.stack((q, k, v), dim=2)
+ # Split the heads
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
+
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
+ # fit this size.
+ if self.d_head <= 32:
+ pad = 32 - self.d_head
+ elif self.d_head <= 64:
+ pad = 64 - self.d_head
+ elif self.d_head <= 128:
+ pad = 128 - self.d_head
+ else:
+ raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
+
+ # Pad the heads
+ if pad:
+ qkv = torch.cat(
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
+ )
+
+ # Compute attention
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
+ out, _ = self.flash(qkv)
+ # Truncate the extra head size
+ out = out[:, :, :, : self.d_head]
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
+
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
+ return self.to_out(out)
+
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ """
+ #### Normal Attention
+
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ """
+
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
+ q = q.view(*q.shape[: 2], self.n_heads, -1)
+ k = k.view(*k.shape[: 2], self.n_heads, -1)
+ v = v.view(*v.shape[: 2], self.n_heads, -1)
+
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
+ attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
+
+ # Compute softmax
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
+ if self.is_inplace:
+ half = attn.shape[0] // 2
+ attn[half :] = attn[half :].softmax(dim=-1)
+ attn[: half] = attn[: half].softmax(dim=-1)
+ else:
+ attn = attn.softmax(dim=-1)
+
+ # Compute attention output
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+ out = torch.einsum('bhij,bjhd->bihd', attn, v)
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
+ out = out.reshape(*out.shape[: 2], -1)
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
+ return self.to_out(out)
+
+
+class FeedForward(nn.Module):
+ """
+ ### Feed-Forward Network
+ """
+ def __init__(self, d_model: int, d_mult: int = 4):
+ """
+ :param d_model: is the input embedding size
+ :param d_mult: is multiplicative factor for the hidden layer size
+ """
+ super().__init__()
+ self.net = nn.Sequential(
+ GeGLU(d_model, d_model * d_mult), nn.Dropout(0.),
+ nn.Linear(d_model * d_mult, d_model)
+ )
+
+ def forward(self, x: torch.Tensor):
+ return self.net(x)
+
+
+class GeGLU(nn.Module):
+ """
+ ### GeGLU Activation
+
+ $$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
+ """
+ def __init__(self, d_in: int, d_out: int):
+ super().__init__()
+ # Combined linear projections $xW + b$ and $xV + c$
+ self.proj = nn.Linear(d_in, d_out * 2)
+
+ def forward(self, x: torch.Tensor):
+ # Get $xW + b$ and $xV + c$
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ # $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
+ return x * F.gelu(gate)
diff --git a/model/latent_diffusion.py b/model/latent_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a73bad5e1b0635cdbe2b96e47c924036a55c5217
--- /dev/null
+++ b/model/latent_diffusion.py
@@ -0,0 +1,222 @@
+"""
+---
+title: Latent Diffusion Models
+summary: >
+ Annotated PyTorch implementation/tutorial of latent diffusion models from paper
+ High-Resolution Image Synthesis with Latent Diffusion Models
+---
+
+# Latent Diffusion Models
+
+Latent diffusion models use an auto-encoder to map between image space and
+latent space. The diffusion model works on the diffusion space, which makes it
+a lot easier to train.
+It is based on paper
+[High-Resolution Image Synthesis with Latent Diffusion Models](https://papers.labml.ai/paper/2112.10752).
+
+They use a pre-trained auto-encoder and train the diffusion U-Net on the latent
+space of the pre-trained auto-encoder.
+
+For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html).
+We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
+"""
+
+from typing import List, Tuple, Optional, Union
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .architecture.unet import UNetModel
+import random
+
+
+def gather(consts: torch.Tensor, t: torch.Tensor):
+ """Gather consts for $t$ and reshape to feature map shape"""
+ c = consts.gather(-1, t)
+ return c.reshape(-1, 1, 1, 1)
+
+
+class LatentDiffusion(nn.Module):
+ """
+ ## Latent diffusion model
+
+ This contains following components:
+
+ * [AutoEncoder](model/autoencoder.html)
+ * [U-Net](model/unet.html) with [attention](model/unet_attention.html)
+ """
+ eps_model: UNetModel
+ #first_stage_model: Optional[Autoencoder] = None
+
+ def __init__(
+ self,
+ unet_model: UNetModel,
+ latent_scaling_factor: float,
+ n_steps: int,
+ linear_start: float,
+ linear_end: float,
+ debug_mode: Optional[bool] = False
+ ):
+ """
+ :param unet_model: is the [U-Net](model/unet.html) that predicts noise
+ $\epsilon_\text{cond}(x_t, c)$, in latent space
+ :param autoencoder: is the [AutoEncoder](model/autoencoder.html)
+ :param latent_scaling_factor: is the scaling factor for the latent space. The encodings of
+ the autoencoder are scaled by this before feeding into the U-Net.
+ :param n_steps: is the number of diffusion steps $T$.
+ :param linear_start: is the start of the $\beta$ schedule.
+ :param linear_end: is the end of the $\beta$ schedule.
+ """
+ super().__init__()
+ # Wrap the [U-Net](model/unet.html) to keep the same model structure as
+ # [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
+ self.eps_model = unet_model
+ self.latent_scaling_factor = latent_scaling_factor
+
+ # Number of steps $T$
+ self.n_steps = n_steps
+
+ # $\beta$ schedule
+ beta = torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64
+ ) ** 2
+ # $\alpha_t = 1 - \beta_t$
+ alpha = 1. - beta
+ # $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
+ alpha_bar = torch.cumprod(alpha, dim=0)
+ self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False)
+ self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
+ self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
+ self.alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
+ self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev))
+ self.sigma2 = self.beta
+
+ self.debug_mode = debug_mode
+
+ @property
+ def device(self):
+ """
+ ### Get model device
+ """
+ return next(iter(self.eps_model.parameters())).device
+
+
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
+ """
+ ### Predict noise
+
+ Predict noise given the latent representation $x_t$, time step $t$, and the
+ conditioning context $c$.
+
+ $$\epsilon_\text{cond}(x_t, c)$$
+ """
+ return self.eps_model(x, t)
+
+ def q_xt_x0(self, x0: torch.Tensor,
+ t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ #### Get $q(x_t|x_0)$ distribution
+ """
+
+ # [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
+ mean = gather(self.alpha_bar, t)**0.5 * x0
+ # $(1-\bar\alpha_t) \mathbf{I}$
+ var = 1 - gather(self.alpha_bar, t)
+ #
+ return mean, var
+
+ def q_sample(
+ self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None
+ ):
+ """
+ #### Sample from $q(x_t|x_0)$
+ """
+
+ # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+ if eps is None:
+ eps = torch.randn_like(x0)
+
+ # get $q(x_t|x_0)$
+ mean, var = self.q_xt_x0(x0, t)
+ # Sample from $q(x_t|x_0)$
+ return mean + (var**0.5) * eps
+
+ def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
+ """
+ #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
+ """
+
+ # $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
+ eps_theta = self.eps_model(xt, t)
+ # [gather](utils.html) $\bar\alpha_t$
+ alpha_bar = gather(self.alpha_bar, t)
+ # [gather](utils.html) $\bar\alpha_t-1$
+ alpha_bar_prev = gather(self.alpha_bar_prev, t)
+ # [gather](utils.html) $\sigma_t$
+ sigma_ddim = gather(self.sigma_ddim, t)
+
+ # DDIM sampling
+ # $\frac{x_t-\sqrt{1-\bar\alpha_t}\epsilon}{\sqrt{\bar\alpha_t}}$
+ predicted_x0 = (xt - (1-alpha_bar)**0.5 * eps_theta) / (alpha_bar)**.5
+ # $\sqrt{1-\alpha_{t-1}-\sigma_t^2}$
+ direction_to_xt = (1 - alpha_bar_prev - sigma_ddim**2)**0.5 * eps_theta
+
+ # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+ eps = torch.randn(xt.shape, device=xt.device)
+
+ # Sample
+ x_tm_1 = alpha_bar_prev**0.5 * predicted_x0 + direction_to_xt + sigma_ddim * eps
+ return x_tm_1
+
+ def loss(
+ self,
+ x0: torch.Tensor,
+ #autoreg_cond: Union[torch.Tensor, None], #This means it can be either a tensor or none
+ #external_cond: Union[torch.Tensor, None],
+ noise: Optional[torch.Tensor] = None,
+ ):
+ """
+ #### Simplified Loss
+ """
+ # Get batch size
+ batch_size = x0.shape[0]
+ # Get random $t$ for each sample in the batch
+ t = torch.randint(
+ 0, self.n_steps, (batch_size, ), device=x0.device, dtype=torch.long
+ )
+
+
+ #autoreg_cond = -torch.ones(x0.size(0), 1, self.eps_model.d_cond, device=x0.device, dtype=x0.dtype)
+ #cond = autoreg_cond
+
+ if x0.size(1) == self.eps_model.out_channels: # generating form
+ if self.debug_mode:
+ print('In the mode of root level:', x0.size())
+ if noise is None:
+ x0 = x0.to(torch.float32)
+ noise = torch.randn_like(x0)
+
+ xt = self.q_sample(x0, t, eps=noise)
+
+ eps_theta = self.eps_model(xt, t)
+
+ loss = F.mse_loss(noise, eps_theta)
+ else:
+ if self.debug_mode:
+ print('In the mode of non-root level:', x0.size())
+
+ if noise is None:
+ noise = torch.randn_like(x0[:, 0: 2])
+
+ front_t = self.q_sample(x0[:, 0: 2], t, eps=noise)
+
+ background_cond = x0[:, 2:]
+
+ xt = torch.cat([front_t, background_cond], 1)
+
+ eps_theta = self.eps_model(xt, t)
+
+ loss = F.mse_loss(noise, eps_theta)
+ if self.debug_mode:
+ print('loss:', loss)
+ return loss
diff --git a/model/model_sdf.py b/model/model_sdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac5a0f3603ba536528ec8d9b3b056315341d0ce5
--- /dev/null
+++ b/model/model_sdf.py
@@ -0,0 +1,55 @@
+import torch
+import os
+import torch.nn as nn
+from .latent_diffusion import LatentDiffusion
+
+
+class Diffpro_SDF(nn.Module):
+
+ def __init__(
+ self,
+ ldm: LatentDiffusion,
+ ):
+ """
+ cond_type: {chord, texture}
+ cond_mode: {cond, mix, uncond}
+ mix: use a special condition for unconditional learning with probability of 0.2
+ use_enc: whether to use pretrained chord encoder to generate encoded condition
+ """
+ super(Diffpro_SDF, self).__init__()
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.ldm = ldm
+
+ @classmethod
+ def load_trained(
+ cls,
+ ldm,
+ chkpt_fpath,
+ ):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = cls(ldm)
+ trained_leaner = torch.load(chkpt_fpath, map_location=device)
+ try:
+ model.load_state_dict(trained_leaner["model"])
+ except RuntimeError:
+ model_dict = trained_leaner["model"]
+ model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()}
+ model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()}
+ model.load_state_dict(model_dict)
+ return model
+
+ def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
+ return self.ldm.p_sample(xt, t)
+
+ def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
+ return self.ldm.q_sample(x0, t)
+
+ def get_loss_dict(self, batch, step):
+ """
+ z_y is the stuff the diffusion model needs to learn
+ """
+ # x = batch.float().to(self.device)
+
+ x= batch
+ loss = self.ldm.loss(x)
+ return {"loss": loss}
diff --git a/model/sampler_sdf.py b/model/sampler_sdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c40682b7b298261adfc697c803d58b1aa099ef
--- /dev/null
+++ b/model/sampler_sdf.py
@@ -0,0 +1,538 @@
+from typing import Optional, List, Union
+import numpy as np
+import torch
+from labml import monit
+from .latent_diffusion import LatentDiffusion
+
+def set_seed(seed):
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+# Call the function to set the seed
+# set_seed(42)
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class DiffusionSampler:
+ """
+ ## Base class for sampling algorithms
+ """
+ model: LatentDiffusion
+
+ def __init__(self, model: LatentDiffusion):
+ """
+ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
+ """
+ super().__init__()
+ # Set the model $\epsilon_\text{cond}(x_t, c)$
+ self.model = model
+ # Get number of steps the model was trained with $T$
+ self.n_steps = model.n_steps
+
+
+class SDFSampler(DiffusionSampler):
+ """
+ ## DDPM Sampler
+
+ This extends the [`DiffusionSampler` base class](index.html).
+
+ DDPM samples images by repeatedly removing noise by sampling step by step from
+ $p_\theta(x_{t-1} | x_t)$,
+
+ \begin{align}
+
+ p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
+
+ \mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
+ + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
+
+ \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
+
+ x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
+
+ \end{align}
+ """
+
+ model: LatentDiffusion
+
+ def __init__(
+ self,
+ model: LatentDiffusion,
+ max_l,
+ h,
+ is_autocast=False,
+ is_show_image=False,
+ device=None,
+ debug_mode=False
+ ):
+ """
+ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
+ """
+ super().__init__(model)
+ if device is None:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ self.device = device
+ # selected time steps ($\tau$) $1, 2, \dots, T$
+ # self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32)
+ self.tau = torch.tensor([13, 53, 116, 193, 310, 443, 587, 730, 845, 999], device=self.device) # torch.tensor([999, 845, 730, 587, 443, 310, 193, 116, 53, 13])
+ # self.tau = torch.tensor(np.asarray(list(range(self.n_steps)), dtype=np.int32), device=self.device)
+ self.used_n_steps = len(self.tau)
+
+ self.is_show_image = is_show_image
+
+ self.autocast = torch.cuda.amp.autocast(enabled=is_autocast)
+
+ self.out_channel = self.model.eps_model.out_channels
+ self.max_l = max_l
+ self.h = h
+ self.debug_mode = debug_mode
+ self.guidance_scale = 7.5
+ self.guidance_rescale = 0.7
+
+ # now, we set the coefficients
+ with torch.no_grad():
+ # $\bar\alpha_t$
+ self.alpha_bar = self.model.alpha_bar
+ # $\beta_t$ schedule
+ beta = self.model.beta
+ # $\bar\alpha_{t-1}$
+ self.alpha_bar_prev = torch.cat([self.alpha_bar.new_tensor([1.]), self.alpha_bar[:-1]])
+ # $\sigma_t$ in DDIM
+ self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) # DDPM noise schedule
+
+ # $\frac{1}{\sqrt{\bar\alpha}}$
+ self.one_over_sqrt_alpha_bar = 1 / (self.alpha_bar ** 0.5)
+ # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
+ self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar = (1 - self.alpha_bar)**0.5 / self.alpha_bar**0.5
+
+ # $\sqrt{\bar\alpha}$
+ self.sqrt_alpha_bar = self.alpha_bar ** 0.5
+ # $\sqrt{1 - \bar\alpha}$
+ self.sqrt_1m_alpha_bar = (1 - self.alpha_bar) ** 0.5
+ # # $\sqrt{\bar\alpha_{t-1}}$
+ # self.sqrt_alpha_bar_prev = self.alpha_bar_prev ** 0.5
+ # # $\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}$
+ # self.sqrt_1m_alpha_bar_prev_m_sigma2 = (1 - self.alpha_bar_prev - self.sigma_ddim ** 2) ** 0.5
+
+ #@property
+ # def d_cond(self):
+ #return self.model.eps_model.d_cond
+
+ def get_eps(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ background_cond: Optional[torch.Tensor],
+
+ uncond_scale: Optional[float],
+ ):
+ """
+ ## Get $\epsilon(x_t, c)$
+
+ :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
+ :param t: is $t$ of shape `[batch_size]`
+ :param background_cond: background condition
+ :param autoreg_cond: autoregressive condition
+ :param external_cond: external condition
+ :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
+ :param uncond_cond: is the conditional embedding for empty prompt $c_u$
+ """
+ # When the scale $s = 1$
+ # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
+
+ batch_size = x.size(0)
+
+ # if hasattr(self.model, 'style_enc'):
+ # if external_cond is not None:
+ # external_cond = self.model.external_cond_enc(external_cond)
+ # if uncond_scale is None or uncond_scale == 1:
+ # external_uncond = (-torch.ones_like(external_cond)).to(self.device)
+ # else:
+ # external_uncond = None
+ # # if random.random() < 0.2:
+ # # external_cond = (-torch.ones_like(external_cond)).to(self.device)
+ # else:
+ # external_cond = -torch.ones(batch_size, 4, self.d_cond, device=x.device, dtype=x.dtype)
+ # external_uncond = None
+ # cond = torch.cat([autoreg_cond, external_cond], 1)
+ # if external_uncond is None:
+ # uncond = None
+ # else:
+ # uncond = torch.cat([autoreg_cond, external_uncond], 1)
+ # else:
+ # cond = autoreg_cond
+ # uncond = None
+
+ if background_cond is not None:
+ x = torch.cat([x, background_cond], 1) if background_cond is not None else x
+
+ # if uncond is None:
+ # e_t = self.model(x, t, cond)
+ # else:
+ # e_t_cond = self.model(x, t, cond)
+ # e_t_uncond = self.model(x, t, uncond)
+ # e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
+
+ e_t = self.model(x,t)
+ return e_t
+
+ @torch.no_grad()
+ def p_sample(
+ self,
+ x: torch.Tensor,
+ background_cond: Optional[torch.Tensor],
+ #autoreg_cond: Optional[torch.Tensor],
+ #external_cond: Optional[torch.Tensor],
+ t: torch.Tensor,
+ step: int,
+ repeat_noise: bool = False,
+ temperature: float = 1.,
+ uncond_scale: float = 1.,
+ same_noise_all_measure: bool = False,
+ X0EditFunc = None,
+ use_classifier_free_guidance = False,
+ use_lsh = False,
+ reduce_extra_notes=True,
+ rhythm_control="Yes",
+ ):
+ print("p_sample")
+ """
+ ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
+
+ :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
+ :param background_cond: background condition
+ :param autoreg_cond: autoregressive condition
+ :param external_cond: external condition
+ :param t: is $t$ of shape `[batch_size]`
+ :param step: is the step $t$ as an integer
+ :param repeat_noise: specified whether the noise should be same for all samples in the batch
+ :param temperature: is the noise temperature (random noise gets multiplied by this)
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
+ """
+ # Get current tau_i and tau_{i-1}
+ tau_i = self.tau[t]
+ step_tau_i = self.tau[step]
+
+ # Get $\epsilon_\theta$
+ with self.autocast:
+ if use_classifier_free_guidance:
+ if use_lsh:
+ assert background_cond.shape[1] == 6 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain, lsh_onset, lsh_sustain
+ null_lsh = -torch.ones_like(background_cond[:,4:,:,:])
+ null_background_cond = torch.cat([background_cond[:,2:4,:,:], null_lsh], axis=1)
+ real_background_cond = torch.cat([background_cond[:,:2,:,:], background_cond[:,4:,:,:]], axis=1)
+
+ e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
+ e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
+ e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
+ if self.guidance_rescale > 0:
+ e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
+ else:
+ assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain
+ null_background_cond = background_cond[:,2:,:,:]
+ real_background_cond = background_cond[:,:2,:,:]
+ e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
+ e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
+ e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
+ if self.guidance_rescale > 0:
+ e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
+ else:
+ if use_lsh:
+ assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, lsh_onset, lsh_sustain
+ e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
+ else:
+ assert background_cond.shape[1] == 2 # chd_onset, chd_sustain
+ e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
+
+ # Get batch size
+ bs = x.shape[0]
+
+ # $\frac{1}{\sqrt{\bar\alpha}}$
+ one_over_sqrt_alpha_bar = x.new_full(
+ (bs, 1, 1, 1), self.one_over_sqrt_alpha_bar[step_tau_i]
+ )
+ # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
+ sqrt_1m_alpha_bar_over_sqrt_alpha_bar = x.new_full(
+ (bs, 1, 1, 1), self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar[step_tau_i]
+ )
+
+ # $\sigma_t$ in DDIM
+ sigma_ddim = x.new_full(
+ (bs, 1, 1, 1), self.sigma_ddim[step_tau_i]
+ )
+
+
+ # Calculate $x_0$ with current $\epsilon_\theta$
+ #
+ # predicted x_0 in DDIM
+ predicted_x0 = one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - sqrt_1m_alpha_bar_over_sqrt_alpha_bar * e_tau_i
+
+ # edit predicted x_0
+ if X0EditFunc is not None:
+ predicted_x0 = X0EditFunc(predicted_x0, background_cond, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control)
+ e_tau_i = (one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - predicted_x0) / sqrt_1m_alpha_bar_over_sqrt_alpha_bar
+
+ # Do not add noise when $t = 1$ (final step sampling process).
+ # Note that `step` is `0` when $t = 1$)
+ if step == 0:
+ noise = 0
+ # If same noise is used for all samples in the batch
+ elif repeat_noise:
+ if same_noise_all_measure:
+ noise = torch.randn((1, predicted_x0.shape[1], 16, predicted_x0.shape[3]), device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
+ else:
+ noise = torch.randn((1, *predicted_x0.shape[1:]), device=self.device)
+ # Different noise for each sample
+ else:
+ if same_noise_all_measure:
+ noise = torch.randn(predicted_x0.shape[0], predicted_x0.shape[1], 16, predicted_x0.shape[3], device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
+ else:
+ noise = torch.randn(predicted_x0.shape, device=self.device)
+
+ # Multiply noise by the temperature
+ noise = noise * temperature
+
+ if step > 0:
+ step_tau_i_m_1 = self.tau[step-1]
+ # $\sqrt{\bar\alpha_{\tau_i-1}}$
+ sqrt_alpha_bar_prev = x.new_full(
+ (bs, 1, 1, 1), self.sqrt_alpha_bar[step_tau_i_m_1]
+ )
+ # $\sqrt{1-\bar\alpha_{\tau_i-1}-\sigma_\tau^2}$
+ sqrt_1m_alpha_bar_prev_m_sigma2 = x.new_full(
+ (bs, 1, 1, 1), (1 - self.alpha_bar[step_tau_i_m_1] - self.sigma_ddim[step_tau_i] ** 2) ** 0.5
+ )
+ direction_to_xt = sqrt_1m_alpha_bar_prev_m_sigma2 * e_tau_i
+ x_prev = sqrt_alpha_bar_prev * predicted_x0 + direction_to_xt + sigma_ddim * noise
+ else:
+ x_prev = predicted_x0 + sigma_ddim * noise
+
+ # Sample from,
+ #
+ # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
+
+ #
+ return x_prev, predicted_x0, e_tau_i
+
+ @torch.no_grad()
+ def q_sample(
+ self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
+ ):
+ """
+ ### Sample from $q(x_t|x_0)$
+
+ $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
+
+ :param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
+ :param index: is the time step $t$ index
+ :param noise: is the noise, $\epsilon$
+ """
+
+ # Random noise, if noise is not specified
+ if noise is None:
+ noise = torch.randn_like(x0, device=self.device)
+
+ # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
+ return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise
+
+ @torch.no_grad()
+ def sample(
+ self,
+ shape: List[int],
+ background_cond: Optional[torch.Tensor] = None,
+ #autoreg_cond: Optional[torch.Tensor] = None,
+ #external_cond: Optional[torch.Tensor] = None,
+ repeat_noise: bool = False,
+ temperature: float = 1.,
+ uncond_scale: float = 1.,
+ x_last: Optional[torch.Tensor] = None,
+ t_start: int = 0,
+ same_noise_all_measure: bool = False,
+ X0EditFunc = None,
+ use_classifier_free_guidance = False,
+ use_lsh = False,
+ reduce_extra_notes=True,
+ rhythm_control="Yes",
+ ):
+ """
+ ### Sampling Loop
+
+ :param shape: is the shape of the generated images in the
+ form `[batch_size, channels, height, width]`
+ :param background_cond: background condition
+ :param autoreg_cond: autoregressive condition
+ :param external_cond: external condition
+ :param repeat_noise: specified whether the noise should be same for all samples in the batch
+ :param temperature: is the noise temperature (random noise gets multiplied by this)
+ :param x_last: is $x_T$. If not provided random noise will be used.
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
+ :param t_start: t_start
+ """
+
+ # Get device and batch size
+ bs = shape[0]
+
+ ######
+ print(shape)
+ ######
+
+
+ # Get $x_T$
+ if same_noise_all_measure:
+ x = x_last if x_last is not None else torch.randn(shape[0],shape[1],16,shape[3], device=self.device).repeat(1,1,int(shape[2]/16),1)
+ else:
+ x = x_last if x_last is not None else torch.randn(shape, device=self.device)
+
+ # Time steps to sample at $T - t', T - t' - 1, \dots, 1$
+ time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
+
+ # Sampling loop
+ for step in monit.iterate('Sample', time_steps):
+ # Time step $t$
+ ts = x.new_full((bs, ), step, dtype=torch.long)
+
+ x, pred_x0, e_t = self.p_sample(
+ x,
+ background_cond,
+ #autoreg_cond,
+ #external_cond,
+ ts,
+ step,
+ repeat_noise=repeat_noise,
+ temperature=temperature,
+ uncond_scale=uncond_scale,
+ same_noise_all_measure=same_noise_all_measure,
+ X0EditFunc = X0EditFunc,
+ use_classifier_free_guidance = use_classifier_free_guidance,
+ use_lsh=use_lsh,
+ reduce_extra_notes=reduce_extra_notes,
+ rhythm_control=rhythm_control
+ )
+
+ s1 = step + 1
+
+ # if self.is_show_image:
+ # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
+ # show_image(x, f"exp/img/x{s1}.png")
+
+ # Return $x_0$
+ # if self.is_show_image:
+ # show_image(x, f"exp/img/x0.png")
+
+ return x
+
+ @torch.no_grad()
+ def paint(
+ self,
+ x: Optional[torch.Tensor] = None,
+ background_cond: Optional[torch.Tensor] = None,
+ #autoreg_cond: Optional[torch.Tensor] = None,
+ #external_cond: Optional[torch.Tensor] = None,
+ t_start: int = 0,
+ orig: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ orig_noise: Optional[torch.Tensor] = None,
+ uncond_scale: float = 1.,
+ same_noise_all_measure: bool = False,
+ X0EditFunc = None,
+ use_classifier_free_guidance = False,
+ use_lsh = False,
+ ):
+ """
+ ### Painting Loop
+
+ :param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]`
+ :param background_cond: background condition
+ :param autoreg_cond: autoregressive condition
+ :param external_cond: external condition
+ :param t_start: is the sampling step to start from, $S'$
+ :param orig: is the original image in latent page which we are in paining.
+ If this is not provided, it'll be an image to image transformation.
+ :param mask: is the mask to keep the original image.
+ :param orig_noise: is fixed noise to be added to the original image.
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
+ """
+ # Get batch size
+ bs = orig.size(0)
+
+ if x is None:
+ x = torch.randn(orig.shape, device=self.device)
+
+ # Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$
+ # time_steps = np.flip(self.time_steps[: t_start])
+ time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
+
+ for i, step in monit.enum('Paint', time_steps):
+ # Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
+ # index = len(time_steps) - i - 1
+ # Time step $\tau_i$
+ ts = x.new_full((bs, ), step, dtype=torch.long)
+
+ # Sample $x_{\tau_{i-1}}$
+ x, _, _ = self.p_sample(
+ x,
+ background_cond,
+ #autoreg_cond,
+ #external_cond,
+ t=ts,
+ step=step,
+ uncond_scale=uncond_scale,
+ same_noise_all_measure=same_noise_all_measure,
+ X0EditFunc = X0EditFunc,
+ use_classifier_free_guidance = use_classifier_free_guidance,
+ use_lsh=use_lsh,
+ )
+
+ # Replace the masked area with original image
+ if orig is not None:
+ assert mask is not None
+ # Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space
+ orig_t = self.q_sample(orig, self.tau[step], noise=orig_noise)
+ # Replace the masked area
+ x = orig_t * mask + x * (1 - mask)
+
+ s1 = step + 1
+
+ # if self.is_show_image:
+ # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
+ # show_image(x, f"exp/img/x{s1}.png")
+
+ # if self.is_show_image:
+ # show_image(x, f"exp/img/x0.png")
+ return x
+
+ def generate(self, background_cond=None, batch_size=1, uncond_scale=None,
+ same_noise_all_measure=False, X0EditFunc=None,
+ use_classifier_free_guidance=False, use_lsh=False, reduce_extra_notes=True, rhythm_control="Yes"):
+
+ shape = [batch_size, self.out_channel, self.max_l, self.h]
+
+ if self.debug_mode:
+ return torch.randn(shape, dtype=torch.float)
+
+ return self.sample(shape, background_cond, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure,
+ X0EditFunc=X0EditFunc, use_classifier_free_guidance=use_classifier_free_guidance, use_lsh=use_lsh,
+ reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control
+ )
+
diff --git a/output_0.mid b/output_0.mid
new file mode 100644
index 0000000000000000000000000000000000000000..26a8e00407f4dfa8be73d4643ab5a6560dbc4462
Binary files /dev/null and b/output_0.mid differ
diff --git a/output_0.wav b/output_0.wav
new file mode 100644
index 0000000000000000000000000000000000000000..86b791aa3e6a84ec9d2732490dbb0908b4af8969
--- /dev/null
+++ b/output_0.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9abffb8b039f86161f025cab6419eecf93ec741ea67e66964ca8e79d333c9d4
+size 2469720
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5cb2b5f4bb78a09b9b44fdc10a858d09f83a5b07
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+TiMidity++
\ No newline at end of file
diff --git a/piano_roll.png b/piano_roll.png
new file mode 100644
index 0000000000000000000000000000000000000000..584ba9d69def93476be0a7bf04cedc490060da97
--- /dev/null
+++ b/piano_roll.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf4d433689089c7895ed3ebf569e2dda9284a8d443481f6d8aa9f9575089cc37
+size 16389
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1628aed3bcc7a7c613f97c177380a9295e56eed5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+gradio#==4.44.0
+imageio#==2.35.1
+imageio[ffmpeg]
+labml#==0.5.3
+librosa
+mir_eval
+matplotlib
+music21
+numba==0.53.1
+numpy==1.19.5
+opencv-python
+# pandas==1.2.5
+pretty_midi
+pydub
+requests
+soundfile
+fluidsynth
+scikit-learn
+torch==2.4.1
+torchvision
+tqdm
+tensorboard
\ No newline at end of file
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f1e3c65efd105af1dfb49178192f4a1c01647c56
--- /dev/null
+++ b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:852cfd7b011bb1ba0ce1d0d05a7acd672c7cd4934756b2e0d357d8002b5ecb6b
+size 441623592
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0
new file mode 100644
index 0000000000000000000000000000000000000000..4973f25dbe7aff4d6b1de8911ea7fdb6c9402412
Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 differ
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2
new file mode 100644
index 0000000000000000000000000000000000000000..2282c4934cc1887fd2bcbb353a3f38a358f5ad73
Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 differ
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1
new file mode 100644
index 0000000000000000000000000000000000000000..64251a673711c650dbdb300800991466aca755c4
Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 differ
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4
new file mode 100644
index 0000000000000000000000000000000000000000..5b10b32f0a0fd3ab8d9aefa118dfae26b9bc0a04
Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 differ
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3
new file mode 100644
index 0000000000000000000000000000000000000000..e2781e652ddfeb83f30c1dfc006e7d67a4081eca
Binary files /dev/null and b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 differ
diff --git a/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json
new file mode 100644
index 0000000000000000000000000000000000000000..d5eff601a3fbd0fce6396ef1b785330777b4e0b8
--- /dev/null
+++ b/results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json
@@ -0,0 +1 @@
+{"batch_size": 16, "max_epoch": 10, "learning_rate": 5e-05, "max_grad_norm": 10, "fp16": true, "in_channels": 6, "out_channels": 2, "channels": 64, "attention_levels": [2, 3], "n_res_blocks": 2, "channel_multipliers": [1, 2, 4, 4], "n_heads": 4, "tf_layers": 1, "d_cond": 2, "linear_start": 0.00085, "linear_end": 0.012, "n_steps": 1000, "latent_scaling_factor": 0.18215}
\ No newline at end of file
diff --git a/rhythm_plot_0.png b/rhythm_plot_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..0513bc832e5434f65f06b11a2039a3c83cfea7ad
--- /dev/null
+++ b/rhythm_plot_0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:967207ba25f584c5f1559beea33807766b5c7472a025e4e9e182b82e3876e143
+size 11476
diff --git a/runtime.txt b/runtime.txt
new file mode 100644
index 0000000000000000000000000000000000000000..032aea2dbcb01c94cd1199869e1e97d9dc351e2e
--- /dev/null
+++ b/runtime.txt
@@ -0,0 +1 @@
+python-3.9
\ No newline at end of file
diff --git a/samples/control_vs_uncontrol/example_1_acc_control.jpg b/samples/control_vs_uncontrol/example_1_acc_control.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..72932aacab9f1c895623c59988f9a97c4608f689
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_acc_control.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b1addf1c8a495c2cd6dc3f4c2fe2d86139e6808ad839ab8fbc7070bf8d02314
+size 122732
diff --git a/samples/control_vs_uncontrol/example_1_acc_control.wav b/samples/control_vs_uncontrol/example_1_acc_control.wav
new file mode 100644
index 0000000000000000000000000000000000000000..932a8f6f3180e528abd10fdd806a1e4a51e3cfe9
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_acc_control.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98254caf2f2943c0f4bf4a07cfddf0c0dfa8bc33f41e5aafcd2626f33763b680
+size 2469720
diff --git a/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg b/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c5f84c39af65b24d2813462c7efdf99bda35062e
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c857985742375d62e30b6a7ad86c2ac27978a5d2f9417a01d85cadf7aa87042
+size 136185
diff --git a/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav b/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7775b7cef3c71484f71351381c36ad89ecfc4d9c
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_acc_uncontrol.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:582e9d559424e40216f3d7e2ded5b07844311da1783fd77792b58a27fa27ba8c
+size 2469720
diff --git a/samples/control_vs_uncontrol/example_1_mel_chd.jpg b/samples/control_vs_uncontrol/example_1_mel_chd.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e91525f4de913dc0b0c1e1cd9676491647a8dae7
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_mel_chd.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7308a7c0255420a712837437cbb37a567b2c01606c0d020caa4d44f94a1f8464
+size 71507
diff --git a/samples/control_vs_uncontrol/example_1_mel_chd.wav b/samples/control_vs_uncontrol/example_1_mel_chd.wav
new file mode 100644
index 0000000000000000000000000000000000000000..59df9f6eee415112786a72c92103881fd9af7a86
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_1_mel_chd.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cedfa10bc44605e9cd1c0a04a82d90f7b6c31dfe68d790a7d9c8e6e27117c90e
+size 5292046
diff --git a/samples/control_vs_uncontrol/example_2_acc_control.jpg b/samples/control_vs_uncontrol/example_2_acc_control.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8dce49aec16af5785380c1fbf83b20b5b709162e
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_acc_control.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e7a9691db8538d58c3d4fff42184c3aac32a73849cc4f2c14d4ec540e9f5d30
+size 96583
diff --git a/samples/control_vs_uncontrol/example_2_acc_control.wav b/samples/control_vs_uncontrol/example_2_acc_control.wav
new file mode 100644
index 0000000000000000000000000000000000000000..717fb8d86981d66e25122fdaa5f96421ea6b509a
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_acc_control.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6e70d13d7d97c0cb3cbf38d9ad2473f5df5eb04c69a54ddb228aca1134f819a
+size 2364108
diff --git a/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg b/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..eae50dc0c93519922b38abbabcec7d2969d89050
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01da8567d60a0ab49c4d0e9a78b21ab1a102e196c2eccb39c30a5a721f1998d9
+size 108920
diff --git a/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav b/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ff4b8970c5d4f8a86e39ee413160b4e052b6b0c3
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_acc_uncontrol.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a60053aca94dac302379e536774832a3f90bf2fd7a0d7a23c2dc97b742f78910
+size 2364108
diff --git a/samples/control_vs_uncontrol/example_2_mel_chd.jpg b/samples/control_vs_uncontrol/example_2_mel_chd.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c9bdbd0c363b6cf36204117081d97b75679e03f9
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_mel_chd.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad9825c1cff525d4f348d60e77e8a28769e58a4ef93730a4c8fcfc8394f40e92
+size 60218
diff --git a/samples/control_vs_uncontrol/example_2_mel_chd.wav b/samples/control_vs_uncontrol/example_2_mel_chd.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2b0657befab4806dc202f3fe80ad49c605ac039f
--- /dev/null
+++ b/samples/control_vs_uncontrol/example_2_mel_chd.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cafc0f0a3c375155457479eb2403e17aead89cd78aa021cdfeec78301930ee4f
+size 5292046
diff --git a/samples/different_styles/chinese_1.wav b/samples/different_styles/chinese_1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e131cf54c032e1fc1eb0601baae6042aa5b09aaa
--- /dev/null
+++ b/samples/different_styles/chinese_1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0cb5284cedced19d23ea5471ada1b3a9357326c64440170ecf05eec74fecf53f
+size 2364108
diff --git a/samples/different_styles/chinese_2.wav b/samples/different_styles/chinese_2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8ebc18741528e863de67ed8b532fa7de9819322b
--- /dev/null
+++ b/samples/different_styles/chinese_2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84ae2c4c480d225239fd28aec180d83ff13b2a3d58074d22fc78e018ebf2d790
+size 2364108
diff --git a/samples/different_styles/chinese_scale.jpg b/samples/different_styles/chinese_scale.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..83c0d10e24c84ddc19bd3df6634b3bebc2ad060a
--- /dev/null
+++ b/samples/different_styles/chinese_scale.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bafa09570c3f9413ddc6e2f643e935ec6932e0e5a20c070f1a617ace53e6f08d
+size 23100
diff --git a/samples/different_styles/dorian_1.wav b/samples/different_styles/dorian_1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..43fff5448aa6534540147fd2f9045b4dadf39871
--- /dev/null
+++ b/samples/different_styles/dorian_1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8913ce841a7b2e8b6eed4c4ffc87da3e22478908fecf8d2965d46790c44d29de
+size 2364108
diff --git a/samples/different_styles/dorian_2.wav b/samples/different_styles/dorian_2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..02a2b6a05e66a3631ff3d57ef1820365834df828
--- /dev/null
+++ b/samples/different_styles/dorian_2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8746fe2cabc9997ace5584e49dc42d96f71f30fe869eed286310e71ffc1442ed
+size 2364108
diff --git a/samples/different_styles/dorian_scale.jpg b/samples/different_styles/dorian_scale.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..604cd0d906592469b94e7a5fa0904fc16795a97f
--- /dev/null
+++ b/samples/different_styles/dorian_scale.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78380800a624c05b867e63f92b13f27008b4ca5217d622849a79fbe7f2c3a58b
+size 26351
diff --git a/samples/diy_examples/example1/example1.jpg b/samples/diy_examples/example1/example1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1817ef2b7cfe25ecdeeb3f330c6772b063d6a211
--- /dev/null
+++ b/samples/diy_examples/example1/example1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e93fb967cd7b52ea6435d775f9533053cc917b90745b3a56eedc4e4a16b9dc27
+size 80422
diff --git a/samples/diy_examples/example1/example1.npy b/samples/diy_examples/example1/example1.npy
new file mode 100644
index 0000000000000000000000000000000000000000..341c0740f47d5ff51a4b48eadc3fd1ba9651d0ae
Binary files /dev/null and b/samples/diy_examples/example1/example1.npy differ
diff --git a/samples/diy_examples/example1/example_1_mel.wav b/samples/diy_examples/example1/example_1_mel.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7b279f604b1cfefbd98a72acc0c1068a641c86ba
--- /dev/null
+++ b/samples/diy_examples/example1/example_1_mel.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99de59651e0df5796d02cd08eb3980b09c11319f52d4d4c074a61cfc9f6f4e06
+size 2626764
diff --git a/samples/diy_examples/example1/sample1.mid b/samples/diy_examples/example1/sample1.mid
new file mode 100644
index 0000000000000000000000000000000000000000..bd85386415644767c3239663970b6df26b654bdf
Binary files /dev/null and b/samples/diy_examples/example1/sample1.mid differ
diff --git a/samples/diy_examples/example1/sample1.wav b/samples/diy_examples/example1/sample1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2424674521564a3206ed1ade75faf5912066d9db
--- /dev/null
+++ b/samples/diy_examples/example1/sample1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d81863d650333e13918c6f39b550fc301e268370c9f257ae9a03c4d0cb24708
+size 2626764
diff --git a/samples/diy_examples/example2/example2.jpg b/samples/diy_examples/example2/example2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4f758687e062b1d70cda2a42137be5dd5002a516
--- /dev/null
+++ b/samples/diy_examples/example2/example2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:136c3f4e30ac8c1a3860f689f82f0f59e2f2c1666671113d193ba4413163c864
+size 87430
diff --git a/samples/diy_examples/example2/example2.mid b/samples/diy_examples/example2/example2.mid
new file mode 100644
index 0000000000000000000000000000000000000000..8b4df24b1be2d570379974985fd2d3737a763174
Binary files /dev/null and b/samples/diy_examples/example2/example2.mid differ
diff --git a/samples/diy_examples/example2/example2.npy b/samples/diy_examples/example2/example2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d988a4c733ba4801ec59b82ae0c15d0e28a2f905
Binary files /dev/null and b/samples/diy_examples/example2/example2.npy differ
diff --git a/samples/diy_examples/example2/example_2_mel.wav b/samples/diy_examples/example2/example_2_mel.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ed67a8b610939ab66988c17887a34a97f9b7862d
--- /dev/null
+++ b/samples/diy_examples/example2/example_2_mel.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f545731a0d7b01b73df05fe2c5e1aba995365e685b276c81a3073d7dc054929b
+size 2364108
diff --git a/samples/diy_examples/example2/sample1.mid b/samples/diy_examples/example2/sample1.mid
new file mode 100644
index 0000000000000000000000000000000000000000..ab0aaae10c65ba748bd6c5f7a1f02987c0cd51bc
Binary files /dev/null and b/samples/diy_examples/example2/sample1.mid differ
diff --git a/samples/diy_examples/example2/sample1.wav b/samples/diy_examples/example2/sample1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..807079ce13e1af6e7927044d7e5ec446ae22e336
--- /dev/null
+++ b/samples/diy_examples/example2/sample1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae68eb033e06bc43f9c26eceaed67bac4745a481c42492a66ad8400ca56b3a9d
+size 2364108
diff --git a/samples/diy_examples/example3/example3.jpg b/samples/diy_examples/example3/example3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9da8c41b79887955a4933dba653826f85aad63fa
--- /dev/null
+++ b/samples/diy_examples/example3/example3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45deba5c0642d156ba05df0f5ebacda82cfc3d786704560278fcefec9f94c4a0
+size 101831
diff --git a/samples/diy_examples/example3/example3.mid b/samples/diy_examples/example3/example3.mid
new file mode 100644
index 0000000000000000000000000000000000000000..5216951964b117ca02104e155e737b7d6b6dc53c
Binary files /dev/null and b/samples/diy_examples/example3/example3.mid differ
diff --git a/samples/diy_examples/example3/example3.npy b/samples/diy_examples/example3/example3.npy
new file mode 100644
index 0000000000000000000000000000000000000000..65780f67cfedb735b6fbf11132a6ae33d71d3013
Binary files /dev/null and b/samples/diy_examples/example3/example3.npy differ
diff --git a/samples/diy_examples/example3/example_3_mel.wav b/samples/diy_examples/example3/example_3_mel.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7f13b8aa125923bf8a3733ab1466e2900a3621c8
--- /dev/null
+++ b/samples/diy_examples/example3/example_3_mel.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a1206053ff85d6ccb5449206a2395bf220dcd88c5cad430eccdefea9c4956d9
+size 2364108
diff --git a/samples/diy_examples/example3/sample1.mid b/samples/diy_examples/example3/sample1.mid
new file mode 100644
index 0000000000000000000000000000000000000000..5e8e3593d1458c68ab301ae36b800a89e5afeb29
Binary files /dev/null and b/samples/diy_examples/example3/sample1.mid differ
diff --git a/samples/diy_examples/example3/sample1.wav b/samples/diy_examples/example3/sample1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..659e97eaceac9022ac37f4663360d39be7cc756a
--- /dev/null
+++ b/samples/diy_examples/example3/sample1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75c4e5568759f78e30f346f6a928ea839063259068d9628dea5b47db249a1361
+size 2364108
diff --git a/samples/diy_examples/example4/example4.jpg b/samples/diy_examples/example4/example4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..036f138c273a02b9c7755987aa52cb818122768a
--- /dev/null
+++ b/samples/diy_examples/example4/example4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:974035fcdc0f415c6cc41acb7ffd013e5499930169c156f65ed645ed5f24eccf
+size 77774
diff --git a/samples/diy_examples/example4/example4.npy b/samples/diy_examples/example4/example4.npy
new file mode 100644
index 0000000000000000000000000000000000000000..334289164de39ad1283ec21e41ddfc75218ea648
Binary files /dev/null and b/samples/diy_examples/example4/example4.npy differ
diff --git a/samples/diy_examples/example4/example_4_mel.wav b/samples/diy_examples/example4/example_4_mel.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4b53f0f93530a7069302bd6068df8ec3d0f54f9d
--- /dev/null
+++ b/samples/diy_examples/example4/example_4_mel.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d71a72473f89275da9f3947bb702ea037f17938cb71ccb504aa693b833a90212
+size 2364108
diff --git a/samples/diy_examples/example4/sample1.mid b/samples/diy_examples/example4/sample1.mid
new file mode 100644
index 0000000000000000000000000000000000000000..df3c989fa9d75069fd030bc08e0d531e212cd763
Binary files /dev/null and b/samples/diy_examples/example4/sample1.mid differ
diff --git a/samples/diy_examples/example4/sample1.wav b/samples/diy_examples/example4/sample1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a6bf2312a42fbe9044be9e0069ee3c7cfe9c7b5a
--- /dev/null
+++ b/samples/diy_examples/example4/sample1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:666b08d944c809662f3781f373bbfbd83af18925d6acc9004944598da532aa09
+size 2364108
diff --git a/samples/diy_examples/rhythm_plot_1.png b/samples/diy_examples/rhythm_plot_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..bfcb7d5623206b2441c4b890e95e35a68465edf9
--- /dev/null
+++ b/samples/diy_examples/rhythm_plot_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe87b8025fb47e2a41cae493aeb69bc9690784349e646d1db80bf2e100a47632
+size 11844
diff --git a/samples/diy_examples/rhythm_plot_2.png b/samples/diy_examples/rhythm_plot_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..bc0ceda8f6bdda7b24635d89f061fee119f89691
--- /dev/null
+++ b/samples/diy_examples/rhythm_plot_2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13e5ee8c7bad342216a57b22486680022112a605e1b45e6b951b148daab6152d
+size 10835
diff --git a/samples/diy_examples/rhythm_plot_3.png b/samples/diy_examples/rhythm_plot_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..961d8753080221d7875b51708f7088e2926cdece
--- /dev/null
+++ b/samples/diy_examples/rhythm_plot_3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b57d1e070f9cd784ffb505c2c0d05593ea9ffd3cf07a718ead783645ce3562c
+size 12093
diff --git a/samples/diy_examples/rhythm_plot_4.png b/samples/diy_examples/rhythm_plot_4.png
new file mode 100644
index 0000000000000000000000000000000000000000..bfcb7d5623206b2441c4b890e95e35a68465edf9
--- /dev/null
+++ b/samples/diy_examples/rhythm_plot_4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe87b8025fb47e2a41cae493aeb69bc9690784349e646d1db80bf2e100a47632
+size 11844
diff --git a/samples/diy_examples/rhythm_plot_default.png b/samples/diy_examples/rhythm_plot_default.png
new file mode 100644
index 0000000000000000000000000000000000000000..9d258ef3cf18ae16a480b9f1997134356b41e4e6
--- /dev/null
+++ b/samples/diy_examples/rhythm_plot_default.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:931d7148bbb618c0befbab6d0c26fa3d8b65acdc87ca71b94686404d3b389cb3
+size 11216
diff --git a/train/__init__.py b/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..535b969dee506b444a93a6421142452abd0f47b2
--- /dev/null
+++ b/train/__init__.py
@@ -0,0 +1,45 @@
+import torch
+import json
+import os
+from datetime import datetime
+from torch.utils.data import DataLoader
+from torch.optim import Optimizer
+from .learner import DiffproLearner
+
+
+class TrainConfig:
+
+ model: torch.nn.Module
+ train_dl: DataLoader
+ val_dl: DataLoader
+ optimizer: Optimizer
+
+ def __init__(self, params, param_scheduler, output_dir) -> None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.params = params
+ self.param_scheduler = param_scheduler
+ self.output_dir = output_dir
+
+ def train(self):
+ # collect and display total parameters
+ total_parameters = sum(
+ p.numel() for p in self.model.parameters() if p.requires_grad
+ )
+ print(f"Total parameters: {total_parameters}")
+
+ # dealing with the output storing
+ output_dir = self.output_dir
+ if os.path.exists(f"{output_dir}/chkpts/weights.pt"):
+ print("Checkpoint already exists.")
+ if input("Resume training? (y/n)") != "y":
+ return
+ else:
+ output_dir = f"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}"
+ print(f"Creating new log folder as {output_dir}")
+
+ # prepare the learner structure and parameters
+ learner = DiffproLearner(
+ output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,
+ self.params, self.param_scheduler
+ )
+ learner.train(max_epoch=self.params.max_epoch)
diff --git a/train/__pycache__/__init__.cpython-39.pyc b/train/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bfd6354b8745c9dd31372b585d2769b58c515e9
Binary files /dev/null and b/train/__pycache__/__init__.cpython-39.pyc differ
diff --git a/train/__pycache__/learner.cpython-39.pyc b/train/__pycache__/learner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14829af006935184e207fdac73839197af8d5a25
Binary files /dev/null and b/train/__pycache__/learner.cpython-39.pyc differ
diff --git a/train/__pycache__/train_config.cpython-39.pyc b/train/__pycache__/train_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0588530e141a4d1ca0b80c7fe20b44c0cdb2b98
Binary files /dev/null and b/train/__pycache__/train_config.cpython-39.pyc differ
diff --git a/train/__pycache__/train_params.cpython-39.pyc b/train/__pycache__/train_params.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6423405379640e108400c53c89db11e8451057b3
Binary files /dev/null and b/train/__pycache__/train_params.cpython-39.pyc differ
diff --git a/train/learner.py b/train/learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb6a9a7cfd5ca60366a81cb346ef4b34008fe851
--- /dev/null
+++ b/train/learner.py
@@ -0,0 +1,211 @@
+import torch
+import json
+import torch.nn as nn
+from tqdm import tqdm
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Optional
+import os
+
+
+def nested_map(struct, map_fn):
+ """This is for trasfering into cuda device"""
+ if isinstance(struct, tuple):
+ return tuple(nested_map(x, map_fn) for x in struct)
+ if isinstance(struct, list):
+ return [nested_map(x, map_fn) for x in struct]
+ if isinstance(struct, dict):
+ return {k: nested_map(v, map_fn) for k, v in struct.items()}
+ return map_fn(struct)
+
+
+class DiffproLearner:
+ def __init__(
+ self, output_dir, model, train_dl, val_dl, optimizer, params
+ ):
+ # model output
+ self.output_dir = output_dir
+ self.log_dir = f"{output_dir}/logs"
+ self.checkpoint_dir = f"{output_dir}/chkpts"
+ # model (architecture and loss)
+ self.model = model
+ # data loader
+ self.train_dl = train_dl
+ self.val_dl = val_dl
+ # optimizer
+ self.optimizer = optimizer
+ # what is this ????
+ self.params = params
+ # current time recoder
+ self.step = 0
+ self.epoch = 0
+ self.grad_norm = 0.
+ # other information
+ self.summary_writer = None
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)
+ self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)
+
+ self.best_val_loss = torch.tensor([1e10], device=self.device)
+
+ # restore if directory exists
+ if os.path.exists(self.output_dir):
+ self.restore_from_checkpoint()
+ else:
+ os.makedirs(self.output_dir)
+ os.makedirs(self.log_dir)
+ os.makedirs(self.checkpoint_dir)
+ with open(f"{output_dir}/params.json", "w") as params_file:
+ json.dump(self.params, params_file)
+
+ print(json.dumps(self.params, sort_keys=True, indent=4))
+
+ def _write_summary(self, losses: dict, scheduled_params: Optional[dict], type):
+ """type: train or val"""
+ summary_losses = losses
+ summary_losses["grad_norm"] = self.grad_norm
+ if scheduled_params is not None:
+ for k, v in scheduled_params.items():
+ summary_losses[f"sched_{k}"] = v
+ writer = self.summary_writer or SummaryWriter(
+ self.log_dir, purge_step=self.step
+ )
+ writer.add_scalars(type, summary_losses, self.step)
+ writer.flush()
+ self.summary_writer = writer
+
+ def state_dict(self):
+ # state dictionary
+ model_state = self.model.state_dict()
+ return {
+ "step": self.step,
+ "epoch": self.epoch,
+ "model":
+ {
+ k: v.cpu() if isinstance(v, torch.Tensor) else v
+ for k, v in model_state.items()
+ },
+ "optimizer":
+ {
+ k: v.cpu() if isinstance(v, torch.Tensor) else v
+ for k, v in self.optimizer.state_dict().items()
+ },
+ "scaler": self.scaler.state_dict(),
+ }
+
+ def load_state_dict(self, state_dict):
+ self.step = state_dict["step"]
+ self.epoch = state_dict["epoch"]
+ self.model.load_state_dict(state_dict["model"])
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+ self.scaler.load_state_dict(state_dict["scaler"])
+
+ def restore_from_checkpoint(self, fname="weights"):
+ try:
+ fpath = f"{self.checkpoint_dir}/{fname}.pt"
+ checkpoint = torch.load(fpath)
+ self.load_state_dict(checkpoint)
+ print(f"Restored from checkpoint {fpath} --> {fname}-{self.epoch}.pt!")
+ return True
+ except FileNotFoundError:
+ print("No checkpoint found. Starting from scratch...")
+ return False
+
+ def _link_checkpoint(self, save_name, link_fpath):
+ if os.path.islink(link_fpath):
+ os.unlink(link_fpath)
+ os.symlink(save_name, link_fpath)
+
+ def save_to_checkpoint(self, fname="weights", is_best=False):
+ save_name = f"{fname}-{self.epoch}.pt"
+ save_fpath = f"{self.checkpoint_dir}/{save_name}"
+ link_best_fpath = f"{self.checkpoint_dir}/{fname}_best.pt"
+ link_fpath = f"{self.checkpoint_dir}/{fname}.pt"
+ torch.save(self.state_dict(), save_fpath)
+ self._link_checkpoint(save_name, link_fpath)
+ if is_best:
+ self._link_checkpoint(save_name, link_best_fpath)
+
+ def train(self, max_epoch=None):
+ self.model.train()
+
+ while True:
+ self.epoch = self.step // len(self.train_dl)
+ if max_epoch is not None and self.epoch >= max_epoch:
+ return
+
+ for batch in tqdm(self.train_dl, desc=f"Epoch {self.epoch}"):
+ #print("type of batch:", type(batch))
+ batch = nested_map(
+ batch, lambda x: x.to(self.device)
+ if isinstance(x, torch.Tensor) else x
+ )
+ #print("type of batch:", type(batch))
+ losses, scheduled_params = self.train_step(batch)
+ # check NaN
+ for loss_value in list(losses.values()):
+ if isinstance(loss_value,
+ torch.Tensor) and torch.isnan(loss_value).any():
+ raise RuntimeError(
+ f"Detected NaN loss at step {self.step}, epoch {self.epoch}"
+ )
+ if self.step % 50 == 0:
+ self._write_summary(losses, scheduled_params, "train")
+ if self.step % 5000 == 0 and self.step != 0 \
+ and self.epoch != 0:
+ self.valid()
+ self.step += 1
+
+ # valid
+ self.valid()
+
+ def valid(self):
+ # self.model.eval()
+ losses = None
+ for batch in self.val_dl:
+ batch = nested_map(
+ batch, lambda x: x.to(self.device) if isinstance(x, torch.Tensor) else x
+ )
+ current_losses, _ = self.val_step(batch)
+ losses = losses or current_losses
+ for k, v in current_losses.items():
+ losses[k] += v
+ assert losses is not None
+ for k, v in losses.items():
+ losses[k] /= len(self.val_dl)
+ self._write_summary(losses, None, "val")
+
+ if self.best_val_loss >= losses["loss"]:
+ self.best_val_loss = losses["loss"]
+ self.save_to_checkpoint(is_best=True)
+ else:
+ self.save_to_checkpoint(is_best=False)
+
+ def train_step(self, batch):
+ # people say this is the better way to set zero grad
+ # instead of self.optimizer.zero_grad()
+ for param in self.model.parameters():
+ param.grad = None
+
+ # here forward the model
+ with self.autocast:
+ scheduled_params = None
+ loss_dict = self.model.get_loss_dict(batch, self.step)
+
+ loss = loss_dict["loss"]
+ self.scaler.scale(loss).backward()
+ self.scaler.unscale_(self.optimizer)
+ self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(
+ self.model.parameters(), self.params.max_grad_norm or 1e9
+ )
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ return loss_dict, scheduled_params
+
+ def val_step(self, batch):
+ with torch.no_grad():
+ with self.autocast:
+
+ scheduled_params = None
+ loss_dict = self.model.get_loss_dict(batch, self.step)
+
+ return loss_dict, scheduled_params
diff --git a/train/train_config.py b/train/train_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0efe1e2524af469e0d9466fbd67c1b4106121c24
--- /dev/null
+++ b/train/train_config.py
@@ -0,0 +1,45 @@
+from . import *
+
+import sys
+import os
+
+# Determine the absolute path to the external folder
+current_directory = os.path.dirname(os.path.abspath(__file__))
+external_directory = os.path.abspath(os.path.join(current_directory, '../data'))
+
+# Add the external folder to sys.path
+sys.path.append(external_directory)
+
+# Now you can import the external module
+from data_utils import load_datasets, create_train_valid_dataloaders
+from model import init_ldm_model, init_diff_pro_sdf
+
+
+class LdmTrainConfig(TrainConfig):
+
+ def __init__(self, params, output_dir, mode,
+ mask_background, multi_phrase_label, random_pitch_aug, debug_mode=False) -> None:
+ super().__init__(params, None, output_dir)
+ self.debug_mode = debug_mode
+ #self.use_autoreg_cond = use_autoreg_cond
+ #self.use_external_cond = use_external_cond
+ self.mask_background = mask_background
+ self.multi_phrase_label = multi_phrase_label
+ self.random_pitch_aug = random_pitch_aug
+
+ # create model
+ self.ldm_model = init_ldm_model(mode, params, debug_mode)
+ self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)
+
+ # Create dataloader
+ load_first_n = 10 if self.debug_mode else None
+ train_set, valid_set = load_datasets(
+ mode, multi_phrase_label, random_pitch_aug,
+ mask_background, load_first_n
+ )
+ self.train_dl, self.val_dl = create_train_valid_dataloaders(params.batch_size, train_set, valid_set)
+
+ # Create optimizer
+ self.optimizer = torch.optim.Adam(
+ self.model.parameters(), lr=params.learning_rate
+ )
diff --git a/train/train_params.py b/train/train_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b227cf33a0c70292f7ceefe82abae867bd2f6f
--- /dev/null
+++ b/train/train_params.py
@@ -0,0 +1,97 @@
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+ def override(self, attrs):
+ if isinstance(attrs, dict):
+ self.__dict__.update(**attrs)
+ elif isinstance(attrs, (list, tuple, set)):
+ for attr in attrs:
+ self.override(attr)
+ elif attrs is not None:
+ raise NotImplementedError
+ return self
+
+
+
+params_chord = AttrDict(
+ # Training params
+ batch_size=16,
+ max_epoch=10,
+ learning_rate=5e-5,
+ max_grad_norm=10,
+ fp16=True,
+
+ # unet
+ in_channels=2,
+ out_channels=2,
+ channels=64,
+ attention_levels=[2, 3],
+ n_res_blocks=2,
+ channel_multipliers=[1, 2, 4, 4],
+ n_heads=4,
+ tf_layers=1,
+ d_cond=12,
+
+ # ldm
+ linear_start=0.00085,
+ linear_end=0.0120,
+ n_steps=1000,
+ latent_scaling_factor=0.18215
+)
+
+
+
+params_chord_cond = AttrDict(
+ # Training params
+ batch_size=16,
+ max_epoch=10,
+ learning_rate=5e-5,
+ max_grad_norm=10,
+ fp16=True,
+
+ # unet
+ in_channels=4,
+ out_channels=2,
+ channels=64,
+ attention_levels=[2, 3],
+ n_res_blocks=2,
+ channel_multipliers=[1, 2, 4, 4],
+ n_heads=4,
+ tf_layers=1,
+ d_cond=2,
+
+ # ldm
+ linear_start=0.00085,
+ linear_end=0.0120,
+ n_steps=1000,
+ latent_scaling_factor=0.18215
+)
+
+
+params_chord_lsh_cond = AttrDict(
+ # Training params
+ batch_size=16,
+ max_epoch=10,
+ learning_rate=5e-5,
+ max_grad_norm=10,
+ fp16=True,
+
+ # unet
+ in_channels=6,
+ out_channels=2,
+ channels=64,
+ attention_levels=[2, 3],
+ n_res_blocks=2,
+ channel_multipliers=[1, 2, 4, 4],
+ n_heads=4,
+ tf_layers=1,
+ d_cond=2,
+
+ # ldm
+ linear_start=0.00085,
+ linear_end=0.0120,
+ n_steps=1000,
+ latent_scaling_factor=0.18215
+)
\ No newline at end of file
diff --git a/train_cond_w_rhythm_onset.ipynb b/train_cond_w_rhythm_onset.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..9f02d2814bf74cf41d810b755ccb9e11520dd648
--- /dev/null
+++ b/train_cond_w_rhythm_onset.ipynb
@@ -0,0 +1,218 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "from torch.optim import Optimizer\n",
+ "import os\n",
+ "from datetime import datetime\n",
+ "from train.learner import DiffproLearner\n",
+ "\n",
+ "class TrainConfig:\n",
+ "\n",
+ " model: torch.nn.Module\n",
+ " train_dl: DataLoader\n",
+ " val_dl: DataLoader\n",
+ " optimizer: Optimizer\n",
+ "\n",
+ " def __init__(self, params, param_scheduler, output_dir) -> None:\n",
+ " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " self.params = params\n",
+ " self.param_scheduler = param_scheduler\n",
+ " self.output_dir = output_dir\n",
+ "\n",
+ " def train(self):\n",
+ " # collect and display total parameters\n",
+ " total_parameters = sum(\n",
+ " p.numel() for p in self.model.parameters() if p.requires_grad\n",
+ " )\n",
+ " print(f\"Total parameters: {total_parameters}\")\n",
+ "\n",
+ " # dealing with the output storing\n",
+ " output_dir = self.output_dir\n",
+ " if os.path.exists(f\"{output_dir}/chkpts/weights.pt\"):\n",
+ " print(\"Checkpoint already exists.\")\n",
+ " if input(\"Resume training? (y/n)\") != \"y\":\n",
+ " return\n",
+ " else:\n",
+ " output_dir = f\"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}\"\n",
+ " print(f\"Creating new log folder as {output_dir}\")\n",
+ "\n",
+ " # prepare the learner structure and parameters\n",
+ " learner = DiffproLearner(\n",
+ " output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,\n",
+ " self.params\n",
+ " )\n",
+ " learner.train(max_epoch=self.params.max_epoch)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from model import init_ldm_model, init_diff_pro_sdf\n",
+ "from data.dataset_loading import load_datasets, create_dataloader\n",
+ "\n",
+ "WITH_RHYTHM = \"onset\"\n",
+ "\n",
+ "class LdmTrainConfig(TrainConfig):\n",
+ "\n",
+ " def __init__(self, params, output_dir, debug_mode=False) -> None:\n",
+ " super().__init__(params, None, output_dir)\n",
+ " self.debug_mode = debug_mode\n",
+ " #self.use_autoreg_cond = use_autoreg_cond\n",
+ " #self.use_external_cond = use_external_cond\n",
+ " #self.mask_background = mask_background\n",
+ " #self.random_pitch_aug = random_pitch_aug\n",
+ "\n",
+ " # create model\n",
+ " self.ldm_model = init_ldm_model(params, debug_mode)\n",
+ " self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)\n",
+ "\n",
+ " # Create dataloader\n",
+ " train_set = load_datasets(with_rhythm=WITH_RHYTHM)\n",
+ " self.train_dl = create_dataloader(params.batch_size, train_set)\n",
+ " self.val_dl = create_dataloader(params.batch_size, train_set) # we temporarily use train_set for validation\n",
+ "\n",
+ " # Create optimizer4\n",
+ " self.optimizer = torch.optim.Adam(\n",
+ " self.model.parameters(), lr=params.learning_rate\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/music/chord_trainer/train/learner.py:45: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
+ " self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)\n",
+ "/home/music/chord_trainer/train/learner.py:46: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
+ " self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total parameters: 36755330\n",
+ "Creating new log folder as results/test/09-13_171940\n",
+ "{\n",
+ " \"attention_levels\": [\n",
+ " 2,\n",
+ " 3\n",
+ " ],\n",
+ " \"batch_size\": 16,\n",
+ " \"channel_multipliers\": [\n",
+ " 1,\n",
+ " 2,\n",
+ " 4,\n",
+ " 4\n",
+ " ],\n",
+ " \"channels\": 64,\n",
+ " \"d_cond\": 2,\n",
+ " \"fp16\": true,\n",
+ " \"in_channels\": 4,\n",
+ " \"latent_scaling_factor\": 0.18215,\n",
+ " \"learning_rate\": 5e-05,\n",
+ " \"linear_end\": 0.012,\n",
+ " \"linear_start\": 0.00085,\n",
+ " \"max_epoch\": 10,\n",
+ " \"max_grad_norm\": 10,\n",
+ " \"n_heads\": 4,\n",
+ " \"n_res_blocks\": 2,\n",
+ " \"n_steps\": 1000,\n",
+ " \"out_channels\": 2,\n",
+ " \"tf_layers\": 1\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 0: 100%|██████████| 1141/1141 [00:51<00:00, 22.08it/s]\n",
+ "Epoch 1: 100%|██████████| 1141/1141 [00:50<00:00, 22.43it/s]\n",
+ "Epoch 2: 100%|██████████| 1141/1141 [00:47<00:00, 24.02it/s]\n",
+ "Epoch 3: 100%|██████████| 1141/1141 [00:47<00:00, 24.07it/s]\n",
+ "Epoch 4: 100%|██████████| 1141/1141 [01:04<00:00, 17.70it/s]\n",
+ "Epoch 5: 100%|██████████| 1141/1141 [00:50<00:00, 22.42it/s]\n",
+ "Epoch 6: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n",
+ "Epoch 7: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n",
+ "Epoch 8: 100%|██████████| 1141/1141 [01:05<00:00, 17.38it/s]\n",
+ "Epoch 9: 100%|██████████| 1141/1141 [00:49<00:00, 22.83it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "# Import necessary libraries\n",
+ "from train.train_params import params_chord_cond, params_chord\n",
+ "import os\n",
+ "\n",
+ "# Set the argument values directly\n",
+ "args = {\n",
+ " 'output_dir': 'results',\n",
+ " 'uniform_pitch_shift': False,\n",
+ " # 'debug': False,\n",
+ " # 'data_source': \"lmd\",\n",
+ " # 'load_chkpt_from': None,\n",
+ " # 'dataset_path': \"data/lmd_sample/no_drum_sample\",\n",
+ "}\n",
+ "\n",
+ "# Determine random pitch augmentation\n",
+ "random_pitch_aug = not args['uniform_pitch_shift']\n",
+ "\n",
+ "# Generate the filename based on argument settings\n",
+ "fn = 'test'\n",
+ "\n",
+ "# Set the output directory\n",
+ "output_dir = os.path.join(args['output_dir'], fn)\n",
+ "\n",
+ "# Create the training configuration\n",
+ "config = LdmTrainConfig(params_chord_cond, output_dir)\n",
+ "\n",
+ "config.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "music_demo",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/utility.ipynb b/utility.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..9ef29f12074bed96dac825de4080e6506593f3b5
--- /dev/null
+++ b/utility.ipynb
@@ -0,0 +1,172 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/music/.conda/envs/music_demo/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(256, 320) (514, 1880, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "from app import * \n",
+ "\n",
+ "test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n",
+ " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n",
+ " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1)), \n",
+ " np.tile(CHORD_DICTIONARY[\"C:major\"], (16, 1))])\n",
+ "\n",
+ "rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm]\n",
+ "\n",
+ "chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)\n",
+ "\n",
+ "chd_roll = circular_extend(chd_roll)\n",
+ "chd_roll = -chd_roll-1\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "example3 = np.load(\"samples/diy_examples/example3.npy\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "example3[:2,:,:] = example3[2:4,:,:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.save(\"samples/diy_examples/example3.npy\", example3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_chd_roll.min(axis=-1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "example0 = np.load(\"samples/diy_examples/example3.npy\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "example0[2,:,:] = np.min(example0[2:4,:,:], axis=0)\n",
+ "example0[3,:,:] = np.min(example0[2:4,:,:], axis=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([-2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n",
+ " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n",
+ " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n",
+ " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,\n",
+ " -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.])"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "example0[2,:,:].min(axis=-1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.save(\"samples/diy_examples/example3.npy\", example0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "music_demo",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/values_below_threshold2.npy b/values_below_threshold2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..fc12ed54bb62f03ccc479731cbaeff8270c87898
Binary files /dev/null and b/values_below_threshold2.npy differ