Spaces:
Sleeping
Sleeping
| import json | |
| import os.path | |
| from pprint import pprint | |
| import random | |
| import string | |
| from types import SimpleNamespace | |
| from typing import List | |
| import pandas as pd | |
| import librosa | |
| import torch | |
| from tqdm import tqdm | |
| from core.data.moisesdb.datamodule import ( | |
| MoisesTestDataModule, | |
| MoisesValidationDataModule, | |
| MoisesDataModule, | |
| MoisesBalancedTrainDataModule, | |
| MoisesVDBODataModule, | |
| ) | |
| from core.losses.base import AdversarialLossHandler, BaseLossHandler | |
| from core.losses.l1snr import ( | |
| L1SNRDecibelMatchLoss, | |
| L1SNRLoss, | |
| WeightedL1Loss, | |
| L1SNRLossIgnoreSilence, | |
| ) | |
| from core.metrics.base import BaseMetricHandler, MultiModeMetricHandler | |
| from core.metrics.snr import ( | |
| SafeScaleInvariantSignalNoiseRatio, | |
| SafeSignalNoiseRatio, | |
| PredictedDecibels, | |
| TargetDecibels, | |
| ) | |
| from core.models.ebase import EndToEndLightningSystem | |
| from core.models.e2e.bandit.bandit import Bandit, PasstFiLMConditionedBandit | |
| from omegaconf import OmegaConf | |
| from core.types import LossHandler, OptimizationBundle | |
| from torch import nn, optim | |
| from torch.optim import lr_scheduler | |
| import torchaudio as ta | |
| import numpy as np | |
| import torchmetrics as tm | |
| import pytorch_lightning as pl | |
| import pytorch_lightning.callbacks | |
| import pytorch_lightning.loggers | |
| from pytorch_lightning.profilers import AdvancedProfiler | |
| import torch.backends.cudnn | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cudnn.benchmark = True | |
| def _allowed_classes_to_dict(allowed_classes): | |
| return {cls.__name__: cls for cls in allowed_classes} | |
| ALLOWED_MODELS = [ | |
| Bandit, | |
| PasstFiLMConditionedBandit, | |
| ] | |
| ALLOWED_MODELS_DICT = _allowed_classes_to_dict(ALLOWED_MODELS) | |
| ALLOWED_DATAMODULES = [ | |
| MoisesDataModule, | |
| MoisesBalancedTrainDataModule, | |
| MoisesVDBODataModule, | |
| MoisesValidationDataModule, | |
| MoisesTestDataModule, | |
| ] | |
| ALLOWED_DATAMODULE_DICT = _allowed_classes_to_dict(ALLOWED_DATAMODULES) | |
| ALLOWED_LOSSES = [ | |
| L1SNRLoss, | |
| WeightedL1Loss, | |
| L1SNRDecibelMatchLoss, | |
| L1SNRLossIgnoreSilence, | |
| ] | |
| ALLOWED_LOSS_DICT = _allowed_classes_to_dict(ALLOWED_LOSSES) | |
| def _build_model(config: OmegaConf) -> nn.Module: | |
| model_config = config.model | |
| model_name = model_config.cls | |
| kwargs = model_config.get("kwargs", {}) | |
| if model_name in ALLOWED_MODELS_DICT: | |
| model = ALLOWED_MODELS_DICT[model_name](**kwargs) | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| return model | |
| def _build_inner_loss(config: OmegaConf) -> nn.Module: | |
| loss_config = config.loss | |
| loss_name = loss_config.cls | |
| kwargs = loss_config.get("kwargs", {}) | |
| if loss_name in ALLOWED_LOSS_DICT: | |
| loss = ALLOWED_LOSS_DICT[loss_name](**kwargs) | |
| elif loss_name in nn.modules.loss.__dict__: | |
| loss = nn.modules.loss.__dict__[loss_name](**kwargs) | |
| else: | |
| raise ValueError(f"Unknown loss name: {loss_name}") | |
| return loss | |
| def _build_loss(config: OmegaConf) -> BaseLossHandler: | |
| loss_handler = BaseLossHandler( | |
| loss=_build_inner_loss(config), | |
| modality=config.loss.modality, | |
| name=config.loss.get("name", None), | |
| ) | |
| return loss_handler | |
| def _dummy_metrics(config: OmegaConf) -> MultiModeMetricHandler: | |
| metrics = MultiModeMetricHandler( | |
| train_metrics={ | |
| stem: BaseMetricHandler( | |
| stem=stem, | |
| metric=tm.MetricCollection( | |
| SafeSignalNoiseRatio(), | |
| SafeScaleInvariantSignalNoiseRatio(), | |
| PredictedDecibels(), | |
| TargetDecibels(), | |
| ), | |
| modality="audio", | |
| name="snr", | |
| ) | |
| for stem in config.stems | |
| }, | |
| val_metrics={ | |
| stem: BaseMetricHandler( | |
| stem=stem, | |
| metric=tm.MetricCollection( | |
| SafeSignalNoiseRatio(), | |
| SafeScaleInvariantSignalNoiseRatio(), | |
| PredictedDecibels(), | |
| TargetDecibels(), | |
| ), | |
| modality="audio", | |
| name="snr", | |
| ) | |
| for stem in config.stems | |
| }, | |
| test_metrics={ | |
| stem: BaseMetricHandler( | |
| stem=stem, | |
| metric=tm.MetricCollection( | |
| SafeSignalNoiseRatio(), | |
| SafeScaleInvariantSignalNoiseRatio(), | |
| PredictedDecibels(), | |
| TargetDecibels(), | |
| ), | |
| modality="audio", | |
| name="snr", | |
| ) | |
| for stem in config.stems | |
| }, | |
| ) | |
| return metrics | |
| def _build_optimization_bundle(config: OmegaConf) -> OptimizationBundle: | |
| optim_config = config.optim | |
| print(optim_config) | |
| optimizer_name = optim_config.optimizer.cls | |
| kwargs = optim_config.optimizer.get("kwargs", {}) | |
| optimizer = getattr(optim, optimizer_name) | |
| optim_bundle = SimpleNamespace( | |
| optimizer=SimpleNamespace(cls=optimizer, kwargs=kwargs), scheduler=None | |
| ) | |
| scheduler_config = optim_config.get("scheduler", None) | |
| if scheduler_config is not None: | |
| scheduler_name = scheduler_config.cls | |
| scheduler_kwargs = scheduler_config.get("kwargs", {}) | |
| if scheduler_name in lr_scheduler.__dict__: | |
| scheduler = lr_scheduler.__dict__[scheduler_name] | |
| else: | |
| raise ValueError(f"Unknown scheduler name: {scheduler_name}") | |
| optim_bundle.scheduler = SimpleNamespace(cls=scheduler, kwargs=scheduler_kwargs) | |
| return optim_bundle | |
| def _dummy_augmentation(): | |
| return nn.Identity() | |
| def _load_config(config_path: str) -> OmegaConf: | |
| config = OmegaConf.load(config_path) | |
| config_dict = {} | |
| for k, v in config.items(): | |
| if isinstance(v, str) and v.endswith(".yml"): | |
| config_dict[k] = OmegaConf.load(v) | |
| else: | |
| config_dict[k] = v | |
| config = OmegaConf.merge(config_dict) | |
| return config | |
| def _build_datamodule(config: OmegaConf) -> pl.LightningDataModule: | |
| DataModule = ALLOWED_DATAMODULE_DICT[config.data.cls] | |
| datamodule = DataModule( | |
| data_root=config.data.data_root, | |
| batch_size=config.data.batch_size, | |
| num_workers=config.data.num_workers, | |
| train_kwargs=config.data.get("train_kwargs", None), | |
| val_kwargs=config.data.get("val_kwargs", None), | |
| test_kwargs=config.data.get("test_kwargs", None), | |
| datamodule_kwargs=config.data.get("datamodule_kwargs", None), | |
| ) | |
| return datamodule | |
| def train( | |
| config_path: str, | |
| profile: bool = False, | |
| ckpt_path: str = None, | |
| validate_only: bool = False, | |
| inference_only: bool = False, | |
| output_dir: str = None, | |
| test_datamodule: bool = False, | |
| precision=32, | |
| ): | |
| config = _load_config(config_path) | |
| pl.seed_everything(config.seed, workers=True) | |
| if inference_only: | |
| config["data"]["batch_size"] = 1 | |
| datamodule = _build_datamodule(config) | |
| if test_datamodule: | |
| for batch in tqdm(datamodule.train_dataloader()): | |
| pass | |
| for batch in tqdm(datamodule.val_dataloader()): | |
| pass | |
| for batch in tqdm(datamodule.test_dataloader()): | |
| pass | |
| return | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem( | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=config.get("inference", None), | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| rand_str = "".join( | |
| random.choice(string.ascii_uppercase + string.digits) for _ in range(6) | |
| ) | |
| logger = pytorch_lightning.loggers.TensorBoardLogger( | |
| save_dir=os.path.join( | |
| config.trainer.logger.save_dir, os.environ.get("SLURM_JOB_ID", rand_str) | |
| ), | |
| ) | |
| callbacks = [ | |
| pytorch_lightning.callbacks.ModelCheckpoint( | |
| monitor=config.trainer.callbacks.checkpoint.monitor, | |
| mode=config.trainer.callbacks.checkpoint.mode, | |
| save_top_k=config.trainer.callbacks.checkpoint.save_top_k, | |
| save_last=config.trainer.callbacks.checkpoint.save_last, | |
| ), | |
| pytorch_lightning.callbacks.ModelCheckpoint( | |
| monitor=None, | |
| ), # also save the last 3 epochs | |
| pytorch_lightning.callbacks.RichModelSummary(max_depth=3), | |
| ] | |
| if profile: | |
| profiler = AdvancedProfiler(filename="profile.txt", dirpath=".") | |
| if config.trainer.accumulate_grad_batches is None: | |
| config.trainer.accumulate_grad_batches = 1 | |
| if config.data.effective_batch_size is not None: | |
| config.trainer.accumulate_grad_batches = int( | |
| config.data.effective_batch_size / config.data.batch_size | |
| ) | |
| trainer = pl.Trainer( | |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
| max_epochs=1 if profile else config.trainer.max_epochs, | |
| callbacks=callbacks, | |
| logger=logger, | |
| profiler=profiler if profile else None, | |
| limit_train_batches=int(8) if profile else float(1.0), | |
| limit_val_batches=int(8) if profile else float(1.0), | |
| accumulate_grad_batches=config.trainer.accumulate_grad_batches, | |
| precision=precision, | |
| gradient_clip_val=config.trainer.get("gradient_clip_val", None), | |
| gradient_clip_algorithm=config.trainer.get("gradient_clip_algorithm", "norm"), | |
| ) | |
| if validate_only: | |
| trainer.validate(system, datamodule, ckpt_path=ckpt_path) | |
| elif inference_only: | |
| if output_dir is None: | |
| output_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(ckpt_path)), "inference" | |
| ) | |
| system.set_output_path(output_dir) | |
| trainer.predict(system, datamodule, ckpt_path=ckpt_path) | |
| else: | |
| trainer.logger.log_hyperparams(OmegaConf.to_object(config)) | |
| trainer.logger.save() | |
| trainer.fit(system, datamodule, ckpt_path=ckpt_path) | |
| def query_validate(config_path: str, ckpt_path: str): | |
| config = _load_config(config_path) | |
| datamodule = _build_datamodule(config) | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem( | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=None, | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| logger = pytorch_lightning.loggers.CSVLogger( | |
| save_dir=os.path.join(config.trainer.logger.save_dir, "validate"), | |
| ) | |
| trainer = pl.Trainer( | |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
| logger=logger, | |
| ) | |
| allowed_stems = config.data.val_kwargs.get("allowed_stems", None) | |
| data = [] | |
| os.makedirs(trainer.logger.log_dir, exist_ok=True) | |
| with open(trainer.logger.log_dir + "/config.txt", "w") as f: | |
| f.write(ckpt_path) | |
| dl = datamodule.val_dataloader() | |
| for stem, val_dl in zip(allowed_stems, dl): | |
| metrics = trainer.validate(system, val_dl, ckpt_path=ckpt_path)[0] | |
| print(stem) | |
| pprint(metrics) | |
| for metric, value in metrics.items(): | |
| data.append({"metric": metric, "value": value, "stem": stem}) | |
| df = pd.DataFrame(data) | |
| df.to_csv( | |
| os.path.join(trainer.logger.log_dir, "validation_metrics.csv"), index=False | |
| ) | |
| def query_test(config_path: str, ckpt_path: str): | |
| config = _load_config(config_path) | |
| pprint(config) | |
| pprint(config.data.inference_kwargs) | |
| datamodule = _build_datamodule(config) | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem( | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=config.data.inference_kwargs, | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| rand_str = "".join( | |
| random.choice(string.ascii_uppercase + string.digits) for _ in range(6) | |
| ) | |
| use_own_query = config.data.test_kwargs.get("use_own_query", False) | |
| prefix = "test-o" if use_own_query else "test" | |
| logger = pytorch_lightning.loggers.CSVLogger( | |
| save_dir=os.path.join( | |
| config.trainer.logger.save_dir, | |
| prefix, | |
| os.environ.get("SLURM_JOB_ID", rand_str), | |
| ), | |
| ) | |
| trainer = pl.Trainer( | |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
| logger=logger, | |
| ) | |
| os.makedirs(trainer.logger.log_dir, exist_ok=True) | |
| with open(trainer.logger.log_dir + "/config.txt", "w") as f: | |
| f.write(ckpt_path) | |
| trainer.logger.log_hyperparams(OmegaConf.to_object(config)) | |
| trainer.logger.save() | |
| dl = datamodule.test_dataloader() | |
| trainer.test(system, dl, ckpt_path=ckpt_path) | |
| def query_inference(config_path: str, ckpt_path: str): | |
| config = _load_config(config_path) | |
| pprint(config) | |
| pprint(config.data.inference_kwargs) | |
| datamodule = _build_datamodule(config) | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem( | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=config.data.inference_kwargs, | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| rand_str = "".join( | |
| random.choice(string.ascii_uppercase + string.digits) for _ in range(6) | |
| ) | |
| use_own_query = config.data.test_kwargs.get("use_own_query", False) | |
| prefix = "inference-o" if use_own_query else "inference-d" | |
| logger = pytorch_lightning.loggers.CSVLogger( | |
| save_dir=os.path.join( | |
| config.trainer.logger.save_dir, | |
| prefix, | |
| os.environ.get("SLURM_JOB_ID", rand_str), | |
| ), | |
| ) | |
| trainer = pl.Trainer( | |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
| logger=logger, | |
| ) | |
| os.makedirs(trainer.logger.log_dir, exist_ok=True) | |
| with open(trainer.logger.log_dir + "/config.txt", "w") as f: | |
| f.write(ckpt_path) | |
| trainer.logger.log_hyperparams(OmegaConf.to_object(config)) | |
| trainer.logger.save() | |
| dl = datamodule.test_dataloader() | |
| trainer.predict(system, dl, ckpt_path=ckpt_path) | |
| def clean_validation_metrics(path): | |
| df = pd.read_csv(path).T | |
| data = [] | |
| stems = [ | |
| "drums", | |
| "lead_male_singer", | |
| "lead_female_singer", | |
| # "human_choir", | |
| "background_vocals", | |
| # "other_vocals", | |
| "bass_guitar", | |
| "bass_synthesizer", | |
| # "contrabass_double_bass", | |
| # "tuba", | |
| # "bassoon", | |
| "fx", | |
| "clean_electric_guitar", | |
| "distorted_electric_guitar", | |
| # "lap_steel_guitar_or_slide_guitar", | |
| "acoustic_guitar", | |
| "other_plucked", | |
| "pitched_percussion", | |
| "grand_piano", | |
| "electric_piano", | |
| "organ_electric_organ", | |
| "synth_pad", | |
| "synth_lead", | |
| # "violin", | |
| # "viola", | |
| # "cello", | |
| # "violin_section", | |
| # "viola_section", | |
| # "cello_section", | |
| "string_section", | |
| "other_strings", | |
| "brass", | |
| # "flutes", | |
| "reeds", | |
| "other_wind", | |
| ] | |
| for metric, value in df.iterrows(): | |
| mm = metric.split("/") | |
| idx = mm[-1] | |
| m = "/".join(mm[:-1]) | |
| print(metric, idx) | |
| try: | |
| idx = int(idx.split("_")[-1]) | |
| except ValueError as e: | |
| assert "invalid literal for int() with base 10" in str(e) | |
| continue | |
| data.append({m: value, "stem": stems[idx]}) | |
| df = pd.DataFrame(data) | |
| new_path = path.replace(".csv", "_clean.csv") | |
| df.to_csv(new_path, index=False) | |
| def query_inference_one( | |
| config_path: str, | |
| ckpt_path: str, | |
| input_path: str, | |
| output_path: str, | |
| query_id: str, | |
| stems: List[str], | |
| fs: int = 44100, | |
| ): | |
| config = _load_config(config_path) | |
| pprint(config) | |
| pprint(config.data.inference_kwargs) | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem.load_from_checkpoint( | |
| os.path.expandvars(ckpt_path), | |
| strict=True, | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=config.data.inference_kwargs, | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| os.makedirs(output_path, exist_ok=True) | |
| mixture, fs = ta.load(input_path) | |
| if fs != 44100: | |
| mixture = ta.functional.resample(mixture, orig_freq=fs, new_freq=44100) | |
| for stem in stems: | |
| query = np.load( | |
| os.path.expandvars( | |
| os.path.join( | |
| "$DATA_ROOT/moisesdb/npyq", query_id, f"{stem}.query-10s.npy" | |
| ) | |
| ) | |
| ) | |
| batch = { | |
| "mixture": {"audio": mixture.unsqueeze(0).cuda()}, | |
| "query": { | |
| "audio": torch.from_numpy(query).to(torch.float32).unsqueeze(0).cuda() | |
| }, | |
| "metadata": {"stem": [stem]}, | |
| "estimates": {}, | |
| } | |
| out = system.chunked_inference(batch) | |
| out_path_stem = os.path.join(output_path, f"{stem}.wav") | |
| ta.save(out_path_stem, out["estimates"][stem]["audio"].squeeze().cpu(), 44100) | |
| def init( | |
| ckpt_path: str, | |
| config_path: str = None, | |
| batch_size: int = None, | |
| use_cuda: bool = True, | |
| ): | |
| if config_path is None: | |
| config_path = "./expt/bandit-everything-test.yml" | |
| config = _load_config(config_path) | |
| if batch_size is not None: | |
| config.data.inference_kwargs.batch_size = batch_size | |
| pprint(config) | |
| pprint(config.data.inference_kwargs) | |
| model = _build_model(config) | |
| loss_handler = _build_loss(config) | |
| system = EndToEndLightningSystem.load_from_checkpoint( | |
| os.path.expandvars(ckpt_path), | |
| strict=True, | |
| model=model, | |
| loss_handler=loss_handler, | |
| metrics=_dummy_metrics(config), | |
| augmentation_handler=_dummy_augmentation(), | |
| inference_handler=config.data.inference_kwargs, | |
| optimization_bundle=_build_optimization_bundle(config), | |
| fast_run=config.fast_run, | |
| batch_size=config.data.batch_size, | |
| effective_batch_size=config.data.get("effective_batch_size", None), | |
| commitment_weight=config.get("commitment_weight", 1.0), | |
| ) | |
| if use_cuda: | |
| system.cuda() | |
| else: | |
| system.cpu() | |
| return system | |
| def inference_file( | |
| system, | |
| input_path: str, | |
| output_path: str, | |
| query_path: str, | |
| stem_name: str = "target", | |
| model_fs: int = 44100, | |
| query_length_seconds: float = 10.0, | |
| ): | |
| assert query_length_seconds == 10.0, "Only 10s queries are supported at the moment." | |
| assert model_fs == 44100, "Only 44.1kHz models are supported at the moment." | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| mixture, fsm = ta.load(input_path) | |
| if mixture.shape[0] == 1: | |
| mixture = torch.cat([mixture, mixture], dim=0) | |
| print("Converting mono mixture to stereo") | |
| query, fsq = ta.load(query_path) | |
| if query.shape[0] == 1: | |
| query = torch.cat([query, query], dim=0) | |
| print("Converting mono query to stereo") | |
| if fsm != model_fs: | |
| mixture = ta.functional.resample(mixture, orig_freq=fsm, new_freq=model_fs) | |
| if fsq != model_fs: | |
| query = ta.functional.resample(query, orig_freq=fsq, new_freq=model_fs) | |
| if query.shape[1] > int(query_length_seconds * model_fs): | |
| print(f"Query is longer than {query_length_seconds} seconds. Extracting most active segment.") | |
| query = extract_most_active_segment(query, sr=model_fs, chunk_length=query_length_seconds) | |
| elif query.shape[1] < int(query_length_seconds * model_fs): | |
| print(f"Query is shorter than {query_length_seconds} seconds. Tiling.") | |
| query = torch.cat([query] * (int(query_length_seconds * model_fs) // query.shape[1] + 1), dim=1) | |
| query = query[:, :int(query_length_seconds * model_fs)] | |
| assert query.shape[1] == int(query_length_seconds * model_fs) | |
| query = query.unsqueeze(0).to(device=system.device) | |
| mixture = mixture.unsqueeze(0).to(device=system.device) | |
| batch = { | |
| "mixture": {"audio": mixture}, | |
| "query": { | |
| "audio": query | |
| }, | |
| "metadata": {"stem": [stem_name]}, | |
| "estimates": {}, | |
| } | |
| out = system.chunked_inference(batch) | |
| estimate = out["estimates"][stem_name]["audio"].squeeze().cpu() | |
| if fsm != model_fs: | |
| print("Resampling estimate back to the mixture's original sampling rate.") | |
| estimate = ta.functional.resample(estimate, orig_freq=model_fs, new_freq=fsm) | |
| ta.save(output_path, estimate, fsm) | |
| def inference_file_text( | |
| system, | |
| input_path: str, | |
| output_path: str, | |
| query_text: str, | |
| stem_name: str = "target", | |
| model_fs: int = 44100, | |
| ): | |
| assert model_fs == 44100, "Only 44.1kHz models are supported at the moment." | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| mixture, fsm = ta.load(input_path) | |
| if mixture.shape[0] == 1: | |
| mixture = torch.cat([mixture, mixture], dim=0) | |
| print("Converting mono mixture to stereo") | |
| if fsm != model_fs: | |
| mixture = ta.functional.resample(mixture, orig_freq=fsm, new_freq=model_fs) | |
| query = [query_text] | |
| mixture = mixture.unsqueeze(0).to(device=system.device) | |
| batch = { | |
| "mixture": {"audio": mixture}, | |
| "query": { | |
| "text": query | |
| }, | |
| "metadata": {"stem": [stem_name]}, | |
| "estimates": {}, | |
| } | |
| out = system.chunked_inference(batch) | |
| estimate = out["estimates"][stem_name]["audio"].squeeze().cpu() | |
| if fsm != model_fs: | |
| print("Resampling estimate back to the mixture's original sampling rate.") | |
| estimate = ta.functional.resample(estimate, orig_freq=model_fs, new_freq=fsm) | |
| ta.save(output_path, estimate, fsm) | |
| def extract_most_active_segment( | |
| audio: torch.Tensor, # (c, l) | |
| sr: int = 44100, | |
| chunk_length: int = 10, # seconds | |
| hop_size: int = 512 | |
| ) -> torch.Tensor: | |
| audio_mono = audio.mean(dim=0).numpy() | |
| chunk_size = int(chunk_length * sr) | |
| onset_strength = librosa.onset.onset_strength( | |
| y=audio_mono, sr=sr, hop_length=hop_size | |
| ) | |
| n_frames_per_chunk = chunk_size // hop_size | |
| onset_strength_slide = np.lib.stride_tricks.sliding_window_view( | |
| onset_strength, n_frames_per_chunk, axis=0 | |
| ) | |
| onset_strength = np.mean(onset_strength_slide, axis=1) | |
| max_onset_frame = np.argmax(onset_strength) | |
| max_onset_samples = librosa.frames_to_samples(max_onset_frame, hop_length=hop_size) | |
| print("max onset at time", max_onset_samples / sr) | |
| segment = audio[:, max_onset_samples : max_onset_samples + chunk_size] | |
| return segment | |
| def inference_byoq( | |
| ckpt_path: str, | |
| input_path: str, | |
| query_path: str, | |
| output_path: str, | |
| config_path: str = None, | |
| stem_name: str = "target", | |
| model_fs: int = 44100, | |
| query_length_seconds: float = 10.0, | |
| batch_size: int = None, | |
| use_cuda: bool = True, | |
| ): | |
| system = init(ckpt_path, config_path, batch_size, use_cuda) | |
| inference_file(system, input_path, output_path, query_path, stem_name, model_fs, query_length_seconds) | |
| def inference_byoq_text( | |
| ckpt_path: str, | |
| input_path: str, | |
| query_text: str, | |
| output_path: str, | |
| config_path: str = None, | |
| stem_name: str = "target", | |
| model_fs: int = 44100, | |
| batch_size: int = None, | |
| use_cuda: bool = True, | |
| ): | |
| system = init(ckpt_path, config_path, batch_size, use_cuda) | |
| inference_file_text(system, input_path, output_path, query_text, stem_name, model_fs) | |
| def inference_test_folder( | |
| ckpt_path: str, | |
| input_dir: str, | |
| output_dir: str, | |
| query_name: str, | |
| input_name: str = "mixture", | |
| config_path: str = None, | |
| stem_name: str = "target", | |
| model_fs: int = 44100, | |
| query_length_seconds: float = 10.0, | |
| batch_size: int = None, | |
| use_cuda: bool = True, | |
| ): | |
| system = init(ckpt_path, config_path, batch_size, use_cuda) | |
| subdirs = [ | |
| dirpath for dirpath, _, files in os.walk(input_dir) | |
| if f"{input_name}.wav" in files and f"{query_name}.wav" in files | |
| ] | |
| for i, subdir in enumerate(subdirs): | |
| print(f"Processing {i+1}/{len(subdirs)}") | |
| rel_path = os.path.relpath(subdir, input_dir) | |
| input_path = os.path.join(input_dir, rel_path, f"{input_name}.wav") | |
| query_path = os.path.join(input_dir, rel_path, f"{query_name}.wav") | |
| output_path = os.path.join(output_dir, rel_path, f"{query_name}.wav") | |
| print(input_path, query_path, output_path) | |
| inference_file(system, input_path, output_path, query_path, stem_name, model_fs, query_length_seconds) | |
| if __name__ == "__main__": | |
| import fire | |
| fire.Fire() | |