File size: 25,235 Bytes
c207bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# @title Model helper
# import spaces # for zero-GPU

import os
from collections import Counter
import argparse
import torch
import torchaudio
import numpy as np

from model.init_train import initialize_trainer, update_config
from utils.task_manager import TaskManager
from config.vocabulary import drum_vocab_presets
from utils.utils import str2bool
from utils.utils import Timer
from utils.audio import slice_padded_array
from utils.note2event import mix_notes
from utils.event2note import merge_zipped_note_events_and_ties_to_notes
from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
from model.ymt3 import YourMT3


def debug_model_task_config(model):
    """Debug function to inspect what task configurations are available in the model"""
    print("=== Model Task Configuration Debug ===")
    
    if hasattr(model, 'task_manager'):
        print(f"✓ Model has task_manager")
        print(f"  Task name: {getattr(model.task_manager, 'task_name', 'Unknown')}")
        
        if hasattr(model.task_manager, 'task'):
            task_config = model.task_manager.task
            print(f"  Task config keys: {list(task_config.keys())}")
            
            if 'eval_subtask_prefix' in task_config:
                print(f"  Available subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}")
                for key, value in task_config['eval_subtask_prefix'].items():
                    print(f"    {key}: {value}")
            else:
                print("  No eval_subtask_prefix found")
                
            if 'subtask_tokens' in task_config:
                print(f"  Subtask tokens: {task_config['subtask_tokens']}")
        else:
            print("  No task config found")
            
        if hasattr(model.task_manager, 'tokenizer'):
            tokenizer = model.task_manager.tokenizer
            print(f"  Tokenizer available: {type(tokenizer)}")
            
            # Try to inspect available events in the codec
            if hasattr(tokenizer, 'codec'):
                codec = tokenizer.codec
                print(f"  Codec type: {type(codec)}")
                if hasattr(codec, '_event_ranges'):
                    print(f"  Event ranges: {codec._event_ranges}")
        else:
            print("  No tokenizer found")
    else:
        print("✗ Model doesn't have task_manager")
    
    print("=" * 40)


