Spaces:
Runtime error
Runtime error
| """ | |
| YourMT3+ with Instrument Conditioning - Google Colab Version | |
| Instructions for use in Google Colab: | |
| 1. First, run this cell to install dependencies: | |
| !pip install torch torchaudio transformers gradio pytorch-lightning | |
| 2. Clone the YourMT3 repository: | |
| !git clone https://github.com/mimbres/YourMT3.git | |
| %cd YourMT3 | |
| 3. Copy this code to a cell and run it to launch the interface | |
| 4. The Gradio interface will provide a public URL you can access | |
| """ | |
| import sys | |
| import os | |
| # Add the amt/src directory to Python path | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) | |
| import subprocess | |
| from typing import Tuple, Dict, Literal | |
| from ctypes import ArgumentError | |
| from html_helper import * | |
| from model_helper import * | |
| import torchaudio | |
| import glob | |
| import gradio as gr | |
| from gradio_log import Log | |
| from pathlib import Path | |
| # Create log file | |
| log_file = 'amt/log.txt' | |
| Path(log_file).touch() | |
| # Model Configuration | |
| model_name = 'YPTF.MoE+Multi (noPS)' # You can change this | |
| precision = '16' | |
| project = '2024' | |
| print(f"Loading model: {model_name}") | |
| # Get model arguments based on selection | |
| if model_name == "YMT3+": | |
| checkpoint = "[email protected]" | |
| args = [checkpoint, '-p', project, '-pr', precision] | |
| elif model_name == "YPTF+Single (noPS)": | |
| checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt" | |
| args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec', | |
| '-hop', '300', '-atc', '1', '-pr', precision] | |
| elif model_name == "YPTF+Multi (PS)": | |
| checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', | |
| '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', | |
| '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| elif model_name == "YPTF.MoE+Multi (noPS)": | |
| checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| elif model_name == "YPTF.MoE+Multi (PS)": | |
| checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # Load model | |
| print("Loading model checkpoint...") | |
| try: | |
| model = load_model_checkpoint(args=args, device="cpu") | |
| model.to("cuda") | |
| print("โ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"โ Error loading model: {e}") | |
| print("Make sure the model checkpoints are available in amt/logs/") | |
| # Helper functions | |
| def prepare_media(source_path_or_url: os.PathLike, | |
| source_type: Literal['audio_filepath', 'youtube_url'], | |
| delete_video: bool = True, | |
| simulate = False) -> Dict: | |
| """prepare media from source path or youtube, and return audio info""" | |
| if source_type == 'audio_filepath': | |
| audio_file = source_path_or_url | |
| elif source_type == 'youtube_url': | |
| if os.path.exists('/content/yt_audio.mp3'): # Colab path | |
| os.remove('/content/yt_audio.mp3') | |
| # Download from youtube | |
| with open(log_file, 'w') as lf: | |
| audio_file = '/content/yt_audio' # Colab path | |
| command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio', | |
| '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames', | |
| '--extractor-retries', '10', '--force-overwrites'] | |
| if simulate: | |
| command = command + ['-s'] | |
| process = subprocess.Popen(command, | |
| stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| for line in iter(process.stdout.readline, ''): | |
| print(line) | |
| lf.write(line); lf.flush() | |
| process.stdout.close() | |
| process.wait() | |
| audio_file += '.mp3' | |
| else: | |
| raise ValueError(source_type) | |
| # Create info | |
| info = torchaudio.info(audio_file) | |
| return { | |
| "filepath": audio_file, | |
| "track_name": os.path.basename(audio_file).split('.')[0], | |
| "sample_rate": int(info.sample_rate), | |
| "bits_per_sample": int(info.bits_per_sample), | |
| "num_channels": int(info.num_channels), | |
| "num_frames": int(info.num_frames), | |
| "duration": int(info.num_frames / info.sample_rate), | |
| "encoding": str.lower(info.encoding), | |
| } | |
| def process_audio(audio_filepath, instrument_hint=None): | |
| """Process uploaded audio with optional instrument conditioning""" | |
| if audio_filepath is None: | |
| return None | |
| try: | |
| audio_info = prepare_media(audio_filepath, source_type='audio_filepath') | |
| midifile = transcribe(model, audio_info, instrument_hint) | |
| midifile = to_data_url(midifile) | |
| return create_html_from_midi(midifile) | |
| except Exception as e: | |
| return f"<p style='color: red;'>Error processing audio: {str(e)}</p>" | |
| def process_video(youtube_url, instrument_hint=None): | |
| """Process YouTube video with optional instrument conditioning""" | |
| if 'youtu' not in youtube_url: | |
| return None | |
| try: | |
| audio_info = prepare_media(youtube_url, source_type='youtube_url') | |
| midifile = transcribe(model, audio_info, instrument_hint) | |
| midifile = to_data_url(midifile) | |
| return create_html_from_midi(midifile) | |
| except Exception as e: | |
| return f"<p style='color: red;'>Error processing YouTube video: {str(e)}</p>" | |
| def play_video(youtube_url): | |
| if 'youtu' not in youtube_url: | |
| return None | |
| return create_html_youtube_player(youtube_url) | |
| # Get example files | |
| AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True) | |
| YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg", | |
| "https://youtu.be/mw5VIEIvuMI?si=Dp9UFVw00Tl8CXe2", | |
| "https://youtu.be/OXXRoa1U6xU?si=dpYMun4LjZHNydSb"] | |
| # Gradio theme | |
| theme = gr.Theme.from_hub("gradio/dracula_revamped") | |
| css = """ | |
| .gradio-container { | |
| background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); | |
| background-size: 400% 400%; | |
| animation: gradient 15s ease infinite; | |
| } | |
| @keyframes gradient { | |
| 0% {background-position: 0% 50%;} | |
| 50% {background-position: 100% 50%;} | |
| 100% {background-position: 0% 50%;} | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(theme=theme, css=css) as demo: | |
| gr.Markdown(f""" | |
| # ๐ถ YourMT3+ with Instrument Conditioning | |
| **Enhanced music transcription with instrument-specific control!** | |
| **New Feature**: Select which instrument you want to transcribe from the dropdown menu. | |
| This solves the problem of the model switching between instruments mid-track. | |
| **Model**: `{model_name}` | **Running in**: Google Colab | |
| --- | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("๐ต Upload Audio"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| format="wav" | |
| ) | |
| instrument_selector = gr.Dropdown( | |
| choices=[ | |
| "Auto (detect all instruments)", | |
| "Vocals/Singing", | |
| "Guitar", | |
| "Piano", | |
| "Violin", | |
| "Drums", | |
| "Bass", | |
| "Saxophone", | |
| "Flute" | |
| ], | |
| value="Auto (detect all instruments)", | |
| label="๐ฏ Target Instrument", | |
| info="NEW! Choose the specific instrument you want to transcribe" | |
| ) | |
| transcribe_button = gr.Button("๐ผ Transcribe", variant="primary", size="lg") | |
| if AUDIO_EXAMPLES: | |
| gr.Examples(examples=AUDIO_EXAMPLES[:5], inputs=audio_input) | |
| with gr.Row(): | |
| output_audio = gr.HTML(label="Transcription Result") | |
| with gr.Tab("๐บ YouTube"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| youtube_input = gr.Textbox( | |
| label="YouTube URL", | |
| placeholder="https://youtu.be/..." | |
| ) | |
| youtube_instrument_selector = gr.Dropdown( | |
| choices=[ | |
| "Auto (detect all instruments)", | |
| "Vocals/Singing", | |
| "Guitar", | |
| "Piano", | |
| "Violin", | |
| "Drums", | |
| "Bass", | |
| "Saxophone", | |
| "Flute" | |
| ], | |
| value="Auto (detect all instruments)", | |
| label="๐ฏ Target Instrument", | |
| info="Choose the specific instrument you want to transcribe" | |
| ) | |
| with gr.Row(): | |
| play_button = gr.Button("โถ๏ธ Preview Video", variant="secondary") | |
| transcribe_yt_button = gr.Button("๐ผ Transcribe", variant="primary") | |
| gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_input) | |
| with gr.Row(): | |
| with gr.Column(): | |
| youtube_player = gr.HTML(label="Video Preview") | |
| with gr.Column(): | |
| output_youtube = gr.HTML(label="Transcription Result") | |
| # Event handlers | |
| def process_with_instrument_audio(audio_file, instrument_choice): | |
| instrument_map = { | |
| "Auto (detect all instruments)": None, | |
| "Vocals/Singing": "vocals", | |
| "Guitar": "guitar", | |
| "Piano": "piano", | |
| "Violin": "violin", | |
| "Drums": "drums", | |
| "Bass": "bass", | |
| "Saxophone": "saxophone", | |
| "Flute": "flute" | |
| } | |
| instrument_hint = instrument_map.get(instrument_choice, None) | |
| return process_audio(audio_file, instrument_hint) | |
| def process_with_instrument_youtube(url, instrument_choice): | |
| instrument_map = { | |
| "Auto (detect all instruments)": None, | |
| "Vocals/Singing": "vocals", | |
| "Guitar": "guitar", | |
| "Piano": "piano", | |
| "Violin": "violin", | |
| "Drums": "drums", | |
| "Bass": "bass", | |
| "Saxophone": "saxophone", | |
| "Flute": "flute" | |
| } | |
| instrument_hint = instrument_map.get(instrument_choice, None) | |
| return process_video(url, instrument_hint) | |
| # Connect events | |
| transcribe_button.click( | |
| process_with_instrument_audio, | |
| inputs=[audio_input, instrument_selector], | |
| outputs=output_audio | |
| ) | |
| transcribe_yt_button.click( | |
| process_with_instrument_youtube, | |
| inputs=[youtube_input, youtube_instrument_selector], | |
| outputs=output_youtube | |
| ) | |
| play_button.click(play_video, inputs=youtube_input, outputs=youtube_player) | |
| print("๐ Launching YourMT3+ with Instrument Conditioning...") | |
| print("๐ Tips:") | |
| print(" โข Try 'Vocals/Singing' for vocal tracks to avoid instrument switching") | |
| print(" โข Use 'Guitar' for guitar solos to get complete transcriptions") | |
| print(" โข 'Auto' works like the original YourMT3+") | |
| # Launch with share=True for Colab public URL | |
| demo.launch(share=True, debug=True) | |