Spaces:
Runtime error
Runtime error
| # @title Model helper | |
| # import spaces # for zero-GPU | |
| import os | |
| from collections import Counter | |
| import argparse | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from model.init_train import initialize_trainer, update_config | |
| from utils.task_manager import TaskManager | |
| from config.vocabulary import drum_vocab_presets | |
| from utils.utils import str2bool | |
| from utils.utils import Timer | |
| from utils.audio import slice_padded_array | |
| from utils.note2event import mix_notes | |
| from utils.event2note import merge_zipped_note_events_and_ties_to_notes | |
| from utils.utils import write_model_output_as_midi, write_err_cnt_as_json | |
| from model.ymt3 import YourMT3 | |
| def debug_model_task_config(model): | |
| """Debug function to inspect what task configurations are available in the model""" | |
| print("=== Model Task Configuration Debug ===") | |
| if hasattr(model, 'task_manager'): | |
| print(f"✓ Model has task_manager") | |
| print(f" Task name: {getattr(model.task_manager, 'task_name', 'Unknown')}") | |
| if hasattr(model.task_manager, 'task'): | |
| task_config = model.task_manager.task | |
| print(f" Task config keys: {list(task_config.keys())}") | |
| if 'eval_subtask_prefix' in task_config: | |
| print(f" Available subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}") | |
| for key, value in task_config['eval_subtask_prefix'].items(): | |
| print(f" {key}: {value}") | |
| else: | |
| print(" No eval_subtask_prefix found") | |
| if 'subtask_tokens' in task_config: | |
| print(f" Subtask tokens: {task_config['subtask_tokens']}") | |
| else: | |
| print(" No task config found") | |
| if hasattr(model.task_manager, 'tokenizer'): | |
| tokenizer = model.task_manager.tokenizer | |
| print(f" Tokenizer available: {type(tokenizer)}") | |
| # Try to inspect available events in the codec | |
| if hasattr(tokenizer, 'codec'): | |
| codec = tokenizer.codec | |
| print(f" Codec type: {type(codec)}") | |
| if hasattr(codec, '_event_ranges'): | |
| print(f" Event ranges: {codec._event_ranges}") | |
| else: | |
| print(" No tokenizer found") | |
| else: | |
| print("✗ Model doesn't have task_manager") | |
| print("=" * 40) | |
| def create_instrument_task_tokens(model, instrument_hint, n_segments): | |
| """Create task tokens for instrument-specific transcription conditioning. | |
| Args: | |
| model: YourMT3 model instance | |
| instrument_hint: String indicating desired instrument ('vocals', 'guitar', 'piano', etc.) | |
| n_segments: Number of audio segments | |
| Returns: | |
| torch.LongTensor: Task tokens for conditioning the model | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Check what task configuration is available in the model | |
| if not hasattr(model, 'task_manager') or not hasattr(model.task_manager, 'task'): | |
| print(f"Warning: Model doesn't have task configuration, skipping task tokens for {instrument_hint}") | |
| return None | |
| task_config = model.task_manager.task | |
| # Check if this model supports subtask prefixes | |
| if 'eval_subtask_prefix' in task_config: | |
| print(f"Model supports subtask prefixes: {list(task_config['eval_subtask_prefix'].keys())}") | |
| # Map instrument hints to available subtask prefixes | |
| if instrument_hint.lower() in ['vocals', 'singing', 'voice']: | |
| if 'singing-only' in task_config['eval_subtask_prefix']: | |
| prefix_tokens = task_config['eval_subtask_prefix']['singing-only'] | |
| print(f"Using singing-only task tokens: {prefix_tokens}") | |
| else: | |
| prefix_tokens = task_config['eval_subtask_prefix'].get('default', []) | |
| print(f"Singing task not available, using default: {prefix_tokens}") | |
| elif instrument_hint.lower() in ['drums', 'drum', 'percussion']: | |
| if 'drum-only' in task_config['eval_subtask_prefix']: | |
| prefix_tokens = task_config['eval_subtask_prefix']['drum-only'] | |
| print(f"Using drum-only task tokens: {prefix_tokens}") | |
| else: | |
| prefix_tokens = task_config['eval_subtask_prefix'].get('default', []) | |
| print(f"Drum task not available, using default: {prefix_tokens}") | |
| else: | |
| # For other instruments, use default transcribe_all | |
| prefix_tokens = task_config['eval_subtask_prefix'].get('default', []) | |
| print(f"Using default task tokens for {instrument_hint}: {prefix_tokens}") | |
| else: | |
| print(f"Model doesn't support subtask prefixes, using general transcription for {instrument_hint}") | |
| # For models without subtask support, return None to use regular transcription | |
| return None | |
| # Convert to token IDs if we have prefix tokens | |
| if prefix_tokens: | |
| try: | |
| tokenizer = model.task_manager.tokenizer | |
| task_token_ids = [] | |
| for event in prefix_tokens: | |
| try: | |
| token_id = tokenizer.codec.encode_event(event) | |
| task_token_ids.append(token_id) | |
| print(f"Encoded event {event} -> token {token_id}") | |
| except Exception as e: | |
| print(f"Warning: Could not encode event {event}: {e}") | |
| continue | |
| if task_token_ids: | |
| # Create task token array: (n_segments, 1, task_len) for single channel | |
| task_len = len(task_token_ids) | |
| task_tokens = torch.zeros((n_segments, 1, task_len), dtype=torch.long, device=device) | |
| for i in range(n_segments): | |
| task_tokens[i, 0, :] = torch.tensor(task_token_ids, dtype=torch.long) | |
| print(f"Created task tokens with shape: {task_tokens.shape}") | |
| return task_tokens | |
| else: | |
| print("No valid task tokens could be created") | |
| return None | |
| except Exception as e: | |
| print(f"Warning: Could not create task tokens for {instrument_hint}: {e}") | |
| return None | |
| def filter_instrument_consistency(pred_notes, primary_instrument=None, confidence_threshold=0.7, instrument_hint=None): | |
| """Post-process transcribed notes to maintain instrument consistency. | |
| Args: | |
| pred_notes: List of Note objects from transcription | |
| primary_instrument: Target instrument program number (if known) | |
| confidence_threshold: Threshold for maintaining instrument consistency | |
| instrument_hint: Original instrument hint to help with mapping | |
| Returns: | |
| List of filtered Note objects | |
| """ | |
| if not pred_notes: | |
| return pred_notes | |
| # Count instrument occurrences to find dominant instrument | |
| instrument_counts = {} | |
| total_notes = len(pred_notes) | |
| for note in pred_notes: | |
| program = getattr(note, 'program', 0) | |
| instrument_counts[program] = instrument_counts.get(program, 0) + 1 | |
| print(f"Found instruments in transcription: {instrument_counts}") | |
| # Determine primary instrument | |
| if primary_instrument is None: | |
| primary_instrument = max(instrument_counts, key=instrument_counts.get) | |
| primary_count = instrument_counts.get(primary_instrument, 0) | |
| primary_ratio = primary_count / total_notes if total_notes > 0 else 0 | |
| print(f"Primary instrument: {primary_instrument} ({primary_ratio:.2%} of notes)") | |
| # Map instrument hints to preferred MIDI programs | |
| instrument_program_map = { | |
| 'vocals': 100, # Singing voice in YourMT3 | |
| 'singing': 100, | |
| 'voice': 100, | |
| 'piano': 0, # Acoustic Grand Piano | |
| 'guitar': 24, # Acoustic Guitar (nylon) | |
| 'violin': 40, # Violin | |
| 'drums': 128, # Drum kit | |
| 'bass': 32, # Acoustic Bass | |
| 'saxophone': 64, # Soprano Sax | |
| 'flute': 73, # Flute | |
| } | |
| # If we have an instrument hint, try to use the appropriate program | |
| if instrument_hint and instrument_hint.lower() in instrument_program_map: | |
| target_program = instrument_program_map[instrument_hint.lower()] | |
| print(f"Target program for {instrument_hint}: {target_program}") | |
| # Check if the target program exists in the transcription | |
| if target_program in instrument_counts: | |
| primary_instrument = target_program | |
| primary_ratio = instrument_counts[target_program] / total_notes | |
| print(f"Found target instrument in transcription: {primary_ratio:.2%} of notes") | |
| # If primary instrument is dominant enough, filter out other instruments | |
| if primary_ratio >= confidence_threshold: | |
| print(f"Applying consistency filter (threshold: {confidence_threshold:.2%})") | |
| filtered_notes = [] | |
| converted_count = 0 | |
| for note in pred_notes: | |
| note_program = getattr(note, 'program', 0) | |
| if note_program == primary_instrument: | |
| filtered_notes.append(note) | |
| else: | |
| # Convert note to primary instrument | |
| try: | |
| note_copy = note._replace(program=primary_instrument) | |
| filtered_notes.append(note_copy) | |
| converted_count += 1 | |
| except AttributeError: | |
| # Handle different note types | |
| note_copy = note.__class__( | |
| start=note.start, | |
| end=note.end, | |
| pitch=note.pitch, | |
| velocity=note.velocity, | |
| program=primary_instrument | |
| ) | |
| filtered_notes.append(note_copy) | |
| converted_count += 1 | |
| print(f"Converted {converted_count} notes to primary instrument {primary_instrument}") | |
| return filtered_notes | |
| else: | |
| print(f"Primary instrument ratio ({primary_ratio:.2%}) below threshold ({confidence_threshold:.2%}), keeping all instruments") | |
| return pred_notes | |
| def load_model_checkpoint(args=None, device='cpu'): | |
| parser = argparse.ArgumentParser(description="YourMT3") | |
| # General | |
| parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.') | |
| parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name') | |
| parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.') | |
| parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.') | |
| parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.') | |
| # Model configurations | |
| parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py') | |
| parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.") | |
| parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.") | |
| parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None") | |
| parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.") | |
| parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.') | |
| parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False') | |
| parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False') | |
| parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False') | |
| parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")') | |
| parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.") | |
| parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.") | |
| parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.') | |
| # Perceiver-TF configurations | |
| parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.') | |
| parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.') | |
| parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') | |
| parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.') | |
| parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.') | |
| parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.') | |
| parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.') | |
| parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.') | |
| # Decoder configurations | |
| parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') | |
| parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') | |
| # Task and Evaluation configurations | |
| parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.') | |
| parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.') | |
| parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.') | |
| parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.') | |
| parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).') | |
| parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False') | |
| parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False') | |
| # Trainer configurations | |
| parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}') | |
| parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp') | |
| parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)') | |
| parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")') | |
| parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.') | |
| # Debug | |
| parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False') | |
| parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.') | |
| args = parser.parse_args(args) | |
| # yapf: enable | |
| if torch.__version__ >= "1.13": | |
| torch.set_float32_matmul_precision("high") | |
| args.epochs = None | |
| # Initialize and update config | |
| _, _, dir_info, shared_cfg = initialize_trainer(args, stage='test') | |
| shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test') | |
| if args.eval_drum_vocab != None: # override eval_drum_vocab | |
| eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab] | |
| # Initialize task manager | |
| tm = TaskManager(task_name=args.task, | |
| max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]), | |
| debug_mode=args.debug_mode) | |
| print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}") | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Model | |
| model = YourMT3( | |
| audio_cfg=audio_cfg, | |
| model_cfg=model_cfg, | |
| shared_cfg=shared_cfg, | |
| optimizer=None, | |
| task_manager=tm, # tokenizer is a member of task_manager | |
| eval_subtask_key=args.eval_subtask_key, | |
| write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None | |
| ).to(device) | |
| checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device, weights_only=False) | |
| state_dict = checkpoint['state_dict'] | |
| new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k} | |
| model.load_state_dict(new_state_dict, strict=False) | |
| return model.eval() # load checkpoint on cpu first | |
| def transcribe(model, audio_info, instrument_hint=None): | |
| t = Timer() | |
| # Converting Audio | |
| t.start() | |
| audio, sr = torchaudio.load(uri=audio_info['filepath']) | |
| audio = torch.mean(audio, dim=0).unsqueeze(0) | |
| audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate']) | |
| audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames']) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz) | |
| t.stop(); t.print_elapsed_time("converting audio"); | |
| # Inference | |
| t.start() | |
| # Debug model configuration when using instrument hints | |
| if instrument_hint: | |
| print(f"Attempting to create task tokens for instrument: {instrument_hint}") | |
| debug_model_task_config(model) | |
| # Create task tokens for instrument-specific transcription | |
| task_tokens = None | |
| if instrument_hint: | |
| task_tokens = create_instrument_task_tokens(model, instrument_hint, audio_segments.shape[0]) | |
| pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments, task_token_array=task_tokens) | |
| t.stop(); t.print_elapsed_time("model inference"); | |
| # Post-processing | |
| t.start() | |
| num_channels = model.task_manager.num_decoding_channels | |
| n_items = audio_segments.shape[0] | |
| start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)] | |
| pred_notes_in_file = [] | |
| n_err_cnt = Counter() | |
| for ch in range(num_channels): | |
| pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr] # (B, L) | |
| zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches( | |
| pred_token_arr_ch, start_secs_file, return_events=True) | |
| pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie) | |
| pred_notes_in_file.append(pred_notes_ch) | |
| n_err_cnt += n_err_cnt_ch | |
| pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels | |
| # Apply instrument consistency filter if instrument hint was provided | |
| if instrument_hint: | |
| print(f"Applying instrument consistency filter for: {instrument_hint}") | |
| # Use more aggressive filtering if task tokens weren't available | |
| confidence_threshold = 0.6 if task_tokens is not None else 0.4 | |
| print(f"Using confidence threshold: {confidence_threshold}") | |
| pred_notes = filter_instrument_consistency(pred_notes, | |
| confidence_threshold=confidence_threshold, | |
| instrument_hint=instrument_hint) | |
| # Write MIDI | |
| write_model_output_as_midi(pred_notes, './', | |
| audio_info['track_name'], model.midi_output_inverse_vocab) | |
| t.stop(); t.print_elapsed_time("post processing"); | |
| midifile = os.path.join('./model_output/', audio_info['track_name'] + '.mid') | |
| assert os.path.exists(midifile) | |
| return midifile | |