File size: 9,600 Bytes
c207bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#!/usr/bin/env python3
"""
YourMT3+ Local Setup and Debug Script

This script helps set up and debug YourMT3+ locally instead of using Colab.
Run this to check your setup and identify issues.
"""

import os
import sys
import subprocess
from pathlib import Path

def check_dependencies():
    """Check if all required dependencies are installed"""
    print("πŸ” Checking dependencies...")
    
    required_packages = [
        'torch', 'torchaudio', 'transformers', 'gradio', 
        'pytorch_lightning', 'einops', 'numpy', 'librosa'
    ]
    
    missing_packages = []
    
    for package in required_packages:
        try:
            __import__(package)
            print(f"  βœ… {package}")
        except ImportError:
            print(f"  ❌ {package} - MISSING")
            missing_packages.append(package)
    
    if missing_packages:
        print(f"\n⚠️  Missing packages: {', '.join(missing_packages)}")
        print("Install them with:")
        print(f"pip install {' '.join(missing_packages)}")
        return False
    else:
        print("βœ… All dependencies found!")
        return True

def check_model_weights():
    """Check if model weights are available"""
    print("\nπŸ” Checking model weights...")
    
    base_path = Path("amt/logs/2024")
    if not base_path.exists():
        print(f"❌ Model directory not found: {base_path}")
        print("Create the directory with: mkdir -p amt/logs/2024")
        return False
    
    # Check for the default model checkpoint
    checkpoint_name = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
    checkpoint_path = base_path / checkpoint_name
    
    if checkpoint_path.exists():
        size = checkpoint_path.stat().st_size / (1024**3)  # GB
        print(f"βœ… Model checkpoint found: {checkpoint_path}")
        print(f"   Size: {size:.2f} GB")
        return True
    else:
        print(f"❌ Model checkpoint not found: {checkpoint_path}")
        print("\nAvailable checkpoints:")
        
        found_any = False
        for ckpt in base_path.glob("*.ckpt"):
            print(f"   πŸ“„ {ckpt.name}")
            found_any = True
        
        if not found_any:
            print("   (none found)")
            print("\nπŸ’‘ You need to download model weights:")
            print("   1. Download from the official YourMT3 repository")
            print("   2. Place .ckpt files in amt/logs/2024/")
        
        return found_any

def test_model_loading():
    """Test if the model can be loaded"""
    print("\nπŸ” Testing model loading...")
    
    try:
        # Add amt/src to path
        sys.path.append(os.path.abspath('amt/src'))
        
        from model_helper import load_model_checkpoint
        
        # Test with minimal args
        model_name = 'YPTF.MoE+Multi (noPS)'
        precision = '16'
        project = '2024'
        
        checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
                '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
                '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
                '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
        
        print(f"Loading model: {model_name}")
        model = load_model_checkpoint(args=args, device="cpu")
        
        # Test task manager
        if hasattr(model, 'task_manager'):
            print("βœ… Model has task_manager")
            
            if hasattr(model.task_manager, 'task_name'):
                print(f"   Task name: {model.task_manager.task_name}")
            
            if hasattr(model.task_manager, 'task'):
                task_config = model.task_manager.task
                print(f"   Task config keys: {list(task_config.keys())}")
                
                if 'eval_subtask_prefix' in task_config:
                    prefixes = list(task_config['eval_subtask_prefix'].keys())
                    print(f"   Available subtask prefixes: {prefixes}")
                else:
                    print("   No eval_subtask_prefix found")
            
            print("βœ… Model loaded successfully!")
            return True
        else:
            print("❌ Model doesn't have task_manager")
            return False
            
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return False

