File size: 8,669 Bytes
c207bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
"""
YourMT3+ CLI with Instrument Conditioning
Command-line interface for transcribing audio with instrument-specific hints.

Usage:
    python transcribe_cli.py audio.wav
    python transcribe_cli.py audio.wav --instrument vocals
    python transcribe_cli.py audio.wav --instrument guitar --confidence-threshold 0.8
"""

import os
import sys
import argparse
import torch
import torchaudio
from pathlib import Path

# Add the amt/src directory to the path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))

from model_helper import load_model_checkpoint, transcribe


def main():
    parser = argparse.ArgumentParser(
        description="YourMT3+ Audio Transcription with Instrument Conditioning",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s audio.wav                                  # Transcribe all instruments
  %(prog)s audio.wav --instrument vocals              # Focus on vocals only
  %(prog)s audio.wav --instrument guitar              # Focus on guitar only
  %(prog)s audio.wav --single-instrument              # Force single instrument output
  %(prog)s audio.wav --instrument piano --confidence-threshold 0.9

Supported instruments:
  vocals, singing, voice, guitar, piano, violin, drums, bass, saxophone, flute
        """
    )
    
    # Required arguments
    parser.add_argument('audio_file', help='Path to the audio file to transcribe')
    
    # Instrument conditioning options
    parser.add_argument('--instrument', type=str, 
                       choices=['vocals', 'singing', 'voice', 'guitar', 'piano', 'violin', 
                               'drums', 'bass', 'saxophone', 'flute'],
                       help='Specify the primary instrument to transcribe')
    
    parser.add_argument('--single-instrument', action='store_true',
                       help='Force single instrument output (apply consistency filtering)')
    
    parser.add_argument('--confidence-threshold', type=float, default=0.7,
                       help='Confidence threshold for instrument consistency filtering (0.0-1.0, default: 0.7)')
    
    # Model selection
    parser.add_argument('--model', type=str, 
                       default='YPTF.MoE+Multi (noPS)',
                       choices=['YMT3+', 'YPTF+Single (noPS)', 'YPTF+Multi (PS)', 
                               'YPTF.MoE+Multi (noPS)', 'YPTF.MoE+Multi (PS)'],
                       help='Model checkpoint to use (default: YPTF.MoE+Multi (noPS))')
    
    # Output options  
    parser.add_argument('--output', '-o', type=str, default=None,
                       help='Output MIDI file path (default: auto-generated from input filename)')
    
    parser.add_argument('--precision', type=str, default='16', choices=['16', '32', 'bf16-mixed'],
                       help='Floating point precision (default: 16)')
    
    parser.add_argument('--verbose', '-v', action='store_true',
                       help='Enable verbose output')
    
    args = parser.parse_args()
    
    # Validate input file
    if not os.path.exists(args.audio_file):
        print(f"Error: Audio file '{args.audio_file}' not found.")
        sys.exit(1)
    
    # Validate confidence threshold
    if not 0.0 <= args.confidence_threshold <= 1.0:
        print("Error: Confidence threshold must be between 0.0 and 1.0.")
        sys.exit(1)
    
    # Set output path
    if args.output is None:
        input_path = Path(args.audio_file)
        args.output = input_path.with_suffix('.mid')
    
    if args.verbose:
        print(f"Input file: {args.audio_file}")
        print(f"Output file: {args.output}")
        print(f"Model: {args.model}")
        if args.instrument:
            print(f"Target instrument: {args.instrument}")
        if args.single_instrument:
            print(f"Single instrument mode: enabled (threshold: {args.confidence_threshold})")
    
    try:
        # Load model
        if args.verbose:
            print("Loading model...")
        
        model_args = get_model_args(args.model, args.precision)
        model = load_model_checkpoint(args=model_args, device="cpu")
        model.to("cuda" if torch.cuda.is_available() else "cpu")
        
        if args.verbose:
            print("Model loaded successfully!")
        
        # Prepare audio info
        audio_info = {
            "filepath": args.audio_file,
            "track_name": Path(args.audio_file).stem
        }
        
        # Get audio info
        info = torchaudio.info(args.audio_file)
        audio_info.update({
            "sample_rate": int(info.sample_rate),
            "bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
            "num_channels": int(info.num_channels),
            "num_frames": int(info.num_frames),
            "duration": int(info.num_frames / info.sample_rate),
            "encoding": str.lower(str(info.encoding)),
        })
        
        # Determine instrument hint
        instrument_hint = None
        if args.instrument:
            instrument_hint = args.instrument
        elif args.single_instrument:
            # Auto-detect dominant instrument but force single output
            instrument_hint = "auto"
        
        # Transcribe
        if args.verbose:
            print("Starting transcription...")
        
        # Set confidence threshold in model_helper if single_instrument is enabled
        if args.single_instrument:
            # We'll need to modify the transcribe function to accept confidence_threshold
            original_confidence = 0.7  # default
            # For now, this is handled in the transcribe function
        
        midifile = transcribe(model, audio_info, instrument_hint)
        
        # Move output to desired location if needed
        if str(args.output) != midifile:
            import shutil
            shutil.move(midifile, args.output)
            midifile = str(args.output)
        
        print(f"Transcription completed successfully!")
        print(f"Output saved to: {midifile}")
        
        if args.verbose:
            # Print some basic statistics
            file_size = os.path.getsize(midifile)
            print(f"Output file size: {file_size} bytes")
            print(f"Duration: {audio_info['duration']} seconds")
        
    except Exception as e:
        print(f"Error during transcription: {str(e)}")
        if args.verbose:
            import traceback
            traceback.print_exc()
        sys.exit(1)


def get_model_args(model_name, precision):
    """Get model arguments based on model name and precision."""
    project = '2024'
    
    if model_name == "YMT3+":
        checkpoint = "[email protected]"
        args = [checkpoint, '-p', project, '-pr', precision]
    elif model_name == "YPTF+Single (noPS)":
        checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
        args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
                '-hop', '300', '-atc', '1', '-pr', precision]
    elif model_name == "YPTF+Multi (PS)":
        checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
                '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
                '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
    elif model_name == "YPTF.MoE+Multi (noPS)":
        checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
                '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
                '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
                '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
    elif model_name == "YPTF.MoE+Multi (PS)":
        checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
        args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
                '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
                '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
                '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    return args


if __name__ == "__main__":
    main()