yourmt3 / transcribe_cli.py
asdd12e2ad's picture
asd
c207bc4
#!/usr/bin/env python3
"""
YourMT3+ CLI with Instrument Conditioning
Command-line interface for transcribing audio with instrument-specific hints.
Usage:
python transcribe_cli.py audio.wav
python transcribe_cli.py audio.wav --instrument vocals
python transcribe_cli.py audio.wav --instrument guitar --confidence-threshold 0.8
"""
import os
import sys
import argparse
import torch
import torchaudio
from pathlib import Path
# Add the amt/src directory to the path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
from model_helper import load_model_checkpoint, transcribe
def main():
parser = argparse.ArgumentParser(
description="YourMT3+ Audio Transcription with Instrument Conditioning",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s audio.wav # Transcribe all instruments
%(prog)s audio.wav --instrument vocals # Focus on vocals only
%(prog)s audio.wav --instrument guitar # Focus on guitar only
%(prog)s audio.wav --single-instrument # Force single instrument output
%(prog)s audio.wav --instrument piano --confidence-threshold 0.9
Supported instruments:
vocals, singing, voice, guitar, piano, violin, drums, bass, saxophone, flute
"""
)
# Required arguments
parser.add_argument('audio_file', help='Path to the audio file to transcribe')
# Instrument conditioning options
parser.add_argument('--instrument', type=str,
choices=['vocals', 'singing', 'voice', 'guitar', 'piano', 'violin',
'drums', 'bass', 'saxophone', 'flute'],
help='Specify the primary instrument to transcribe')
parser.add_argument('--single-instrument', action='store_true',
help='Force single instrument output (apply consistency filtering)')
parser.add_argument('--confidence-threshold', type=float, default=0.7,
help='Confidence threshold for instrument consistency filtering (0.0-1.0, default: 0.7)')
# Model selection
parser.add_argument('--model', type=str,
default='YPTF.MoE+Multi (noPS)',
choices=['YMT3+', 'YPTF+Single (noPS)', 'YPTF+Multi (PS)',
'YPTF.MoE+Multi (noPS)', 'YPTF.MoE+Multi (PS)'],
help='Model checkpoint to use (default: YPTF.MoE+Multi (noPS))')
# Output options
parser.add_argument('--output', '-o', type=str, default=None,
help='Output MIDI file path (default: auto-generated from input filename)')
parser.add_argument('--precision', type=str, default='16', choices=['16', '32', 'bf16-mixed'],
help='Floating point precision (default: 16)')
parser.add_argument('--verbose', '-v', action='store_true',
help='Enable verbose output')
args = parser.parse_args()
# Validate input file
if not os.path.exists(args.audio_file):
print(f"Error: Audio file '{args.audio_file}' not found.")
sys.exit(1)
# Validate confidence threshold
if not 0.0 <= args.confidence_threshold <= 1.0:
print("Error: Confidence threshold must be between 0.0 and 1.0.")
sys.exit(1)
# Set output path
if args.output is None:
input_path = Path(args.audio_file)
args.output = input_path.with_suffix('.mid')
if args.verbose:
print(f"Input file: {args.audio_file}")
print(f"Output file: {args.output}")
print(f"Model: {args.model}")
if args.instrument:
print(f"Target instrument: {args.instrument}")
if args.single_instrument:
print(f"Single instrument mode: enabled (threshold: {args.confidence_threshold})")
try:
# Load model
if args.verbose:
print("Loading model...")
model_args = get_model_args(args.model, args.precision)
model = load_model_checkpoint(args=model_args, device="cpu")
model.to("cuda" if torch.cuda.is_available() else "cpu")
if args.verbose:
print("Model loaded successfully!")
# Prepare audio info
audio_info = {
"filepath": args.audio_file,
"track_name": Path(args.audio_file).stem
}
# Get audio info
info = torchaudio.info(args.audio_file)
audio_info.update({
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(str(info.encoding)),
})
# Determine instrument hint
instrument_hint = None
if args.instrument:
instrument_hint = args.instrument
elif args.single_instrument:
# Auto-detect dominant instrument but force single output
instrument_hint = "auto"
# Transcribe
if args.verbose:
print("Starting transcription...")
# Set confidence threshold in model_helper if single_instrument is enabled
if args.single_instrument:
# We'll need to modify the transcribe function to accept confidence_threshold
original_confidence = 0.7 # default
# For now, this is handled in the transcribe function
midifile = transcribe(model, audio_info, instrument_hint)
# Move output to desired location if needed
if str(args.output) != midifile:
import shutil
shutil.move(midifile, args.output)
midifile = str(args.output)
print(f"Transcription completed successfully!")
print(f"Output saved to: {midifile}")
if args.verbose:
# Print some basic statistics
file_size = os.path.getsize(midifile)
print(f"Output file size: {file_size} bytes")
print(f"Duration: {audio_info['duration']} seconds")
except Exception as e:
print(f"Error during transcription: {str(e)}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
def get_model_args(model_name, precision):
"""Get model arguments based on model name and precision."""
project = '2024'
if model_name == "YMT3+":
checkpoint = "[email protected]"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(f"Unknown model name: {model_name}")
return args
if __name__ == "__main__":
main()