Jihuai's picture
have to create an orphan branch to bypass large file history: cleanup .ipynb and create LFS
d572f56
import json
import os.path
from pprint import pprint
import random
import string
from types import SimpleNamespace
from typing import List
import pandas as pd
import librosa
import torch
from tqdm import tqdm
from core.data.moisesdb.datamodule import (
MoisesTestDataModule,
MoisesValidationDataModule,
MoisesDataModule,
MoisesBalancedTrainDataModule,
MoisesVDBODataModule,
)
from core.losses.base import AdversarialLossHandler, BaseLossHandler
from core.losses.l1snr import (
L1SNRDecibelMatchLoss,
L1SNRLoss,
WeightedL1Loss,
L1SNRLossIgnoreSilence,
)
from core.metrics.base import BaseMetricHandler, MultiModeMetricHandler
from core.metrics.snr import (
SafeScaleInvariantSignalNoiseRatio,
SafeSignalNoiseRatio,
PredictedDecibels,
TargetDecibels,
)
from core.models.ebase import EndToEndLightningSystem
from core.models.e2e.bandit.bandit import Bandit, PasstFiLMConditionedBandit
from omegaconf import OmegaConf
from core.types import LossHandler, OptimizationBundle
from torch import nn, optim
from torch.optim import lr_scheduler
import torchaudio as ta
import numpy as np
import torchmetrics as tm
import pytorch_lightning as pl
import pytorch_lightning.callbacks
import pytorch_lightning.loggers
from pytorch_lightning.profilers import AdvancedProfiler
import torch.backends.cudnn
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
def _allowed_classes_to_dict(allowed_classes):
return {cls.__name__: cls for cls in allowed_classes}
ALLOWED_MODELS = [
Bandit,
PasstFiLMConditionedBandit,
]
ALLOWED_MODELS_DICT = _allowed_classes_to_dict(ALLOWED_MODELS)
ALLOWED_DATAMODULES = [
MoisesDataModule,
MoisesBalancedTrainDataModule,
MoisesVDBODataModule,
MoisesValidationDataModule,
MoisesTestDataModule,
]
ALLOWED_DATAMODULE_DICT = _allowed_classes_to_dict(ALLOWED_DATAMODULES)
ALLOWED_LOSSES = [
L1SNRLoss,
WeightedL1Loss,
L1SNRDecibelMatchLoss,
L1SNRLossIgnoreSilence,
]
ALLOWED_LOSS_DICT = _allowed_classes_to_dict(ALLOWED_LOSSES)
def _build_model(config: OmegaConf) -> nn.Module:
model_config = config.model
model_name = model_config.cls
kwargs = model_config.get("kwargs", {})
if model_name in ALLOWED_MODELS_DICT:
model = ALLOWED_MODELS_DICT[model_name](**kwargs)
else:
raise ValueError(f"Unknown model name: {model_name}")
return model
def _build_inner_loss(config: OmegaConf) -> nn.Module:
loss_config = config.loss
loss_name = loss_config.cls
kwargs = loss_config.get("kwargs", {})
if loss_name in ALLOWED_LOSS_DICT:
loss = ALLOWED_LOSS_DICT[loss_name](**kwargs)
elif loss_name in nn.modules.loss.__dict__:
loss = nn.modules.loss.__dict__[loss_name](**kwargs)
else:
raise ValueError(f"Unknown loss name: {loss_name}")
return loss
def _build_loss(config: OmegaConf) -> BaseLossHandler:
loss_handler = BaseLossHandler(
loss=_build_inner_loss(config),
modality=config.loss.modality,
name=config.loss.get("name", None),
)
return loss_handler
def _dummy_metrics(config: OmegaConf) -> MultiModeMetricHandler:
metrics = MultiModeMetricHandler(
train_metrics={
stem: BaseMetricHandler(
stem=stem,
metric=tm.MetricCollection(
SafeSignalNoiseRatio(),
SafeScaleInvariantSignalNoiseRatio(),
PredictedDecibels(),
TargetDecibels(),
),
modality="audio",
name="snr",
)
for stem in config.stems
},
val_metrics={
stem: BaseMetricHandler(
stem=stem,
metric=tm.MetricCollection(
SafeSignalNoiseRatio(),
SafeScaleInvariantSignalNoiseRatio(),
PredictedDecibels(),
TargetDecibels(),
),
modality="audio",
name="snr",
)
for stem in config.stems
},
test_metrics={
stem: BaseMetricHandler(
stem=stem,
metric=tm.MetricCollection(
SafeSignalNoiseRatio(),
SafeScaleInvariantSignalNoiseRatio(),
PredictedDecibels(),
TargetDecibels(),
),
modality="audio",
name="snr",
)
for stem in config.stems
},
)
return metrics
def _build_optimization_bundle(config: OmegaConf) -> OptimizationBundle:
optim_config = config.optim
print(optim_config)
optimizer_name = optim_config.optimizer.cls
kwargs = optim_config.optimizer.get("kwargs", {})
optimizer = getattr(optim, optimizer_name)
optim_bundle = SimpleNamespace(
optimizer=SimpleNamespace(cls=optimizer, kwargs=kwargs), scheduler=None
)
scheduler_config = optim_config.get("scheduler", None)
if scheduler_config is not None:
scheduler_name = scheduler_config.cls
scheduler_kwargs = scheduler_config.get("kwargs", {})
if scheduler_name in lr_scheduler.__dict__:
scheduler = lr_scheduler.__dict__[scheduler_name]
else:
raise ValueError(f"Unknown scheduler name: {scheduler_name}")
optim_bundle.scheduler = SimpleNamespace(cls=scheduler, kwargs=scheduler_kwargs)
return optim_bundle
def _dummy_augmentation():
return nn.Identity()
def _load_config(config_path: str) -> OmegaConf:
config = OmegaConf.load(config_path)
config_dict = {}
for k, v in config.items():
if isinstance(v, str) and v.endswith(".yml"):
config_dict[k] = OmegaConf.load(v)
else:
config_dict[k] = v
config = OmegaConf.merge(config_dict)
return config
def _build_datamodule(config: OmegaConf) -> pl.LightningDataModule:
DataModule = ALLOWED_DATAMODULE_DICT[config.data.cls]
datamodule = DataModule(
data_root=config.data.data_root,
batch_size=config.data.batch_size,
num_workers=config.data.num_workers,
train_kwargs=config.data.get("train_kwargs", None),
val_kwargs=config.data.get("val_kwargs", None),
test_kwargs=config.data.get("test_kwargs", None),
datamodule_kwargs=config.data.get("datamodule_kwargs", None),
)
return datamodule
def train(
config_path: str,
profile: bool = False,
ckpt_path: str = None,
validate_only: bool = False,
inference_only: bool = False,
output_dir: str = None,
test_datamodule: bool = False,
precision=32,
):
config = _load_config(config_path)
pl.seed_everything(config.seed, workers=True)
if inference_only:
config["data"]["batch_size"] = 1
datamodule = _build_datamodule(config)
if test_datamodule:
for batch in tqdm(datamodule.train_dataloader()):
pass
for batch in tqdm(datamodule.val_dataloader()):
pass
for batch in tqdm(datamodule.test_dataloader()):
pass
return
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem(
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=config.get("inference", None),
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
rand_str = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
)
logger = pytorch_lightning.loggers.TensorBoardLogger(
save_dir=os.path.join(
config.trainer.logger.save_dir, os.environ.get("SLURM_JOB_ID", rand_str)
),
)
callbacks = [
pytorch_lightning.callbacks.ModelCheckpoint(
monitor=config.trainer.callbacks.checkpoint.monitor,
mode=config.trainer.callbacks.checkpoint.mode,
save_top_k=config.trainer.callbacks.checkpoint.save_top_k,
save_last=config.trainer.callbacks.checkpoint.save_last,
),
pytorch_lightning.callbacks.ModelCheckpoint(
monitor=None,
), # also save the last 3 epochs
pytorch_lightning.callbacks.RichModelSummary(max_depth=3),
]
if profile:
profiler = AdvancedProfiler(filename="profile.txt", dirpath=".")
if config.trainer.accumulate_grad_batches is None:
config.trainer.accumulate_grad_batches = 1
if config.data.effective_batch_size is not None:
config.trainer.accumulate_grad_batches = int(
config.data.effective_batch_size / config.data.batch_size
)
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
max_epochs=1 if profile else config.trainer.max_epochs,
callbacks=callbacks,
logger=logger,
profiler=profiler if profile else None,
limit_train_batches=int(8) if profile else float(1.0),
limit_val_batches=int(8) if profile else float(1.0),
accumulate_grad_batches=config.trainer.accumulate_grad_batches,
precision=precision,
gradient_clip_val=config.trainer.get("gradient_clip_val", None),
gradient_clip_algorithm=config.trainer.get("gradient_clip_algorithm", "norm"),
)
if validate_only:
trainer.validate(system, datamodule, ckpt_path=ckpt_path)
elif inference_only:
if output_dir is None:
output_dir = os.path.join(
os.path.dirname(os.path.dirname(ckpt_path)), "inference"
)
system.set_output_path(output_dir)
trainer.predict(system, datamodule, ckpt_path=ckpt_path)
else:
trainer.logger.log_hyperparams(OmegaConf.to_object(config))
trainer.logger.save()
trainer.fit(system, datamodule, ckpt_path=ckpt_path)
def query_validate(config_path: str, ckpt_path: str):
config = _load_config(config_path)
datamodule = _build_datamodule(config)
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem(
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=None,
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
logger = pytorch_lightning.loggers.CSVLogger(
save_dir=os.path.join(config.trainer.logger.save_dir, "validate"),
)
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
logger=logger,
)
allowed_stems = config.data.val_kwargs.get("allowed_stems", None)
data = []
os.makedirs(trainer.logger.log_dir, exist_ok=True)
with open(trainer.logger.log_dir + "/config.txt", "w") as f:
f.write(ckpt_path)
dl = datamodule.val_dataloader()
for stem, val_dl in zip(allowed_stems, dl):
metrics = trainer.validate(system, val_dl, ckpt_path=ckpt_path)[0]
print(stem)
pprint(metrics)
for metric, value in metrics.items():
data.append({"metric": metric, "value": value, "stem": stem})
df = pd.DataFrame(data)
df.to_csv(
os.path.join(trainer.logger.log_dir, "validation_metrics.csv"), index=False
)
def query_test(config_path: str, ckpt_path: str):
config = _load_config(config_path)
pprint(config)
pprint(config.data.inference_kwargs)
datamodule = _build_datamodule(config)
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem(
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=config.data.inference_kwargs,
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
rand_str = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
)
use_own_query = config.data.test_kwargs.get("use_own_query", False)
prefix = "test-o" if use_own_query else "test"
logger = pytorch_lightning.loggers.CSVLogger(
save_dir=os.path.join(
config.trainer.logger.save_dir,
prefix,
os.environ.get("SLURM_JOB_ID", rand_str),
),
)
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
logger=logger,
)
os.makedirs(trainer.logger.log_dir, exist_ok=True)
with open(trainer.logger.log_dir + "/config.txt", "w") as f:
f.write(ckpt_path)
trainer.logger.log_hyperparams(OmegaConf.to_object(config))
trainer.logger.save()
dl = datamodule.test_dataloader()
trainer.test(system, dl, ckpt_path=ckpt_path)
def query_inference(config_path: str, ckpt_path: str):
config = _load_config(config_path)
pprint(config)
pprint(config.data.inference_kwargs)
datamodule = _build_datamodule(config)
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem(
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=config.data.inference_kwargs,
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
rand_str = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
)
use_own_query = config.data.test_kwargs.get("use_own_query", False)
prefix = "inference-o" if use_own_query else "inference-d"
logger = pytorch_lightning.loggers.CSVLogger(
save_dir=os.path.join(
config.trainer.logger.save_dir,
prefix,
os.environ.get("SLURM_JOB_ID", rand_str),
),
)
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
logger=logger,
)
os.makedirs(trainer.logger.log_dir, exist_ok=True)
with open(trainer.logger.log_dir + "/config.txt", "w") as f:
f.write(ckpt_path)
trainer.logger.log_hyperparams(OmegaConf.to_object(config))
trainer.logger.save()
dl = datamodule.test_dataloader()
trainer.predict(system, dl, ckpt_path=ckpt_path)
def clean_validation_metrics(path):
df = pd.read_csv(path).T
data = []
stems = [
"drums",
"lead_male_singer",
"lead_female_singer",
# "human_choir",
"background_vocals",
# "other_vocals",
"bass_guitar",
"bass_synthesizer",
# "contrabass_double_bass",
# "tuba",
# "bassoon",
"fx",
"clean_electric_guitar",
"distorted_electric_guitar",
# "lap_steel_guitar_or_slide_guitar",
"acoustic_guitar",
"other_plucked",
"pitched_percussion",
"grand_piano",
"electric_piano",
"organ_electric_organ",
"synth_pad",
"synth_lead",
# "violin",
# "viola",
# "cello",
# "violin_section",
# "viola_section",
# "cello_section",
"string_section",
"other_strings",
"brass",
# "flutes",
"reeds",
"other_wind",
]
for metric, value in df.iterrows():
mm = metric.split("/")
idx = mm[-1]
m = "/".join(mm[:-1])
print(metric, idx)
try:
idx = int(idx.split("_")[-1])
except ValueError as e:
assert "invalid literal for int() with base 10" in str(e)
continue
data.append({m: value, "stem": stems[idx]})
df = pd.DataFrame(data)
new_path = path.replace(".csv", "_clean.csv")
df.to_csv(new_path, index=False)
def query_inference_one(
config_path: str,
ckpt_path: str,
input_path: str,
output_path: str,
query_id: str,
stems: List[str],
fs: int = 44100,
):
config = _load_config(config_path)
pprint(config)
pprint(config.data.inference_kwargs)
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem.load_from_checkpoint(
os.path.expandvars(ckpt_path),
strict=True,
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=config.data.inference_kwargs,
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
os.makedirs(output_path, exist_ok=True)
mixture, fs = ta.load(input_path)
if fs != 44100:
mixture = ta.functional.resample(mixture, orig_freq=fs, new_freq=44100)
for stem in stems:
query = np.load(
os.path.expandvars(
os.path.join(
"$DATA_ROOT/moisesdb/npyq", query_id, f"{stem}.query-10s.npy"
)
)
)
batch = {
"mixture": {"audio": mixture.unsqueeze(0).cuda()},
"query": {
"audio": torch.from_numpy(query).to(torch.float32).unsqueeze(0).cuda()
},
"metadata": {"stem": [stem]},
"estimates": {},
}
out = system.chunked_inference(batch)
out_path_stem = os.path.join(output_path, f"{stem}.wav")
ta.save(out_path_stem, out["estimates"][stem]["audio"].squeeze().cpu(), 44100)
def init(
ckpt_path: str,
config_path: str = None,
batch_size: int = None,
use_cuda: bool = True,
):
if config_path is None:
config_path = "./expt/bandit-everything-test.yml"
config = _load_config(config_path)
if batch_size is not None:
config.data.inference_kwargs.batch_size = batch_size
pprint(config)
pprint(config.data.inference_kwargs)
model = _build_model(config)
loss_handler = _build_loss(config)
system = EndToEndLightningSystem.load_from_checkpoint(
os.path.expandvars(ckpt_path),
strict=True,
model=model,
loss_handler=loss_handler,
metrics=_dummy_metrics(config),
augmentation_handler=_dummy_augmentation(),
inference_handler=config.data.inference_kwargs,
optimization_bundle=_build_optimization_bundle(config),
fast_run=config.fast_run,
batch_size=config.data.batch_size,
effective_batch_size=config.data.get("effective_batch_size", None),
commitment_weight=config.get("commitment_weight", 1.0),
)
if use_cuda:
system.cuda()
else:
system.cpu()
return system
def inference_file(
system,
input_path: str,
output_path: str,
query_path: str,
stem_name: str = "target",
model_fs: int = 44100,
query_length_seconds: float = 10.0,
):
assert query_length_seconds == 10.0, "Only 10s queries are supported at the moment."
assert model_fs == 44100, "Only 44.1kHz models are supported at the moment."
os.makedirs(os.path.dirname(output_path), exist_ok=True)
mixture, fsm = ta.load(input_path)
if mixture.shape[0] == 1:
mixture = torch.cat([mixture, mixture], dim=0)
print("Converting mono mixture to stereo")
query, fsq = ta.load(query_path)
if query.shape[0] == 1:
query = torch.cat([query, query], dim=0)
print("Converting mono query to stereo")
if fsm != model_fs:
mixture = ta.functional.resample(mixture, orig_freq=fsm, new_freq=model_fs)
if fsq != model_fs:
query = ta.functional.resample(query, orig_freq=fsq, new_freq=model_fs)
if query.shape[1] > int(query_length_seconds * model_fs):
print(f"Query is longer than {query_length_seconds} seconds. Extracting most active segment.")
query = extract_most_active_segment(query, sr=model_fs, chunk_length=query_length_seconds)
elif query.shape[1] < int(query_length_seconds * model_fs):
print(f"Query is shorter than {query_length_seconds} seconds. Tiling.")
query = torch.cat([query] * (int(query_length_seconds * model_fs) // query.shape[1] + 1), dim=1)
query = query[:, :int(query_length_seconds * model_fs)]
assert query.shape[1] == int(query_length_seconds * model_fs)
query = query.unsqueeze(0).to(device=system.device)
mixture = mixture.unsqueeze(0).to(device=system.device)
batch = {
"mixture": {"audio": mixture},
"query": {
"audio": query
},
"metadata": {"stem": [stem_name]},
"estimates": {},
}
out = system.chunked_inference(batch)
estimate = out["estimates"][stem_name]["audio"].squeeze().cpu()
if fsm != model_fs:
print("Resampling estimate back to the mixture's original sampling rate.")
estimate = ta.functional.resample(estimate, orig_freq=model_fs, new_freq=fsm)
ta.save(output_path, estimate, fsm)
def inference_file_text(
system,
input_path: str,
output_path: str,
query_text: str,
stem_name: str = "target",
model_fs: int = 44100,
):
assert model_fs == 44100, "Only 44.1kHz models are supported at the moment."
os.makedirs(os.path.dirname(output_path), exist_ok=True)
mixture, fsm = ta.load(input_path)
if mixture.shape[0] == 1:
mixture = torch.cat([mixture, mixture], dim=0)
print("Converting mono mixture to stereo")
if fsm != model_fs:
mixture = ta.functional.resample(mixture, orig_freq=fsm, new_freq=model_fs)
query = [query_text]
mixture = mixture.unsqueeze(0).to(device=system.device)
batch = {
"mixture": {"audio": mixture},
"query": {
"text": query
},
"metadata": {"stem": [stem_name]},
"estimates": {},
}
out = system.chunked_inference(batch)
estimate = out["estimates"][stem_name]["audio"].squeeze().cpu()
if fsm != model_fs:
print("Resampling estimate back to the mixture's original sampling rate.")
estimate = ta.functional.resample(estimate, orig_freq=model_fs, new_freq=fsm)
ta.save(output_path, estimate, fsm)
def extract_most_active_segment(
audio: torch.Tensor, # (c, l)
sr: int = 44100,
chunk_length: int = 10, # seconds
hop_size: int = 512
) -> torch.Tensor:
audio_mono = audio.mean(dim=0).numpy()
chunk_size = int(chunk_length * sr)
onset_strength = librosa.onset.onset_strength(
y=audio_mono, sr=sr, hop_length=hop_size
)
n_frames_per_chunk = chunk_size // hop_size
onset_strength_slide = np.lib.stride_tricks.sliding_window_view(
onset_strength, n_frames_per_chunk, axis=0
)
onset_strength = np.mean(onset_strength_slide, axis=1)
max_onset_frame = np.argmax(onset_strength)
max_onset_samples = librosa.frames_to_samples(max_onset_frame, hop_length=hop_size)
print("max onset at time", max_onset_samples / sr)
segment = audio[:, max_onset_samples : max_onset_samples + chunk_size]
return segment
def inference_byoq(
ckpt_path: str,
input_path: str,
query_path: str,
output_path: str,
config_path: str = None,
stem_name: str = "target",
model_fs: int = 44100,
query_length_seconds: float = 10.0,
batch_size: int = None,
use_cuda: bool = True,
):
system = init(ckpt_path, config_path, batch_size, use_cuda)
inference_file(system, input_path, output_path, query_path, stem_name, model_fs, query_length_seconds)
def inference_byoq_text(
ckpt_path: str,
input_path: str,
query_text: str,
output_path: str,
config_path: str = None,
stem_name: str = "target",
model_fs: int = 44100,
batch_size: int = None,
use_cuda: bool = True,
):
system = init(ckpt_path, config_path, batch_size, use_cuda)
inference_file_text(system, input_path, output_path, query_text, stem_name, model_fs)
def inference_test_folder(
ckpt_path: str,
input_dir: str,
output_dir: str,
query_name: str,
input_name: str = "mixture",
config_path: str = None,
stem_name: str = "target",
model_fs: int = 44100,
query_length_seconds: float = 10.0,
batch_size: int = None,
use_cuda: bool = True,
):
system = init(ckpt_path, config_path, batch_size, use_cuda)
subdirs = [
dirpath for dirpath, _, files in os.walk(input_dir)
if f"{input_name}.wav" in files and f"{query_name}.wav" in files
]
for i, subdir in enumerate(subdirs):
print(f"Processing {i+1}/{len(subdirs)}")
rel_path = os.path.relpath(subdir, input_dir)
input_path = os.path.join(input_dir, rel_path, f"{input_name}.wav")
query_path = os.path.join(input_dir, rel_path, f"{query_name}.wav")
output_path = os.path.join(output_dir, rel_path, f"{query_name}.wav")
print(input_path, query_path, output_path)
inference_file(system, input_path, output_path, query_path, stem_name, model_fs, query_length_seconds)
if __name__ == "__main__":
import fire
fire.Fire()