Spaces:
Runtime error
Runtime error
File size: 5,028 Bytes
c207bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
#!/usr/bin/env python3
"""
Quick test script for YourMT3+ instrument conditioning
Run this to test if everything is working before launching the full interface
"""
import sys
import os
from pathlib import Path
# Add amt/src to path
sys.path.append(os.path.abspath('amt/src'))
def test_basic_import():
"""Test if we can import the basic modules"""
print("π Testing basic imports...")
try:
import torch
print("β
torch")
import torchaudio
print("β
torchaudio")
import gradio as gr
print("β
gradio")
# Test YourMT3 imports
from model_helper import load_model_checkpoint, transcribe
print("β
model_helper")
from html_helper import create_html_from_midi, to_data_url
print("β
html_helper")
return True
except Exception as e:
print(f"β Import error: {e}")
return False
def test_model_loading():
"""Test model loading with debug output"""
print("\nπ Testing model loading...")
try:
from model_helper import load_model_checkpoint
# Use the same args as app.py
model_name = 'YPTF.MoE+Multi (noPS)'
precision = '16'
project = '2024'
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
print(f"Loading {model_name}...")
model = load_model_checkpoint(args=args, device="cpu")
print("β
Model loaded successfully!")
# Test our debug function
from model_helper import debug_model_task_config
debug_model_task_config(model)
return model
except Exception as e:
print(f"β Model loading failed: {e}")
import traceback
traceback.print_exc()
return None
def test_instrument_conditioning(model):
"""Test the instrument conditioning with a sample file"""
print("\nπ Testing instrument conditioning...")
# Find a test audio file
example_files = list(Path("examples").glob("*.wav"))
if not example_files:
print("β No example files found")
return False
test_file = example_files[0]
print(f"Using test file: {test_file}")
try:
import torchaudio
from model_helper import transcribe
# Create audio info
info = torchaudio.info(str(test_file))
audio_info = {
"filepath": str(test_file),
"track_name": test_file.stem + "_test",
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(str(info.encoding)),
}
print("\n--- Testing normal transcription ---")
midifile1 = transcribe(model, audio_info, instrument_hint=None)
print(f"Normal transcription result: {midifile1}")
print("\n--- Testing vocals conditioning ---")
midifile2 = transcribe(model, audio_info, instrument_hint="vocals")
print(f"Vocals transcription result: {midifile2}")
print("β
Instrument conditioning test completed!")
return True
except Exception as e:
print(f"β Instrument conditioning test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("π΅ YourMT3+ Quick Test")
print("=" * 40)
# Check if we're in the right directory
if not Path("app.py").exists():
print("β Please run this from the YourMT3 directory")
sys.exit(1)
print(f"π Working directory: {os.getcwd()}")
# Test imports
if not test_basic_import():
print("\nβ Basic imports failed - install dependencies first")
sys.exit(1)
# Test model loading
model = test_model_loading()
if model is None:
print("\nβ Model loading failed - check model weights")
sys.exit(1)
# Test instrument conditioning
if test_instrument_conditioning(model):
print("\nπ All tests passed!")
print("\nYou can now run:")
print(" python app.py")
print("\nThen visit: http://127.0.0.1:7860")
else:
print("\nβ οΈ Some tests failed but basic functionality should work")
print("You can still try running: python app.py")
if __name__ == "__main__":
main()
|