def create_instrument_task_tokens(model, instrument_hint, n_segments):
    """Create task tokens for instrument-specific transcription conditioning.
    
    Args:
        model: YourMT3 model instance
        instrument_hint: String indicating desired instrument ('vocals', 'guitar', 'piano', etc.)
        n_segments: Number of audio segments
        
    Returns:
        torch.LongTensor: Task tokens for conditioning the model
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Check what task configuration is available in the model
    if not hasattr(model, 'task_manager') or not hasattr(model.task_manager, 'task'):
        print(f"Warning: Model doesn't have task configuration, skipping task tokens for {instrument_hint}")
        return None
    
    task_config = model.task_manager.task
    
    # Check if this model supports subtask prefixes
    if 'eval_subtask_prefix' in task_config:
        print(f"Model supports subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}")
        
        # Map instrument hints to available subtask prefixes
        if instrument_hint.lower() in ['vocals', 'singing', 'voice']:
            if 'singing-only' in task_config['eval_subtask_prefix']:
                prefix_tokens = task_config['eval_subtask_prefix']['singing-only']
                print(f"Using singing-only task tokens: {prefix_tokens}")
            else:
                prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
                print(f"Singing task not available, using default: {prefix_tokens}")
        elif instrument_hint.lower() in ['drums', 'drum', 'percussion']:
            if 'drum-only' in task_config['eval_subtask_prefix']:
                prefix_tokens = task_config['eval_subtask_prefix']['drum-only']
                print(f"Using drum-only task tokens: {prefix_tokens}")
            else:
                prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
                print(f"Drum task not available, using default: {prefix_tokens}")
        else:
            # For other instruments, use default transcribe_all
            prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
            print(f"Using default task tokens for {instrument_hint}: {prefix_tokens}")
    else:
        print(f"Model doesn't support subtask prefixes, using general transcription for {instrument_hint}")
        # For models without subtask support, return None to use regular transcription
        return None
    
    # Convert to token IDs if we have prefix tokens
    if prefix_tokens:
        try:
            tokenizer = model.task_manager.tokenizer
            task_token_ids = []
            
            for event in prefix_tokens:
                try:
                    token_id = tokenizer.codec.encode_event(event)
                    task_token_ids.append(token_id)
                    print(f"Encoded event {event} -> token {token_id}")
                except Exception as e:
                    print(f"Warning: Could not encode event {event}: {e}")
                    continue
            
            if task_token_ids:
                # Create task token array: (n_segments, 1, task_len) for single channel
                task_len = len(task_token_ids)
                task_tokens = torch.zeros((n_segments, 1, task_len), dtype=torch.long, device=device)
                for i in range(n_segments):
                    task_tokens[i, 0, :] = torch.tensor(task_token_ids, dtype=torch.long)
                
                print(f"Created task tokens with shape: {task_tokens.shape}")
                return task_tokens
            else:
                print("No valid task tokens could be created")
                return None
                
        except Exception as e:
            print(f"Warning: Could not create task tokens for {instrument_hint}: {e}")
    
    return None


def filter_instrument_consistency(pred_notes, primary_instrument=None, confidence_threshold=0.7, instrument_hint=None):
    """Post-process transcribed notes to maintain instrument consistency.
    
    Args:
        pred_notes: List of Note objects from transcription
        primary_instrument: Target instrument program number (if known)
        confidence_threshold: Threshold for maintaining instrument consistency
        instrument_hint: Original instrument hint to help with mapping
        
    Returns:
        List of filtered Note objects
    """
    if not pred_notes:
        return pred_notes
    
    # Count instrument occurrences to find dominant instrument
    instrument_counts = {}
    total_notes = len(pred_notes)
    
    for note in pred_notes:
        program = getattr(note, 'program', 0)
        instrument_counts[program] = instrument_counts.get(program, 0) + 1
    
    print(f"Found instruments in transcription: {instrument_counts}")
    
    # Determine primary instrument
    if primary_instrument is None:
        primary_instrument = max(instrument_counts, key=instrument_counts.get)
    
    primary_count = instrument_counts.get(primary_instrument, 0)
    primary_ratio = primary_count / total_notes if total_notes > 0 else 0
    
    print(f"Primary instrument: {primary_instrument} ({primary_ratio:.2%} of notes)")
    
    # Map instrument hints to preferred MIDI programs
    instrument_program_map = {
        'vocals': 100,  # Singing voice in YourMT3
        'singing': 100,
        'voice': 100,
        'piano': 0,     # Acoustic Grand Piano
        'guitar': 24,   # Acoustic Guitar (nylon)
        'violin': 40,   # Violin
        'drums': 128,   # Drum kit
        'bass': 32,     # Acoustic Bass
        'saxophone': 64, # Soprano Sax
        'flute': 73,    # Flute
    }
    
    # If we have an instrument hint, try to use the appropriate program
    if instrument_hint and instrument_hint.lower() in instrument_program_map:
        target_program = instrument_program_map[instrument_hint.lower()]
        print(f"Target program for {instrument_hint}: {target_program}")
        
        # Check if the target program exists in the transcription
        if target_program in instrument_counts:
            primary_instrument = target_program
            primary_ratio = instrument_counts[target_program] / total_notes
            print(f"Found target instrument in transcription: {primary_ratio:.2%} of notes")
    
    # If primary instrument is dominant enough, filter out other instruments
    if primary_ratio >= confidence_threshold:
        print(f"Applying consistency filter (threshold: {confidence_threshold:.2%})")
        filtered_notes = []
        converted_count = 0
        
        for note in pred_notes:
            note_program = getattr(note, 'program', 0)
            if note_program == primary_instrument:
                filtered_notes.append(note)
            else:
                # Convert note to primary instrument
                try:
                    note_copy = note._replace(program=primary_instrument)
                    filtered_notes.append(note_copy)
                    converted_count += 1
                except AttributeError:
                    # Handle different note types
                    note_copy = note.__class__(
                        start=note.start,
                        end=note.end, 
                        pitch=note.pitch,
                        velocity=note.velocity,
                        program=primary_instrument
                    )
                    filtered_notes.append(note_copy)
                    converted_count += 1
        
        print(f"Converted {converted_count} notes to primary instrument {primary_instrument}")
        return filtered_notes
    else:
        print(f"Primary instrument ratio ({primary_ratio:.2%}) below threshold ({confidence_threshold:.2%}), keeping all instruments")
    
    return pred_notes




def load_model_checkpoint(args=None, device='cpu'):
    parser = argparse.ArgumentParser(description="YourMT3")
    # General
    parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
    parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
    parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
    parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
    # Model configurations
    parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
    parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
    parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
    parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
    parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
    parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
    parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
    parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
    parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
    parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
    parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
    parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
    parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
    # Perceiver-TF configurations
    parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
    parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
    parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
    parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
    # Decoder configurations
    parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    # Task and Evaluation configurations
    parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
    parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
    parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
    parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
    parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
    # Trainer configurations
    parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
    parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
    parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
    parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
    parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
    # Debug
    parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
    parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
    args = parser.parse_args(args)
    # yapf: enable
    if torch.__version__ >= "1.13":
        torch.set_float32_matmul_precision("high")
    args.epochs = None

    # Initialize and update config
    _, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
    shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')

    if args.eval_drum_vocab != None:  # override eval_drum_vocab
        eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]

    # Initialize task manager
    tm = TaskManager(task_name=args.task,
                     max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
                     debug_mode=args.debug_mode)
    print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model
    model = YourMT3(
        audio_cfg=audio_cfg,
        model_cfg=model_cfg,
        shared_cfg=shared_cfg,
        optimizer=None,
        task_manager=tm,  # tokenizer is a member of task_manager
        eval_subtask_key=args.eval_subtask_key,
        write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
        ).to(device)
    checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device, weights_only=False)
    state_dict = checkpoint['state_dict']
    new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
    model.load_state_dict(new_state_dict, strict=False)
    return model.eval() # load checkpoint on cpu first


def transcribe(model, audio_info, instrument_hint=None):
    t = Timer()

    # Converting Audio
    t.start()
    audio, sr = torchaudio.load(uri=audio_info['filepath'])
    audio = torch.mean(audio, dim=0).unsqueeze(0)
    audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
    audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
    t.stop(); t.print_elapsed_time("converting audio");

    # Inference
    t.start()
    
    # Debug model configuration when using instrument hints
    if instrument_hint:
        print(f"Attempting to create task tokens for instrument: {instrument_hint}")
        debug_model_task_config(model)
    
    # Create task tokens for instrument-specific transcription
    task_tokens = None
    if instrument_hint:
        task_tokens = create_instrument_task_tokens(model, instrument_hint, audio_segments.shape[0])
    
    pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments, task_token_array=task_tokens)
    t.stop(); t.print_elapsed_time("model inference");

    # Post-processing
    t.start()
    num_channels = model.task_manager.num_decoding_channels
    n_items = audio_segments.shape[0]
    start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
    pred_notes_in_file = []
    n_err_cnt = Counter()
    for ch in range(num_channels):
        pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr]  # (B, L)
        zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
            pred_token_arr_ch, start_secs_file, return_events=True)
        pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
        pred_notes_in_file.append(pred_notes_ch)
        n_err_cnt += n_err_cnt_ch
    pred_notes = mix_notes(pred_notes_in_file)  # This is the mixed notes from all channels
    
    # Apply instrument consistency filter if instrument hint was provided
    if instrument_hint:
        print(f"Applying instrument consistency filter for: {instrument_hint}")
        # Use more aggressive filtering if task tokens weren't available
        confidence_threshold = 0.6 if task_tokens is not None else 0.4
        print(f"Using confidence threshold: {confidence_threshold}")
        pred_notes = filter_instrument_consistency(pred_notes, 
                                                 confidence_threshold=confidence_threshold,
                                                 instrument_hint=instrument_hint)

    # Write MIDI
    write_model_output_as_midi(pred_notes, './',
                              audio_info['track_name'], model.midi_output_inverse_vocab)
    t.stop(); t.print_elapsed_time("post processing");
    midifile =  os.path.join('./model_output/', audio_info['track_name']  + '.mid')
    assert os.path.exists(midifile)
    return midifile