Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Quick test script for YourMT3+ instrument conditioning | |
| Run this to test if everything is working before launching the full interface | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add amt/src to path | |
| sys.path.append(os.path.abspath('amt/src')) | |
| def test_basic_import(): | |
| """Test if we can import the basic modules""" | |
| print("π Testing basic imports...") | |
| try: | |
| import torch | |
| print("β torch") | |
| import torchaudio | |
| print("β torchaudio") | |
| import gradio as gr | |
| print("β gradio") | |
| # Test YourMT3 imports | |
| from model_helper import load_model_checkpoint, transcribe | |
| print("β model_helper") | |
| from html_helper import create_html_from_midi, to_data_url | |
| print("β html_helper") | |
| return True | |
| except Exception as e: | |
| print(f"β Import error: {e}") | |
| return False | |
| def test_model_loading(): | |
| """Test model loading with debug output""" | |
| print("\nπ Testing model loading...") | |
| try: | |
| from model_helper import load_model_checkpoint | |
| # Use the same args as app.py | |
| model_name = 'YPTF.MoE+Multi (noPS)' | |
| precision = '16' | |
| project = '2024' | |
| checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
| args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
| print(f"Loading {model_name}...") | |
| model = load_model_checkpoint(args=args, device="cpu") | |
| print("β Model loaded successfully!") | |
| # Test our debug function | |
| from model_helper import debug_model_task_config | |
| debug_model_task_config(model) | |
| return model | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def test_instrument_conditioning(model): | |
| """Test the instrument conditioning with a sample file""" | |
| print("\nπ Testing instrument conditioning...") | |
| # Find a test audio file | |
| example_files = list(Path("examples").glob("*.wav")) | |
| if not example_files: | |
| print("β No example files found") | |
| return False | |
| test_file = example_files[0] | |
| print(f"Using test file: {test_file}") | |
| try: | |
| import torchaudio | |
| from model_helper import transcribe | |
| # Create audio info | |
| info = torchaudio.info(str(test_file)) | |
| audio_info = { | |
| "filepath": str(test_file), | |
| "track_name": test_file.stem + "_test", | |
| "sample_rate": int(info.sample_rate), | |
| "bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16, | |
| "num_channels": int(info.num_channels), | |
| "num_frames": int(info.num_frames), | |
| "duration": int(info.num_frames / info.sample_rate), | |
| "encoding": str.lower(str(info.encoding)), | |
| } | |
| print("\n--- Testing normal transcription ---") | |
| midifile1 = transcribe(model, audio_info, instrument_hint=None) | |
| print(f"Normal transcription result: {midifile1}") | |
| print("\n--- Testing vocals conditioning ---") | |
| midifile2 = transcribe(model, audio_info, instrument_hint="vocals") | |
| print(f"Vocals transcription result: {midifile2}") | |
| print("β Instrument conditioning test completed!") | |
| return True | |
| except Exception as e: | |
| print(f"β Instrument conditioning test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def main(): | |
| print("π΅ YourMT3+ Quick Test") | |
| print("=" * 40) | |
| # Check if we're in the right directory | |
| if not Path("app.py").exists(): | |
| print("β Please run this from the YourMT3 directory") | |
| sys.exit(1) | |
| print(f"π Working directory: {os.getcwd()}") | |
| # Test imports | |
| if not test_basic_import(): | |
| print("\nβ Basic imports failed - install dependencies first") | |
| sys.exit(1) | |
| # Test model loading | |
| model = test_model_loading() | |
| if model is None: | |
| print("\nβ Model loading failed - check model weights") | |
| sys.exit(1) | |
| # Test instrument conditioning | |
| if test_instrument_conditioning(model): | |
| print("\nπ All tests passed!") | |
| print("\nYou can now run:") | |
| print(" python app.py") | |
| print("\nThen visit: http://127.0.0.1:7860") | |
| else: | |
| print("\nβ οΈ Some tests failed but basic functionality should work") | |
| print("You can still try running: python app.py") | |
| if __name__ == "__main__": | |
| main() | |