Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +41 -0
- mappings.py +81 -0
- t2a.py +35 -0
app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import openai
|
| 3 |
+
from t2a import text_to_audio
|
| 4 |
+
import joblib
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
reg = joblib.load('text_reg.joblib')
|
| 9 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 10 |
+
finetune = "davinci:ft-personal:autodrummer-v4-2022-11-01-22-44-58"
|
| 11 |
+
|
| 12 |
+
def get_note_text(prompt):
|
| 13 |
+
prompt = prompt + " ->"
|
| 14 |
+
# get completion from finetune
|
| 15 |
+
response = openai.Completion.create(
|
| 16 |
+
engine=finetune,
|
| 17 |
+
prompt=prompt,
|
| 18 |
+
temperature=0.7,
|
| 19 |
+
max_tokens=100,
|
| 20 |
+
top_p=1,
|
| 21 |
+
frequency_penalty=0,
|
| 22 |
+
presence_penalty=0,
|
| 23 |
+
stop=["###"]
|
| 24 |
+
)
|
| 25 |
+
return response.choices[0].text.strip()
|
| 26 |
+
|
| 27 |
+
def get_drummer_output(prompt, openai_api_key):
|
| 28 |
+
openai.api_key = openai_api_key
|
| 29 |
+
note_text = get_note_text(prompt)
|
| 30 |
+
# note_text = note_text + " " + note_text
|
| 31 |
+
# note_text = "k n k n k n k n s n h n k n s n k n k n k n k n k n k n h n k n n"
|
| 32 |
+
prompt_enc = model.encode([prompt])
|
| 33 |
+
bpm = int(reg.predict(prompt_enc)[0]) + 20
|
| 34 |
+
print(bpm, "bpm", "notes are", note_text)
|
| 35 |
+
audio = text_to_audio(note_text, bpm)
|
| 36 |
+
# audio to numpy
|
| 37 |
+
audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
|
| 38 |
+
return (96000, audio)
|
| 39 |
+
|
| 40 |
+
iface = gr.Interface(fn=get_drummer_output, inputs=["text", "text"], outputs="audio")
|
| 41 |
+
iface.launch()
|
mappings.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
inverse_mapping_old = {
|
| 3 |
+
36: 'kick',
|
| 4 |
+
38: 'snr', # snare
|
| 5 |
+
42: 'hh', # hihat
|
| 6 |
+
48: 'tom',
|
| 7 |
+
49: 'csh', # crash
|
| 8 |
+
51: 'ride',
|
| 9 |
+
39: 'clap',
|
| 10 |
+
56: 'cbl', # cowbell
|
| 11 |
+
75: 'claves',
|
| 12 |
+
64: 'conga',
|
| 13 |
+
70: 'maracas',
|
| 14 |
+
76: 'guiro',
|
| 15 |
+
69: 'cabasa',
|
| 16 |
+
60: 'bongo',
|
| 17 |
+
37: 'shkr', # shaker
|
| 18 |
+
54: 'tamb', # tambourine
|
| 19 |
+
81: 'triangle',
|
| 20 |
+
49: 'cymbal',
|
| 21 |
+
35: 'kick', # bass drum of some kind
|
| 22 |
+
55: 'spl', # splash cymbal
|
| 23 |
+
0: 'none',
|
| 24 |
+
46: 'hh_open', # hihat_open
|
| 25 |
+
44: 'hh', # hihat_pedal
|
| 26 |
+
40: 'snr', # snare_rimshot
|
| 27 |
+
43: 'tom_high_floor',
|
| 28 |
+
-1: 'none',
|
| 29 |
+
22: 'kick', # VERIFY
|
| 30 |
+
58: 'vibraslap',
|
| 31 |
+
53: 'ride_bell',
|
| 32 |
+
50: 'tom_high',
|
| 33 |
+
59: 'ride_2',
|
| 34 |
+
45: 'tom_low',
|
| 35 |
+
47: 'tom_low_mid',
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
inverse_mapping = {
|
| 40 |
+
36: 'k', # kick
|
| 41 |
+
38: 's', # snare
|
| 42 |
+
42: 'h', # hihat
|
| 43 |
+
48: 't0', # tom
|
| 44 |
+
49: 'c', # crash
|
| 45 |
+
51: 'r', # ride
|
| 46 |
+
39: 'l', # clap
|
| 47 |
+
56: 'b', # cowbell
|
| 48 |
+
37: 'z', # shaker
|
| 49 |
+
54: 'a', # tambourine
|
| 50 |
+
81: 'i', # triangle
|
| 51 |
+
49: 'y', # cymbal
|
| 52 |
+
35: 'k', # bass drum of some kind
|
| 53 |
+
55: 'p', # splash cymbal
|
| 54 |
+
0: 'n', # none
|
| 55 |
+
46: 'h1', # hihat_open
|
| 56 |
+
44: 'h', # hihat_pedal
|
| 57 |
+
40: 's', # snare_rimshot
|
| 58 |
+
43: 't2', # tom_high_floor
|
| 59 |
+
-1: 'n', # none
|
| 60 |
+
22: 'k', # VERIFY
|
| 61 |
+
58: 'v', # vibraslap
|
| 62 |
+
53: 'd', # ride_bell
|
| 63 |
+
50: 't1', # tom_high
|
| 64 |
+
59: 'e', # ride_2
|
| 65 |
+
45: 't3', # tom_low
|
| 66 |
+
47: 't4', # tom_low_mid
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
mappings = {
|
| 70 |
+
"k": "drum-samples/kick.wav",
|
| 71 |
+
"s": "drum-samples/snare.wav",
|
| 72 |
+
"h": "drum-samples/hihat.wav",
|
| 73 |
+
"c": "drum-samples/cymbal.wav",
|
| 74 |
+
"y": "drum-samples/cymbal.wav",
|
| 75 |
+
"l": "drum-samples/clap.wav",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
replacements = {
|
| 79 |
+
"hh_closed": "hh",
|
| 80 |
+
"hh_open": "hh",
|
| 81 |
+
}
|
t2a.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydub import AudioSegment
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import os
|
| 4 |
+
from mappings import mappings, replacements
|
| 5 |
+
|
| 6 |
+
def bpm_to_ms(bpm):
|
| 7 |
+
return 60000 / 2 / bpm
|
| 8 |
+
|
| 9 |
+
def text_to_audio(text, bpm):
|
| 10 |
+
buffer_length = bpm_to_ms(bpm)
|
| 11 |
+
audio = AudioSegment.silent(duration=0)
|
| 12 |
+
|
| 13 |
+
for key, value in replacements.items():
|
| 14 |
+
text = text.replace(key, value)
|
| 15 |
+
|
| 16 |
+
for note in text.split(" "):
|
| 17 |
+
if note in mappings:
|
| 18 |
+
to_add = AudioSegment.from_wav(mappings[note])
|
| 19 |
+
# slice to be of consistent length or add more silence
|
| 20 |
+
if len(to_add) < buffer_length:
|
| 21 |
+
to_add = to_add + AudioSegment.silent(duration=buffer_length - len(to_add))
|
| 22 |
+
elif len(to_add) > buffer_length:
|
| 23 |
+
to_add = to_add[:buffer_length]
|
| 24 |
+
audio = audio + to_add
|
| 25 |
+
elif note == "n":
|
| 26 |
+
audio = audio + AudioSegment.silent(duration=buffer_length)
|
| 27 |
+
else: # everything else is a clap
|
| 28 |
+
to_add = AudioSegment.from_wav(mappings["l"])
|
| 29 |
+
# slice to be of consistent length or add more silence
|
| 30 |
+
if len(to_add) < buffer_length:
|
| 31 |
+
to_add = to_add + AudioSegment.silent(duration=buffer_length - len(to_add))
|
| 32 |
+
elif len(to_add) > buffer_length:
|
| 33 |
+
to_add = to_add[:buffer_length]
|
| 34 |
+
audio = audio + to_add
|
| 35 |
+
return audio
|