def test_example_transcription():
    """Test transcription with example audio"""
    print("\nπŸ” Testing example transcription...")
    
    example_files = list(Path("examples").glob("*.wav"))[:1]  # Just test one file
    
    if not example_files:
        print("❌ No example audio files found in examples/")
        return False
    
    try:
        example_file = example_files[0]
        print(f"Testing with: {example_file}")
        
        # Import what we need
        sys.path.append(os.path.abspath('amt/src'))
        from model_helper import transcribe, load_model_checkpoint
        import torchaudio
        
        # Load model
        model_name = 'YPTF.MoE+Multi (noPS)'
        precision = '16'
        project = '2024'
        
        checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
                '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
                '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
                '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
        
        model = load_model_checkpoint(args=args, device="cpu")
        
        # Prepare audio info
        info = torchaudio.info(str(example_file))
        audio_info = {
            "filepath": str(example_file),
            "track_name": example_file.stem,
            "sample_rate": int(info.sample_rate),
            "bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
            "num_channels": int(info.num_channels),
            "num_frames": int(info.num_frames),
            "duration": int(info.num_frames / info.sample_rate),
            "encoding": str.lower(str(info.encoding)),
        }
        
        print("Testing normal transcription...")
        midifile = transcribe(model, audio_info, instrument_hint=None)
        print(f"βœ… Normal transcription successful: {midifile}")
        
        print("Testing with vocals hint...")
        midifile_vocals = transcribe(model, audio_info, instrument_hint="vocals")
        print(f"βœ… Vocals transcription successful: {midifile_vocals}")
        
        return True
        
    except Exception as e:
        print(f"❌ Error testing transcription: {e}")
        import traceback
        traceback.print_exc()
        return False

def create_local_launcher():
    """Create a simple launcher script"""
    launcher_content = '''#!/usr/bin/env python3
"""
YourMT3+ Local Launcher
Run this script to start the web interface locally
"""

import sys
import os

# Change to the YourMT3 directory
os.chdir(os.path.dirname(os.path.abspath(__file__)))

print("🎡 Starting YourMT3+ with Instrument Conditioning...")
print("πŸ“ Working directory:", os.getcwd())
print("🌐 Web interface will be available at: http://127.0.0.1:7860")
print("🎯 New feature: Select specific instruments from the dropdown!")
print()

try:
    # Run the app
    exec(open('app.py').read())
except KeyboardInterrupt:
    print("\\nπŸ‘‹ YourMT3+ stopped by user")
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()
'''
    
    with open('run_yourmt3.py', 'w') as f:
        f.write(launcher_content)
    
    # Make it executable on Unix systems
    try:
        os.chmod('run_yourmt3.py', 0o755)
    except:
        pass
    
    print("βœ… Created launcher script: run_yourmt3.py")

def main():
    print("🎡 YourMT3+ Local Setup Checker")
    print("=" * 50)
    
    # Check current directory
    if not Path("app.py").exists():
        print("❌ Not in YourMT3 directory!")
        print("Please run this script from the YourMT3 root directory")
        sys.exit(1)
    
    print(f"πŸ“ Working directory: {os.getcwd()}")
    
    # Run all checks
    deps_ok = check_dependencies()
    weights_ok = check_model_weights()
    
    if not deps_ok:
        print("\n❌ Please install missing dependencies first")
        sys.exit(1)
    
    if not weights_ok:
        print("\n❌ Please download model weights first")
        print("The app won't work without them")
        sys.exit(1)
    
    print("\n" + "=" * 50)
    model_ok = test_model_loading()
    
    if model_ok:
        print("\nπŸŽ‰ Setup looks good!")
        create_local_launcher()
        
        print("\nπŸš€ To start YourMT3+:")
        print("   python run_yourmt3.py")
        print("   OR")
        print("   python app.py")
        
        print("\nπŸ’‘ Then open: http://127.0.0.1:7860")
        
        # Ask if user wants to test transcription
        try:
            test_now = input("\nπŸ§ͺ Test transcription now? (y/n): ").lower().startswith('y')
            if test_now:
                test_example_transcription()
        except:
            pass
        
    else:
        print("\n❌ Model loading failed - check the errors above")

if __name__ == "__main__":
    main()