Spaces:
Runtime error
Runtime error
File size: 25,235 Bytes
c207bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 |
# @title Model helper
# import spaces # for zero-GPU
import os
from collections import Counter
import argparse
import torch
import torchaudio
import numpy as np
from model.init_train import initialize_trainer, update_config
from utils.task_manager import TaskManager
from config.vocabulary import drum_vocab_presets
from utils.utils import str2bool
from utils.utils import Timer
from utils.audio import slice_padded_array
from utils.note2event import mix_notes
from utils.event2note import merge_zipped_note_events_and_ties_to_notes
from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
from model.ymt3 import YourMT3
def debug_model_task_config(model):
"""Debug function to inspect what task configurations are available in the model"""
print("=== Model Task Configuration Debug ===")
if hasattr(model, 'task_manager'):
print(f"✓ Model has task_manager")
print(f" Task name: {getattr(model.task_manager, 'task_name', 'Unknown')}")
if hasattr(model.task_manager, 'task'):
task_config = model.task_manager.task
print(f" Task config keys: {list(task_config.keys())}")
if 'eval_subtask_prefix' in task_config:
print(f" Available subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}")
for key, value in task_config['eval_subtask_prefix'].items():
print(f" {key}: {value}")
else:
print(" No eval_subtask_prefix found")
if 'subtask_tokens' in task_config:
print(f" Subtask tokens: {task_config['subtask_tokens']}")
else:
print(" No task config found")
if hasattr(model.task_manager, 'tokenizer'):
tokenizer = model.task_manager.tokenizer
print(f" Tokenizer available: {type(tokenizer)}")
# Try to inspect available events in the codec
if hasattr(tokenizer, 'codec'):
codec = tokenizer.codec
print(f" Codec type: {type(codec)}")
if hasattr(codec, '_event_ranges'):
print(f" Event ranges: {codec._event_ranges}")
else:
print(" No tokenizer found")
else:
print("✗ Model doesn't have task_manager")
print("=" * 40)
def create_instrument_task_tokens(model, instrument_hint, n_segments):
"""Create task tokens for instrument-specific transcription conditioning.
Args:
model: YourMT3 model instance
instrument_hint: String indicating desired instrument ('vocals', 'guitar', 'piano', etc.)
n_segments: Number of audio segments
Returns:
torch.LongTensor: Task tokens for conditioning the model
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check what task configuration is available in the model
if not hasattr(model, 'task_manager') or not hasattr(model.task_manager, 'task'):
print(f"Warning: Model doesn't have task configuration, skipping task tokens for {instrument_hint}")
return None
task_config = model.task_manager.task
# Check if this model supports subtask prefixes
if 'eval_subtask_prefix' in task_config:
print(f"Model supports subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}")
# Map instrument hints to available subtask prefixes
if instrument_hint.lower() in ['vocals', 'singing', 'voice']:
if 'singing-only' in task_config['eval_subtask_prefix']:
prefix_tokens = task_config['eval_subtask_prefix']['singing-only']
print(f"Using singing-only task tokens: {prefix_tokens}")
else:
prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
print(f"Singing task not available, using default: {prefix_tokens}")
elif instrument_hint.lower() in ['drums', 'drum', 'percussion']:
if 'drum-only' in task_config['eval_subtask_prefix']:
prefix_tokens = task_config['eval_subtask_prefix']['drum-only']
print(f"Using drum-only task tokens: {prefix_tokens}")
else:
prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
print(f"Drum task not available, using default: {prefix_tokens}")
else:
# For other instruments, use default transcribe_all
prefix_tokens = task_config['eval_subtask_prefix'].get('default', [])
print(f"Using default task tokens for {instrument_hint}: {prefix_tokens}")
else:
print(f"Model doesn't support subtask prefixes, using general transcription for {instrument_hint}")
# For models without subtask support, return None to use regular transcription
return None
# Convert to token IDs if we have prefix tokens
if prefix_tokens:
try:
tokenizer = model.task_manager.tokenizer
task_token_ids = []
for event in prefix_tokens:
try:
token_id = tokenizer.codec.encode_event(event)
task_token_ids.append(token_id)
print(f"Encoded event {event} -> token {token_id}")
except Exception as e:
print(f"Warning: Could not encode event {event}: {e}")
continue
if task_token_ids:
# Create task token array: (n_segments, 1, task_len) for single channel
task_len = len(task_token_ids)
task_tokens = torch.zeros((n_segments, 1, task_len), dtype=torch.long, device=device)
for i in range(n_segments):
task_tokens[i, 0, :] = torch.tensor(task_token_ids, dtype=torch.long)
print(f"Created task tokens with shape: {task_tokens.shape}")
return task_tokens
else:
print("No valid task tokens could be created")
return None
except Exception as e:
print(f"Warning: Could not create task tokens for {instrument_hint}: {e}")
return None
def filter_instrument_consistency(pred_notes, primary_instrument=None, confidence_threshold=0.7, instrument_hint=None):
"""Post-process transcribed notes to maintain instrument consistency.
Args:
pred_notes: List of Note objects from transcription
primary_instrument: Target instrument program number (if known)
confidence_threshold: Threshold for maintaining instrument consistency
instrument_hint: Original instrument hint to help with mapping
Returns:
List of filtered Note objects
"""
if not pred_notes:
return pred_notes
# Count instrument occurrences to find dominant instrument
instrument_counts = {}
total_notes = len(pred_notes)
for note in pred_notes:
program = getattr(note, 'program', 0)
instrument_counts[program] = instrument_counts.get(program, 0) + 1
print(f"Found instruments in transcription: {instrument_counts}")
# Determine primary instrument
if primary_instrument is None:
primary_instrument = max(instrument_counts, key=instrument_counts.get)
primary_count = instrument_counts.get(primary_instrument, 0)
primary_ratio = primary_count / total_notes if total_notes > 0 else 0
print(f"Primary instrument: {primary_instrument} ({primary_ratio:.2%} of notes)")
# Map instrument hints to preferred MIDI programs
instrument_program_map = {
'vocals': 100, # Singing voice in YourMT3
'singing': 100,
'voice': 100,
'piano': 0, # Acoustic Grand Piano
'guitar': 24, # Acoustic Guitar (nylon)
'violin': 40, # Violin
'drums': 128, # Drum kit
'bass': 32, # Acoustic Bass
'saxophone': 64, # Soprano Sax
'flute': 73, # Flute
}
# If we have an instrument hint, try to use the appropriate program
if instrument_hint and instrument_hint.lower() in instrument_program_map:
target_program = instrument_program_map[instrument_hint.lower()]
print(f"Target program for {instrument_hint}: {target_program}")
# Check if the target program exists in the transcription
if target_program in instrument_counts:
primary_instrument = target_program
primary_ratio = instrument_counts[target_program] / total_notes
print(f"Found target instrument in transcription: {primary_ratio:.2%} of notes")
# If primary instrument is dominant enough, filter out other instruments
if primary_ratio >= confidence_threshold:
print(f"Applying consistency filter (threshold: {confidence_threshold:.2%})")
filtered_notes = []
converted_count = 0
for note in pred_notes:
note_program = getattr(note, 'program', 0)
if note_program == primary_instrument:
filtered_notes.append(note)
else:
# Convert note to primary instrument
try:
note_copy = note._replace(program=primary_instrument)
filtered_notes.append(note_copy)
converted_count += 1
except AttributeError:
# Handle different note types
note_copy = note.__class__(
start=note.start,
end=note.end,
pitch=note.pitch,
velocity=note.velocity,
program=primary_instrument
)
filtered_notes.append(note_copy)
converted_count += 1
print(f"Converted {converted_count} notes to primary instrument {primary_instrument}")
return filtered_notes
else:
print(f"Primary instrument ratio ({primary_ratio:.2%}) below threshold ({confidence_threshold:.2%}), keeping all instruments")
return pred_notes
def load_model_checkpoint(args=None, device='cpu'):
parser = argparse.ArgumentParser(description="YourMT3")
# General
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
# Model configurations
parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
# Perceiver-TF configurations
parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
# Decoder configurations
parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
# Task and Evaluation configurations
parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
# Trainer configurations
parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
# Debug
parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
args = parser.parse_args(args)
# yapf: enable
if torch.__version__ >= "1.13":
torch.set_float32_matmul_precision("high")
args.epochs = None
# Initialize and update config
_, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')
if args.eval_drum_vocab != None: # override eval_drum_vocab
eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]
# Initialize task manager
tm = TaskManager(task_name=args.task,
max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
debug_mode=args.debug_mode)
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model
model = YourMT3(
audio_cfg=audio_cfg,
model_cfg=model_cfg,
shared_cfg=shared_cfg,
optimizer=None,
task_manager=tm, # tokenizer is a member of task_manager
eval_subtask_key=args.eval_subtask_key,
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
).to(device)
checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device, weights_only=False)
state_dict = checkpoint['state_dict']
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
model.load_state_dict(new_state_dict, strict=False)
return model.eval() # load checkpoint on cpu first
def transcribe(model, audio_info, instrument_hint=None):
t = Timer()
# Converting Audio
t.start()
audio, sr = torchaudio.load(uri=audio_info['filepath'])
audio = torch.mean(audio, dim=0).unsqueeze(0)
audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
t.stop(); t.print_elapsed_time("converting audio");
# Inference
t.start()
# Debug model configuration when using instrument hints
if instrument_hint:
print(f"Attempting to create task tokens for instrument: {instrument_hint}")
debug_model_task_config(model)
# Create task tokens for instrument-specific transcription
task_tokens = None
if instrument_hint:
task_tokens = create_instrument_task_tokens(model, instrument_hint, audio_segments.shape[0])
pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments, task_token_array=task_tokens)
t.stop(); t.print_elapsed_time("model inference");
# Post-processing
t.start()
num_channels = model.task_manager.num_decoding_channels
n_items = audio_segments.shape[0]
start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
pred_notes_in_file = []
n_err_cnt = Counter()
for ch in range(num_channels):
pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr] # (B, L)
zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
pred_token_arr_ch, start_secs_file, return_events=True)
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
pred_notes_in_file.append(pred_notes_ch)
n_err_cnt += n_err_cnt_ch
pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels
# Apply instrument consistency filter if instrument hint was provided
if instrument_hint:
print(f"Applying instrument consistency filter for: {instrument_hint}")
# Use more aggressive filtering if task tokens weren't available
confidence_threshold = 0.6 if task_tokens is not None else 0.4
print(f"Using confidence threshold: {confidence_threshold}")
pred_notes = filter_instrument_consistency(pred_notes,
confidence_threshold=confidence_threshold,
instrument_hint=instrument_hint)
# Write MIDI
write_model_output_as_midi(pred_notes, './',
audio_info['track_name'], model.midi_output_inverse_vocab)
t.stop(); t.print_elapsed_time("post processing");
midifile = os.path.join('./model_output/', audio_info['track_name'] + '.mid')
assert os.path.exists(midifile)
return midifile
|