yourmt3 / setup_local.py
asdd12e2ad's picture
asd
c207bc4
#!/usr/bin/env python3
"""
YourMT3+ Local Setup and Debug Script
This script helps set up and debug YourMT3+ locally instead of using Colab.
Run this to check your setup and identify issues.
"""
import os
import sys
import subprocess
from pathlib import Path
def check_dependencies():
"""Check if all required dependencies are installed"""
print("πŸ” Checking dependencies...")
required_packages = [
'torch', 'torchaudio', 'transformers', 'gradio',
'pytorch_lightning', 'einops', 'numpy', 'librosa'
]
missing_packages = []
for package in required_packages:
try:
__import__(package)
print(f" βœ… {package}")
except ImportError:
print(f" ❌ {package} - MISSING")
missing_packages.append(package)
if missing_packages:
print(f"\n⚠️ Missing packages: {', '.join(missing_packages)}")
print("Install them with:")
print(f"pip install {' '.join(missing_packages)}")
return False
else:
print("βœ… All dependencies found!")
return True
def check_model_weights():
"""Check if model weights are available"""
print("\nπŸ” Checking model weights...")
base_path = Path("amt/logs/2024")
if not base_path.exists():
print(f"❌ Model directory not found: {base_path}")
print("Create the directory with: mkdir -p amt/logs/2024")
return False
# Check for the default model checkpoint
checkpoint_name = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
checkpoint_path = base_path / checkpoint_name
if checkpoint_path.exists():
size = checkpoint_path.stat().st_size / (1024**3) # GB
print(f"βœ… Model checkpoint found: {checkpoint_path}")
print(f" Size: {size:.2f} GB")
return True
else:
print(f"❌ Model checkpoint not found: {checkpoint_path}")
print("\nAvailable checkpoints:")
found_any = False
for ckpt in base_path.glob("*.ckpt"):
print(f" πŸ“„ {ckpt.name}")
found_any = True
if not found_any:
print(" (none found)")
print("\nπŸ’‘ You need to download model weights:")
print(" 1. Download from the official YourMT3 repository")
print(" 2. Place .ckpt files in amt/logs/2024/")
return found_any
def test_model_loading():
"""Test if the model can be loaded"""
print("\nπŸ” Testing model loading...")
try:
# Add amt/src to path
sys.path.append(os.path.abspath('amt/src'))
from model_helper import load_model_checkpoint
# Test with minimal args
model_name = 'YPTF.MoE+Multi (noPS)'
precision = '16'
project = '2024'
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
print(f"Loading model: {model_name}")
model = load_model_checkpoint(args=args, device="cpu")
# Test task manager
if hasattr(model, 'task_manager'):
print("βœ… Model has task_manager")
if hasattr(model.task_manager, 'task_name'):
print(f" Task name: {model.task_manager.task_name}")
if hasattr(model.task_manager, 'task'):
task_config = model.task_manager.task
print(f" Task config keys: {list(task_config.keys())}")
if 'eval_subtask_prefix' in task_config:
prefixes = list(task_config['eval_subtask_prefix'].keys())
print(f" Available subtask prefixes: {prefixes}")
else:
print(" No eval_subtask_prefix found")
print("βœ… Model loaded successfully!")
return True
else:
print("❌ Model doesn't have task_manager")
return False
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
return False
def test_example_transcription():
"""Test transcription with example audio"""
print("\nπŸ” Testing example transcription...")
example_files = list(Path("examples").glob("*.wav"))[:1] # Just test one file
if not example_files:
print("❌ No example audio files found in examples/")
return False
try:
example_file = example_files[0]
print(f"Testing with: {example_file}")
# Import what we need
sys.path.append(os.path.abspath('amt/src'))
from model_helper import transcribe, load_model_checkpoint
import torchaudio
# Load model
model_name = 'YPTF.MoE+Multi (noPS)'
precision = '16'
project = '2024'
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
model = load_model_checkpoint(args=args, device="cpu")
# Prepare audio info
info = torchaudio.info(str(example_file))
audio_info = {
"filepath": str(example_file),
"track_name": example_file.stem,
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(str(info.encoding)),
}
print("Testing normal transcription...")
midifile = transcribe(model, audio_info, instrument_hint=None)
print(f"βœ… Normal transcription successful: {midifile}")
print("Testing with vocals hint...")
midifile_vocals = transcribe(model, audio_info, instrument_hint="vocals")
print(f"βœ… Vocals transcription successful: {midifile_vocals}")
return True
except Exception as e:
print(f"❌ Error testing transcription: {e}")
import traceback
traceback.print_exc()
return False
def create_local_launcher():
"""Create a simple launcher script"""
launcher_content = '''#!/usr/bin/env python3
"""
YourMT3+ Local Launcher
Run this script to start the web interface locally
"""
import sys
import os
# Change to the YourMT3 directory
os.chdir(os.path.dirname(os.path.abspath(__file__)))
print("🎡 Starting YourMT3+ with Instrument Conditioning...")
print("πŸ“ Working directory:", os.getcwd())
print("🌐 Web interface will be available at: http://127.0.0.1:7860")
print("🎯 New feature: Select specific instruments from the dropdown!")
print()
try:
# Run the app
exec(open('app.py').read())
except KeyboardInterrupt:
print("\\nπŸ‘‹ YourMT3+ stopped by user")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
'''
with open('run_yourmt3.py', 'w') as f:
f.write(launcher_content)
# Make it executable on Unix systems
try:
os.chmod('run_yourmt3.py', 0o755)
except:
pass
print("βœ… Created launcher script: run_yourmt3.py")
def main():
print("🎡 YourMT3+ Local Setup Checker")
print("=" * 50)
# Check current directory
if not Path("app.py").exists():
print("❌ Not in YourMT3 directory!")
print("Please run this script from the YourMT3 root directory")
sys.exit(1)
print(f"πŸ“ Working directory: {os.getcwd()}")
# Run all checks
deps_ok = check_dependencies()
weights_ok = check_model_weights()
if not deps_ok:
print("\n❌ Please install missing dependencies first")
sys.exit(1)
if not weights_ok:
print("\n❌ Please download model weights first")
print("The app won't work without them")
sys.exit(1)
print("\n" + "=" * 50)
model_ok = test_model_loading()
if model_ok:
print("\nπŸŽ‰ Setup looks good!")
create_local_launcher()
print("\nπŸš€ To start YourMT3+:")
print(" python run_yourmt3.py")
print(" OR")
print(" python app.py")
print("\nπŸ’‘ Then open: http://127.0.0.1:7860")
# Ask if user wants to test transcription
try:
test_now = input("\nπŸ§ͺ Test transcription now? (y/n): ").lower().startswith('y')
if test_now:
test_example_transcription()
except:
pass
else:
print("\n❌ Model loading failed - check the errors above")
if __name__ == "__main__":
main()