Spaces:
Runtime error
Runtime error
File size: 8,669 Bytes
c207bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
#!/usr/bin/env python3
"""
YourMT3+ CLI with Instrument Conditioning
Command-line interface for transcribing audio with instrument-specific hints.
Usage:
python transcribe_cli.py audio.wav
python transcribe_cli.py audio.wav --instrument vocals
python transcribe_cli.py audio.wav --instrument guitar --confidence-threshold 0.8
"""
import os
import sys
import argparse
import torch
import torchaudio
from pathlib import Path
# Add the amt/src directory to the path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
from model_helper import load_model_checkpoint, transcribe
def main():
parser = argparse.ArgumentParser(
description="YourMT3+ Audio Transcription with Instrument Conditioning",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s audio.wav # Transcribe all instruments
%(prog)s audio.wav --instrument vocals # Focus on vocals only
%(prog)s audio.wav --instrument guitar # Focus on guitar only
%(prog)s audio.wav --single-instrument # Force single instrument output
%(prog)s audio.wav --instrument piano --confidence-threshold 0.9
Supported instruments:
vocals, singing, voice, guitar, piano, violin, drums, bass, saxophone, flute
"""
)
# Required arguments
parser.add_argument('audio_file', help='Path to the audio file to transcribe')
# Instrument conditioning options
parser.add_argument('--instrument', type=str,
choices=['vocals', 'singing', 'voice', 'guitar', 'piano', 'violin',
'drums', 'bass', 'saxophone', 'flute'],
help='Specify the primary instrument to transcribe')
parser.add_argument('--single-instrument', action='store_true',
help='Force single instrument output (apply consistency filtering)')
parser.add_argument('--confidence-threshold', type=float, default=0.7,
help='Confidence threshold for instrument consistency filtering (0.0-1.0, default: 0.7)')
# Model selection
parser.add_argument('--model', type=str,
default='YPTF.MoE+Multi (noPS)',
choices=['YMT3+', 'YPTF+Single (noPS)', 'YPTF+Multi (PS)',
'YPTF.MoE+Multi (noPS)', 'YPTF.MoE+Multi (PS)'],
help='Model checkpoint to use (default: YPTF.MoE+Multi (noPS))')
# Output options
parser.add_argument('--output', '-o', type=str, default=None,
help='Output MIDI file path (default: auto-generated from input filename)')
parser.add_argument('--precision', type=str, default='16', choices=['16', '32', 'bf16-mixed'],
help='Floating point precision (default: 16)')
parser.add_argument('--verbose', '-v', action='store_true',
help='Enable verbose output')
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file '{args.audio_file}' not found.")
sys.exit(1)
# Validate confidence threshold
if not 0.0 <= args.confidence_threshold <= 1.0:
print("Error: Confidence threshold must be between 0.0 and 1.0.")
sys.exit(1)
# Set output path
if args.output is None:
input_path = Path(args.audio_file)
args.output = input_path.with_suffix('.mid')
if args.verbose:
print(f"Input file: {args.audio_file}")
print(f"Output file: {args.output}")
print(f"Model: {args.model}")
if args.instrument:
print(f"Target instrument: {args.instrument}")
if args.single_instrument:
print(f"Single instrument mode: enabled (threshold: {args.confidence_threshold})")
try:
# Load model
if args.verbose:
print("Loading model...")
model_args = get_model_args(args.model, args.precision)
model = load_model_checkpoint(args=model_args, device="cpu")
model.to("cuda" if torch.cuda.is_available() else "cpu")
if args.verbose:
print("Model loaded successfully!")
# Prepare audio info
audio_info = {
"filepath": args.audio_file,
"track_name": Path(args.audio_file).stem
}
# Get audio info
info = torchaudio.info(args.audio_file)
audio_info.update({
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(str(info.encoding)),
})
# Determine instrument hint
instrument_hint = None
if args.instrument:
instrument_hint = args.instrument
elif args.single_instrument:
# Auto-detect dominant instrument but force single output
instrument_hint = "auto"
# Transcribe
if args.verbose:
print("Starting transcription...")
# Set confidence threshold in model_helper if single_instrument is enabled
if args.single_instrument:
# We'll need to modify the transcribe function to accept confidence_threshold
original_confidence = 0.7 # default
# For now, this is handled in the transcribe function
midifile = transcribe(model, audio_info, instrument_hint)
# Move output to desired location if needed
if str(args.output) != midifile:
import shutil
shutil.move(midifile, args.output)
midifile = str(args.output)
print(f"Transcription completed successfully!")
print(f"Output saved to: {midifile}")
if args.verbose:
# Print some basic statistics
file_size = os.path.getsize(midifile)
print(f"Output file size: {file_size} bytes")
print(f"Duration: {audio_info['duration']} seconds")
except Exception as e:
print(f"Error during transcription: {str(e)}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
def get_model_args(model_name, precision):
"""Get model arguments based on model name and precision."""
project = '2024'
if model_name == "YMT3+":
checkpoint = "[email protected]"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(f"Unknown model name: {model_name}")
return args
if __name__ == "__main__":
main()
|