File size: 5,028 Bytes
c207bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
"""
Quick test script for YourMT3+ instrument conditioning
Run this to test if everything is working before launching the full interface
"""

import sys
import os
from pathlib import Path

# Add amt/src to path
sys.path.append(os.path.abspath('amt/src'))

def test_basic_import():
    """Test if we can import the basic modules"""
    print("πŸ” Testing basic imports...")
    
    try:
        import torch
        print("βœ… torch")
        
        import torchaudio
        print("βœ… torchaudio")
        
        import gradio as gr
        print("βœ… gradio")
        
        # Test YourMT3 imports
        from model_helper import load_model_checkpoint, transcribe
        print("βœ… model_helper")
        
        from html_helper import create_html_from_midi, to_data_url
        print("βœ… html_helper")
        
        return True
    except Exception as e:
        print(f"❌ Import error: {e}")
        return False

def test_model_loading():
    """Test model loading with debug output"""
    print("\nπŸ” Testing model loading...")
    
    try:
        from model_helper import load_model_checkpoint
        
        # Use the same args as app.py
        model_name = 'YPTF.MoE+Multi (noPS)'
        precision = '16'
        project = '2024'
        
        checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
                '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
                '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
                '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]

        print(f"Loading {model_name}...")
        model = load_model_checkpoint(args=args, device="cpu")
        
        print("βœ… Model loaded successfully!")
        
        # Test our debug function
        from model_helper import debug_model_task_config
        debug_model_task_config(model)
        
        return model
    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        import traceback
        traceback.print_exc()
        return None

def test_instrument_conditioning(model):
    """Test the instrument conditioning with a sample file"""
    print("\nπŸ” Testing instrument conditioning...")
    
    # Find a test audio file
    example_files = list(Path("examples").glob("*.wav"))
    if not example_files:
        print("❌ No example files found")
        return False
    
    test_file = example_files[0]
    print(f"Using test file: {test_file}")
    
    try:
        import torchaudio
        from model_helper import transcribe
        
        # Create audio info
        info = torchaudio.info(str(test_file))
        audio_info = {
            "filepath": str(test_file),
            "track_name": test_file.stem + "_test",
            "sample_rate": int(info.sample_rate),
            "bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
            "num_channels": int(info.num_channels),
            "num_frames": int(info.num_frames),
            "duration": int(info.num_frames / info.sample_rate),
            "encoding": str.lower(str(info.encoding)),
        }
        
        print("\n--- Testing normal transcription ---")
        midifile1 = transcribe(model, audio_info, instrument_hint=None)
        print(f"Normal transcription result: {midifile1}")
        
        print("\n--- Testing vocals conditioning ---")
        midifile2 = transcribe(model, audio_info, instrument_hint="vocals")
        print(f"Vocals transcription result: {midifile2}")
        
        print("βœ… Instrument conditioning test completed!")
        return True
        
    except Exception as e:
        print(f"❌ Instrument conditioning test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    print("🎡 YourMT3+ Quick Test")
    print("=" * 40)
    
    # Check if we're in the right directory
    if not Path("app.py").exists():
        print("❌ Please run this from the YourMT3 directory")
        sys.exit(1)
    
    print(f"πŸ“ Working directory: {os.getcwd()}")
    
    # Test imports
    if not test_basic_import():
        print("\n❌ Basic imports failed - install dependencies first")
        sys.exit(1)
    
    # Test model loading
    model = test_model_loading()
    if model is None:
        print("\n❌ Model loading failed - check model weights")
        sys.exit(1)
    
    # Test instrument conditioning
    if test_instrument_conditioning(model):
        print("\nπŸŽ‰ All tests passed!")
        print("\nYou can now run:")
        print("  python app.py")
        print("\nThen visit: http://127.0.0.1:7860")
    else:
        print("\n⚠️  Some tests failed but basic functionality should work")
        print("You can still try running: python app.py")

if __name__ == "__main__":
    main()