Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| YourMT3+ Local Setup and Debug Script | |
| This script helps set up and debug YourMT3+ locally instead of using Colab. | |
| Run this to check your setup and identify issues. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| from pathlib import Path | |
| def check_dependencies(): | |
| """Check if all required dependencies are installed""" | |
| print("π Checking dependencies...") | |
| required_packages = [ | |
| 'torch', 'torchaudio', 'transformers', 'gradio', | |
| 'pytorch_lightning', 'einops', 'numpy', 'librosa' | |
| ] | |
| missing_packages = [] | |
| for package in required_packages: | |
| try: | |
| __import__(package) | |
| print(f" β {package}") | |
| except ImportError: | |
| print(f" β {package} - MISSING") | |
| missing_packages.append(package) | |
| if missing_packages: | |
| print(f"\nβ οΈ Missing packages: {', '.join(missing_packages)}") | |
| print("Install them with:") | |
| print(f"pip install {' '.join(missing_packages)}") | |
| return False | |
| else: | |
| print("β All dependencies found!") | |
| return True | |
| def check_model_weights(): | |
| """Check if model weights are available""" | |
| print("\nπ Checking model weights...") | |
| base_path = Path("amt/logs/2024") | |
| if not base_path.exists(): | |
| print(f"β Model directory not found: {base_path}") | |
| print("Create the directory with: mkdir -p amt/logs/2024") | |
| return False | |
| # Check for the default model checkpoint | |
| checkpoint_name = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
| checkpoint_path = base_path / checkpoint_name | |
| if checkpoint_path.exists(): | |
| size = checkpoint_path.stat().st_size / (1024**3) # GB | |
| print(f"β Model checkpoint found: {checkpoint_path}") | |
| print(f" Size: {size:.2f} GB") | |
| return True | |
| else: | |
| print(f"β Model checkpoint not found: {checkpoint_path}") | |
| print("\nAvailable checkpoints:") | |
| found_any = False | |
| for ckpt in base_path.glob("*.ckpt"): | |
| print(f" π {ckpt.name}") | |
| found_any = True | |
| if not found_any: | |
| print(" (none found)") | |
| print("\nπ‘ You need to download model weights:") | |
| print(" 1. Download from the official YourMT3 repository") | |
| print(" 2. Place .ckpt files in amt/logs/2024/") | |
| return found_any | |
| def test_model_loading(): | |
| """Test if the model can be loaded""" | |
| print("\nπ Testing model loading...") | |
| try: | |
| # Add amt/src to path | |
| sys.path.append(os.path.abspath('amt/src')) | |
| from model_helper import load_model_checkpoint | |
| # Test with minimal args | |
| model_name = 'YPTF.MoE+Multi (noPS)' | |
| precision = '16' | |
| project = '2024' | |
| checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| print(f"Loading model: {model_name}") | |
| model = load_model_checkpoint(args=args, device="cpu") | |
| # Test task manager | |
| if hasattr(model, 'task_manager'): | |
| print("β Model has task_manager") | |
| if hasattr(model.task_manager, 'task_name'): | |
| print(f" Task name: {model.task_manager.task_name}") | |
| if hasattr(model.task_manager, 'task'): | |
| task_config = model.task_manager.task | |
| print(f" Task config keys: {list(task_config.keys())}") | |
| if 'eval_subtask_prefix' in task_config: | |
| prefixes = list(task_config['eval_subtask_prefix'].keys()) | |
| print(f" Available subtask prefixes: {prefixes}") | |
| else: | |
| print(" No eval_subtask_prefix found") | |
| print("β Model loaded successfully!") | |
| return True | |
| else: | |
| print("β Model doesn't have task_manager") | |
| return False | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_example_transcription(): | |
| """Test transcription with example audio""" | |
| print("\nπ Testing example transcription...") | |
| example_files = list(Path("examples").glob("*.wav"))[:1] # Just test one file | |
| if not example_files: | |
| print("β No example audio files found in examples/") | |
| return False | |
| try: | |
| example_file = example_files[0] | |
| print(f"Testing with: {example_file}") | |
| # Import what we need | |
| sys.path.append(os.path.abspath('amt/src')) | |
| from model_helper import transcribe, load_model_checkpoint | |
| import torchaudio | |
| # Load model | |
| model_name = 'YPTF.MoE+Multi (noPS)' | |
| precision = '16' | |
| project = '2024' | |
| checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| model = load_model_checkpoint(args=args, device="cpu") | |
| # Prepare audio info | |
| info = torchaudio.info(str(example_file)) | |
| audio_info = { | |
| "filepath": str(example_file), | |
| "track_name": example_file.stem, | |
| "sample_rate": int(info.sample_rate), | |
| "bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16, | |
| "num_channels": int(info.num_channels), | |
| "num_frames": int(info.num_frames), | |
| "duration": int(info.num_frames / info.sample_rate), | |
| "encoding": str.lower(str(info.encoding)), | |
| } | |
| print("Testing normal transcription...") | |
| midifile = transcribe(model, audio_info, instrument_hint=None) | |
| print(f"β Normal transcription successful: {midifile}") | |
| print("Testing with vocals hint...") | |
| midifile_vocals = transcribe(model, audio_info, instrument_hint="vocals") | |
| print(f"β Vocals transcription successful: {midifile_vocals}") | |
| return True | |
| except Exception as e: | |
| print(f"β Error testing transcription: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def create_local_launcher(): | |
| """Create a simple launcher script""" | |
| launcher_content = '''#!/usr/bin/env python3 | |
| """ | |
| YourMT3+ Local Launcher | |
| Run this script to start the web interface locally | |
| """ | |
| import sys | |
| import os | |
| # Change to the YourMT3 directory | |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
| print("π΅ Starting YourMT3+ with Instrument Conditioning...") | |
| print("π Working directory:", os.getcwd()) | |
| print("π Web interface will be available at: http://127.0.0.1:7860") | |
| print("π― New feature: Select specific instruments from the dropdown!") | |
| print() | |
| try: | |
| # Run the app | |
| exec(open('app.py').read()) | |
| except KeyboardInterrupt: | |
| print("\\nπ YourMT3+ stopped by user") | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| ''' | |
| with open('run_yourmt3.py', 'w') as f: | |
| f.write(launcher_content) | |
| # Make it executable on Unix systems | |
| try: | |
| os.chmod('run_yourmt3.py', 0o755) | |
| except: | |
| pass | |
| print("β Created launcher script: run_yourmt3.py") | |
| def main(): | |
| print("π΅ YourMT3+ Local Setup Checker") | |
| print("=" * 50) | |
| # Check current directory | |
| if not Path("app.py").exists(): | |
| print("β Not in YourMT3 directory!") | |
| print("Please run this script from the YourMT3 root directory") | |
| sys.exit(1) | |
| print(f"π Working directory: {os.getcwd()}") | |
| # Run all checks | |
| deps_ok = check_dependencies() | |
| weights_ok = check_model_weights() | |
| if not deps_ok: | |
| print("\nβ Please install missing dependencies first") | |
| sys.exit(1) | |
| if not weights_ok: | |
| print("\nβ Please download model weights first") | |
| print("The app won't work without them") | |
| sys.exit(1) | |
| print("\n" + "=" * 50) | |
| model_ok = test_model_loading() | |
| if model_ok: | |
| print("\nπ Setup looks good!") | |
| create_local_launcher() | |
| print("\nπ To start YourMT3+:") | |
| print(" python run_yourmt3.py") | |
| print(" OR") | |
| print(" python app.py") | |
| print("\nπ‘ Then open: http://127.0.0.1:7860") | |
| # Ask if user wants to test transcription | |
| try: | |
| test_now = input("\nπ§ͺ Test transcription now? (y/n): ").lower().startswith('y') | |
| if test_now: | |
| test_example_transcription() | |
| except: | |
| pass | |
| else: | |
| print("\nβ Model loading failed - check the errors above") | |
| if __name__ == "__main__": | |
| main() | |