lokman2k5 commited on
Commit
e3807d4
·
1 Parent(s): 53d2e48

Add application file

Browse files
Files changed (1) hide show
  1. app.py +347 -0
app.py CHANGED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import spaces
3
+
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
7
+
8
+ import subprocess
9
+ from typing import Tuple, Dict, Literal
10
+ from ctypes import ArgumentError
11
+
12
+ from html_helper import *
13
+ from model_helper import *
14
+
15
+ import torchaudio
16
+ import glob
17
+ import gradio as gr
18
+ from gradio_log import Log
19
+ from pathlib import Path
20
+
21
+ # gradio_log
22
+ log_file = 'amt/log.txt'
23
+ Path(log_file).touch()
24
+
25
+ # @title Load Checkpoint
26
+ model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
27
+ precision = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
28
+ project = '2024'
29
+
30
+ if model_name == "YMT3+":
31
+ checkpoint = "[email protected]"
32
+ args = [checkpoint, '-p', project, '-pr', precision]
33
+ elif model_name == "YPTF+Single (noPS)":
34
+ checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
35
+ args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
36
+ '-hop', '300', '-atc', '1', '-pr', precision]
37
+ elif model_name == "YPTF+Multi (PS)":
38
+ checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
39
+ args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
40
+ '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
41
+ '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
42
+ elif model_name == "YPTF.MoE+Multi (noPS)":
43
+ checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
44
+ args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
45
+ '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
46
+ '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
47
+ '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
48
+ elif model_name == "YPTF.MoE+Multi (PS)":
49
+ checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
50
+ args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
51
+ '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
52
+ '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
53
+ '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
54
+ else:
55
+ raise ValueError(model_name)
56
+
57
+ model = load_model_checkpoint(args=args, device="cpu")
58
+ #model.to("cuda")
59
+ # Keep model on CPU for HuggingFace Spaces free tier
60
+ print("Model loaded on CPU for HuggingFace Spaces deployment")
61
+ # @title GradIO helper
62
+
63
+
64
+ def prepare_media(source_path_or_url: os.PathLike,
65
+ source_type: Literal['audio_filepath', 'youtube_url'],
66
+ delete_video: bool = True,
67
+ simulate = False) -> Dict:
68
+ """prepare media from source path or youtube, and return audio info"""
69
+ # Get audio_file
70
+ if source_type == 'audio_filepath':
71
+ audio_file = source_path_or_url
72
+ elif source_type == 'youtube_url':
73
+ if os.path.exists('/download/yt_audio.mp3'):
74
+ os.remove('/download/yt_audio.mp3')
75
+ # Download from youtube
76
+ with open(log_file, 'w') as lf:
77
+ audio_file = './downloaded/yt_audio'
78
+ command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
79
+ '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
80
+ '--extractor-retries', '10',
81
+ '--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
82
+ if simulate:
83
+ command = command + ['-s']
84
+ process = subprocess.Popen(command,
85
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
86
+
87
+ for line in iter(process.stdout.readline, ''):
88
+ # Filter out unnecessary messages
89
+ print(line)
90
+ if "www.google.com/device" in line:
91
+ hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
92
+ hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
93
+ lf.write(' '.join(hl_text)); lf.flush()
94
+ elif "Authorization successful" in line or "Video unavailable" in line:
95
+ lf.write(line); lf.flush()
96
+ process.stdout.close()
97
+ process.wait()
98
+
99
+ audio_file += '.mp3'
100
+ else:
101
+ raise ValueError(source_type)
102
+
103
+ # Create info
104
+ info = torchaudio.info(audio_file)
105
+ return {
106
+ "filepath": audio_file,
107
+ "track_name": os.path.basename(audio_file).split('.')[0],
108
+ "sample_rate": int(info.sample_rate),
109
+ "bits_per_sample": int(info.bits_per_sample),
110
+ "num_channels": int(info.num_channels),
111
+ "num_frames": int(info.num_frames),
112
+ "duration": int(info.num_frames / info.sample_rate),
113
+ "encoding": str.lower(info.encoding),
114
+ }
115
+
116
+ @spaces.GPU(duration=120) # 2 minute timeout for CPU inference
117
+ def process_audio(audio_filepath, instrument_hint=None):
118
+ if audio_filepath is None:
119
+ return None
120
+ try:
121
+ print(f"Processing audio: {audio_filepath}")
122
+ if instrument_hint and instrument_hint != "Auto (detect all instruments)":
123
+ print(f"Using instrument hint: {instrument_hint}")
124
+
125
+ audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
126
+ midifile = transcribe(model, audio_info, instrument_hint)
127
+ midifile = to_data_url(midifile)
128
+ return create_html_from_midi(midifile) # html midiplayer
129
+ except Exception as e:
130
+ print(f"Error in process_audio: {e}")
131
+ import traceback
132
+ traceback.print_exc()
133
+ return f"<p style='color: red;'>Error processing audio: {str(e)}</p>"
134
+
135
+ # @spaces.GPU # Comment out for Colab
136
+ def process_audio_yt_temp(youtube_url):
137
+ if youtube_url is None:
138
+ return None
139
+ elif youtube_url == "https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg":
140
+ midifile = "./mid/Free Jazz Intro Music - Piano Sway (Intro B - 10 seconds) - OurMusicBox.mid"
141
+ elif youtube_url == "https://youtu.be/mw5VIEIvuMI?si=Dp9UFVw00Tl8CXe2":
142
+ midifile = "./mid/Naomi Scott Speechless from Aladdin Official Video Sony vevo Music.mid"
143
+ elif youtube_url == "https://youtu.be/OXXRoa1U6xU?si=dpYMun4LjZHNydSb":
144
+ midifile = "./mid/Mozart_Sonata_for_Piano_and_Violin_(getmp3.pro).mid"
145
+ midifile = to_data_url(midifile)
146
+ return create_html_from_midi(midifile) # html midiplayer
147
+
148
+
149
+ @spaces.GPU(duration=120)
150
+ def process_video(youtube_url, instrument_hint=None):
151
+ if 'youtu' not in youtube_url:
152
+ return None
153
+ audio_info = prepare_media(youtube_url, source_type='youtube_url')
154
+ midifile = transcribe(model, audio_info, instrument_hint)
155
+ midifile = to_data_url(midifile)
156
+ return create_html_from_midi(midifile) # html midiplayer
157
+
158
+ def play_video(youtube_url):
159
+ if 'youtu' not in youtube_url:
160
+ return None
161
+ return create_html_youtube_player(youtube_url)
162
+
163
+ # def oauth_google():
164
+ # return create_html_oauth()
165
+
166
+ AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True)
167
+ YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg",
168
+ "https://youtu.be/mw5VIEIvuMI?si=Dp9UFVw00Tl8CXe2",
169
+ "https://youtu.be/OXXRoa1U6xU?si=dpYMun4LjZHNydSb"]
170
+ # YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg",
171
+ # "https://www.youtube.com/watch?v=vMboypSkj3c",
172
+ # "https://youtu.be/vRd5KEjX8vw?si=b-qw633ZjaX6Uxy5",
173
+ # "https://youtu.be/bnS-HK_lTHA?si=PQLVAab3QHMbv0S3https://youtu.be/zJB0nnOc7bM?si=EA1DN8nHWJcpQWp_",
174
+ # "https://youtu.be/7mjQooXt28o?si=qqmMxCxwqBlLPDI2",
175
+ # "https://youtu.be/mIWYTg55h10?si=WkbtKfL6NlNquvT8"]
176
+
177
+ theme = gr.Theme.from_hub("gradio/dracula_revamped")
178
+ theme.text_md = '10px'
179
+ theme.text_lg = '12px'
180
+
181
+ theme.body_background_fill_dark = '#060a1c' #'#372037'# '#a17ba5' #'#73d3ac'
182
+ theme.border_color_primary_dark = '#45507328'
183
+ theme.block_background_fill_dark = '#3845685c'
184
+
185
+ theme.body_text_color_dark = 'white'
186
+ theme.block_title_text_color_dark = 'black'
187
+ theme.body_text_color_subdued_dark = '#e4e9e9'
188
+
189
+ css = """
190
+ .gradio-container {
191
+ background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
192
+ background-size: 400% 400%;
193
+ animation: gradient 15s ease infinite;
194
+ height: 100vh;
195
+ }
196
+ @keyframes gradient {
197
+ 0% {background-position: 0% 50%;}
198
+ 50% {background-position: 100% 50%;}
199
+ 100% {background-position: 0% 50%;}
200
+ }
201
+ #mylog {font-size: 12pt; line-height: 1.2; min-height: 2em; max-height: 4em;}
202
+ """
203
+
204
+ with gr.Blocks(theme=theme, css=css) as demo:
205
+
206
+ with gr.Row():
207
+ with gr.Column(scale=10):
208
+ gr.Markdown(
209
+ f"""
210
+ ## 🎶YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation
211
+ - Model name: `{model_name}`
212
+ <details>
213
+ <summary>▶model details◀</summary>
214
+
215
+ | **Component** | **Details** |
216
+ |--------------------------|--------------------------------------------------|
217
+ | Encoder backbone | Perceiver-TF + Mixture of Experts (2/8) |
218
+ | Decoder backbone | Multi-channel T5-small |
219
+ | Tokenizer | MT3 tokens with Singing extension |
220
+ | Dataset | YourMT3 dataset |
221
+ | Augmentation strategy | Intra-/Cross dataset stem augment, No Pitch-shifting |
222
+ | FP Precision | BF16-mixed for training, FP16 for inference |
223
+ </details>
224
+
225
+ ## Caution:
226
+ - For acadmic reproduction purpose, we strongly recommend to use [Colab Demo](https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing) with multiple checkpoints.
227
+
228
+ ## YouTube transcription (Sorry!! YouTube blocked HuggingFace IP. We display a few pre-transcribed examples in the below!):
229
+ - Select one from the `Examples`, click `Get Audio from YouTube`, and then press `Transcribe`.
230
+
231
+ <div style="display: inline-block;">
232
+ <a href="https://arxiv.org/abs/2407.04822">
233
+ <img src="https://img.shields.io/badge/arXiv:2407.04822-B31B1B?logo=arxiv&logoColor=fff&style=plastic" alt="arXiv Badge"/>
234
+ </a>
235
+ </div>
236
+ <div style="display: inline-block;">
237
+ <a href="https://github.com/mimbres/YourMT3">
238
+ <img src="https://img.shields.io/badge/GitHub-181717?logo=github&logoColor=fff&style=plastic" alt="GitHub Badge"/>
239
+ </a>
240
+ </div>
241
+ <div style="display: inline-block;">
242
+ <a href="https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing">
243
+ <img src="https://img.shields.io/badge/Google%20Colab-F9AB00?logo=googlecolab&logoColor=fff&style=plastic"/>
244
+ </a>
245
+ </div>
246
+ """)
247
+
248
+ with gr.Group():
249
+
250
+ with gr.Tab("From YouTube"):
251
+ with gr.Column(scale=4):
252
+ # Input URL
253
+ youtube_url = gr.Textbox(label="YouTube Link URL",
254
+ placeholder="https://youtu.be/...")
255
+ # Display examples
256
+ gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
257
+ # Play button
258
+ play_video_button = gr.Button("Get Audio from YouTube", variant="primary")
259
+ # Play youtube
260
+ youtube_player = gr.HTML(render=True)
261
+
262
+ with gr.Column(scale=4):
263
+ # Instrument selection for YouTube
264
+ youtube_instrument_selector = gr.Dropdown(
265
+ choices=["Auto (detect all instruments)", "Vocals/Singing", "Guitar", "Piano",
266
+ "Violin", "Drums", "Bass", "Saxophone", "Flute"],
267
+ value="Auto (detect all instruments)",
268
+ label="Target Instrument",
269
+ info="Choose the specific instrument you want to transcribe"
270
+ )
271
+ with gr.Row():
272
+ # Submit button
273
+ transcribe_video_button = gr.Button("Transcribe", variant="primary")
274
+ # Oauth button
275
+ oauth_button = gr.Button("google.com/device", variant="primary", link="https://www.google.com/device")
276
+
277
+ with gr.Column(scale=1):
278
+ # Transcribe
279
+ output_tab2 = gr.HTML(render=True)
280
+ # video_output = gr.Text(label="Video Info")
281
+
282
+ def process_youtube_with_instrument(url, instrument_choice):
283
+ # Map UI choices to internal instrument hints
284
+ instrument_map = {
285
+ "Auto (detect all instruments)": None,
286
+ "Vocals/Singing": "vocals",
287
+ "Guitar": "guitar",
288
+ "Piano": "piano",
289
+ "Violin": "violin",
290
+ "Drums": "drums",
291
+ "Bass": "bass",
292
+ "Saxophone": "saxophone",
293
+ "Flute": "flute"
294
+ }
295
+ instrument_hint = instrument_map.get(instrument_choice, None)
296
+ # For now, using the temp function - you can replace with process_video when ready
297
+ return process_audio_yt_temp(url) # TODO: Replace with process_video(url, instrument_hint)
298
+
299
+ transcribe_video_button.click(process_youtube_with_instrument, inputs=[youtube_url, youtube_instrument_selector], outputs=output_tab2)
300
+ # transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
301
+ # Play
302
+ play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)
303
+ with gr.Column(scale=1):
304
+ Log(log_file, dark=True, xterm_font_size=12, elem_id='mylog')
305
+
306
+ with gr.Tab("Upload audio"):
307
+ # Input
308
+ audio_input = gr.Audio(label="Record Audio", type="filepath",
309
+ show_share_button=True, show_download_button=True)
310
+
311
+ # Instrument selection
312
+ instrument_selector = gr.Dropdown(
313
+ choices=["Auto (detect all instruments)", "Vocals/Singing", "Guitar", "Piano",
314
+ "Violin", "Drums", "Bass", "Saxophone", "Flute"],
315
+ value="Auto (detect all instruments)",
316
+ label="Target Instrument",
317
+ info="Choose the specific instrument you want to transcribe, or 'Auto' for all instruments"
318
+ )
319
+
320
+ # Display examples
321
+ gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
322
+ # Submit button
323
+ transcribe_audio_button = gr.Button("Transcribe", variant="primary")
324
+ # Transcribe
325
+ output_tab1 = gr.HTML()
326
+
327
+ def process_with_instrument(audio_file, instrument_choice):
328
+ # Map UI choices to internal instrument hints
329
+ instrument_map = {
330
+ "Auto (detect all instruments)": None,
331
+ "Vocals/Singing": "vocals",
332
+ "Guitar": "guitar",
333
+ "Piano": "piano",
334
+ "Violin": "violin",
335
+ "Drums": "drums",
336
+ "Bass": "bass",
337
+ "Saxophone": "saxophone",
338
+ "Flute": "flute"
339
+ }
340
+ instrument_hint = instrument_map.get(instrument_choice, None)
341
+ print(f"UI choice: {instrument_choice} -> instrument_hint: {instrument_hint}")
342
+ return process_audio(audio_file, instrument_hint)
343
+
344
+ transcribe_audio_button.click(process_with_instrument, inputs=[audio_input, instrument_selector], outputs=output_tab1)
345
+
346
+ # Launch for HuggingFace Spaces
347
+ demo.launch(debug=True)