Jihuai's picture
change pad mode
2f53c98
raw
history blame
15.4 kB
import math
import os.path
from collections import defaultdict
from itertools import chain, combinations
from pprint import pprint
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
import pytorch_lightning as pl
import torch
import torchaudio as ta
import torchmetrics as tm
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from torch.nn import functional as F
from core.types import BatchedInputOutput, OperationMode, RawInputType, SimpleishNamespace
from core.types import (
InputType,
OutputType,
LossOutputType,
MetricOutputType,
ModelType,
OptimizerType,
SchedulerType,
MetricType,
LossType,
OptimizationBundle,
LossHandler,
MetricHandler,
AugmentationHandler,
InferenceHandler,
)
class EndToEndLightningSystem(pl.LightningModule):
def __init__(
self,
model: ModelType,
loss_handler: LossHandler,
metrics: MetricHandler,
augmentation_handler: AugmentationHandler,
inference_handler: InferenceHandler,
optimization_bundle: OptimizationBundle,
fast_run: bool = False,
commitment_weight: float = 1.0,
batch_size: Optional[int] = None,
effective_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.model = model
self.loss = loss_handler
self.metrics = metrics
self.optimization = optimization_bundle
self.augmentation = augmentation_handler
self.inference = inference_handler
self.fast_run = fast_run
self.model.fast_run = fast_run
self.commitment_weight = commitment_weight
self.batch_size = batch_size
self.effective_batch_size = effective_batch_size if effective_batch_size is not None else batch_size
self.accum_ratio = self.effective_batch_size // self.batch_size if self.effective_batch_size is not None else 1
self.output_dir = None
self.split_size = None
def configure_optimizers(self) -> Any:
optimizer = self.optimization.optimizer.cls(
self.model.parameters(),
**self.optimization.optimizer.kwargs
)
ret = {
"optimizer": optimizer,
}
if self.optimization.scheduler is not None:
scheduler = self.optimization.scheduler.cls(
optimizer,
**self.optimization.scheduler.kwargs
)
ret["lr_scheduler"] = scheduler
return ret
def compute_loss(
self,
batch: BatchedInputOutput,
mode=OperationMode.TRAIN
) -> LossOutputType:
loss_dict = self.loss(batch)
return loss_dict
# TODO: move to a metric handler
def update_metrics(
self,
batch: BatchedInputOutput,
mode: OperationMode = OperationMode.TRAIN,
) -> None:
metrics: MetricType = self.metrics.get_mode(mode)
for stem, metric in metrics.items():
if stem not in batch.estimates.keys():
continue
metric.update(batch)
# TODO: move to a metric handler
def compute_metrics(self, mode: OperationMode) -> MetricOutputType:
metrics: MetricType = self.metrics.get_mode(mode)
metric_dict = {}
for stem, metric in metrics.items():
md = metric.compute()
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
self.log_dict(metric_dict, prog_bar=True, logger=False)
return metric_dict
# TODO: move to a metric handler
def reset_metrics(self, mode: OperationMode) -> None:
metrics: MetricType = self.metrics.get_mode(mode)
for _, metric in metrics.items():
metric.reset()
def forward(self, batch: RawInputType) -> Tuple[InputType, OutputType]:
batch = self.model(batch)
return batch
def common_step(
self, batch: RawInputType, mode: OperationMode, batch_idx: int = -1
) -> Tuple[OutputType, LossOutputType]:
batch = BatchedInputOutput.from_dict(batch)
batch = self.forward(batch)
loss_dict = self.compute_loss(batch, mode=mode)
if not self.fast_run:
with torch.no_grad():
self.update_metrics(batch, mode=mode)
return loss_dict
def training_step(self, batch: RawInputType, batch_idx: int) -> LossOutputType:
# augmented_batch = self.augmentation(batch, mode=OperationMode.TRAIN)
self.model.train()
loss_dict = self.common_step(batch, mode=OperationMode.TRAIN, batch_idx=batch_idx)
self.log_dict_with_prefix(loss_dict, prefix=OperationMode.TRAIN, prog_bar=True)
return loss_dict
def on_train_batch_end(
self, outputs: OutputType, batch: RawInputType, batch_idx: int
) -> None:
if self.fast_run:
return
if (batch_idx + 1) % self.accum_ratio == 0:
metric_dict = self.compute_metrics(mode=OperationMode.TRAIN)
self.log_dict_with_prefix(metric_dict, prefix=OperationMode.TRAIN)
self.reset_metrics(mode=OperationMode.TRAIN)
@torch.inference_mode()
def validation_step(
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
) -> Dict[str, Any]:
self.model.eval()
with torch.inference_mode():
loss_dict = self.common_step(batch, mode=OperationMode.VAL)
self.log_dict_with_prefix(loss_dict, prefix=OperationMode.VAL)
return loss_dict
def on_validation_epoch_start(self) -> None:
self.reset_metrics(mode=OperationMode.VAL)
def on_validation_epoch_end(self) -> None:
if self.fast_run:
return
metric_dict = self.compute_metrics(mode=OperationMode.VAL)
self.log_dict_with_prefix(
metric_dict, OperationMode.VAL, prog_bar=True, add_dataloader_idx=False
)
self.reset_metrics(mode=OperationMode.VAL)
def save_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None:
batch_size = batch["mixture"]["audio"].shape[0]
assert batch_size == 1, "Batch size must be 1 for inference"
metadata = batch.metadata
song_id = metadata["mix"][0]
stem = metadata["stem"][0]
log_dir = os.path.join(self.logger.log_dir, "audio")
os.makedirs(os.path.join(log_dir, song_id), exist_ok=True)
audio = batch.estimates[stem]["audio"]
audio = audio.squeeze(0).cpu().numpy()
audio_path = os.path.join(log_dir, song_id, f"{stem}.wav")
ta.save(audio_path, torch.tensor(audio), self.inference.fs)
def save_vdbo_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None:
batch_size = batch["mixture"]["audio"].shape[0]
assert batch_size == 1, "Batch size must be 1 for inference"
metadata = batch.metadata
song_id = metadata["song_id"][0]
log_dir = os.path.join(self.logger.log_dir, "audio")
os.makedirs(os.path.join(log_dir, song_id), exist_ok=True)
for stem, audio in batch.estimates.items():
audio = audio["audio"]
audio = audio.squeeze(0).cpu().numpy()
audio_path = os.path.join(log_dir, song_id, f"{stem}.wav")
ta.save(audio_path, torch.tensor(audio), self.inference.fs)
@torch.inference_mode()
def chunked_inference(
self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0
) -> BatchedInputOutput:
batch = BatchedInputOutput.from_dict(batch)
audio = batch["mixture"]["audio"]
b, c, n_samples = audio.shape
assert b == 1
fs = self.inference.fs
chunk_size = int(self.inference.chunk_size_seconds * fs)
hop_size = int(self.inference.hop_size_seconds * fs)
batch_size = self.inference.batch_size
overlap = chunk_size - hop_size
scaler = chunk_size / (2 * hop_size)
n_chunks = int(math.ceil(
(n_samples + 4 * overlap - chunk_size) / hop_size
)) + 1
pad = (n_chunks - 1) * hop_size + chunk_size - n_samples
# print(audio.shape)
audio = F.pad(
audio,
pad=(2 * overlap, 2 * overlap + pad),
mode="constant",
value=0,
)
padded_length = audio.shape[-1]
audio = audio.reshape(c, 1, -1, 1)
chunked_audio = F.unfold(
audio,
kernel_size=(chunk_size, 1),
stride=(hop_size, 1)
) # (c, chunk_size, n_chunk)
# print(chunked_audio.shape)
chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size)
n_chunks = chunked_audio.shape[0]
n_batch = math.ceil(n_chunks / batch_size)
outputs = []
for i in tqdm(range(n_batch)):
start = i * batch_size
end = min((i + 1) * batch_size, n_chunks)
chunked_batch = SimpleishNamespace(
mixture={
"audio": chunked_audio[start:end]
},
query=batch["query"],
estimates=batch["estimates"]
)
output = self.forward(chunked_batch)
outputs.append(output.estimates["target"]["audio"])
output = torch.cat(outputs, dim=0) # (n_chunks, c, chunk_size)
window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size)
output = output * window / scaler
output = torch.permute(output, (1, 2, 0))
output = F.fold(
output,
output_size=(padded_length, 1),
kernel_size=(chunk_size, 1),
stride=(hop_size, 1)
) # (c, 1, t, 1)
output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0]
stem = batch.metadata["stem"][0]
batch["estimates"][stem] = {
"audio": output
}
return batch
def chunked_vdbo_inference(
self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0
) -> BatchedInputOutput:
batch = BatchedInputOutput.from_dict(batch)
audio = batch["mixture"]["audio"]
b, c, n_samples = audio.shape
assert b == 1
fs = self.inference.fs
chunk_size = int(self.inference.chunk_size_seconds * fs)
hop_size = int(self.inference.hop_size_seconds * fs)
batch_size = self.inference.batch_size
overlap = chunk_size - hop_size
scaler = chunk_size / (2 * hop_size)
n_chunks = int(math.ceil(
(n_samples + 4 * overlap - chunk_size) / hop_size
)) + 1
pad = (n_chunks - 1) * hop_size + chunk_size - n_samples
# print(audio.shape)
audio = F.pad(
audio,
pad=(2 * overlap, 2 * overlap + pad),
mode="reflect"
)
padded_length = audio.shape[-1]
audio = audio.reshape(c, 1, -1, 1)
chunked_audio = F.unfold(
audio,
kernel_size=(chunk_size, 1),
stride=(hop_size, 1)
) # (c, chunk_size, n_chunk)
# print(chunked_audio.shape)
chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size)
n_chunks = chunked_audio.shape[0]
n_batch = math.ceil(n_chunks / batch_size)
outputs = defaultdict(list)
for i in tqdm(range(n_batch)):
start = i * batch_size
end = min((i + 1) * batch_size, n_chunks)
chunked_batch = SimpleishNamespace(
mixture={
"audio": chunked_audio[start:end]
},
estimates=batch["estimates"]
)
output = self.forward(chunked_batch)
for stem, estimate in output.estimates.items():
outputs[stem].append(estimate["audio"])
for stem, outputs_ in outputs.items():
output = torch.cat(outputs_, dim=0) # (n_chunks, c, chunk_size)
window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size)
output = output * window / scaler
output = torch.permute(output, (1, 2, 0))
output = F.fold(
output,
output_size=(padded_length, 1),
kernel_size=(chunk_size, 1),
stride=(hop_size, 1)
) # (c, 1, t, 1)
output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0]
batch["estimates"][stem] = {
"audio": output
}
return batch
def on_test_epoch_start(self) -> None:
self.reset_metrics(mode=OperationMode.TEST)
def test_step(
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
) -> Any:
self.model.eval()
if "query" in batch.keys():
batch = self.chunked_inference(batch, batch_idx, dataloader_idx)
else:
batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx)
self.reset_metrics(mode=OperationMode.TEST)
self.update_metrics(batch, mode=OperationMode.TEST)
metrics = self.compute_metrics(mode=OperationMode.TEST)
# metrics["song_id"] = batch.metadata["mix"][0]
self.log_dict_with_prefix(metrics, OperationMode.TEST,
on_step=True, on_epoch=False, prog_bar=True)
self.reset_metrics(mode=OperationMode.TEST)
# pprint(metrics)
return batch
def on_test_epoch_end(self) -> None:
self.reset_metrics(mode=OperationMode.TEST)
def set_output_path(self, output_dir: str) -> None:
self.output_dir = output_dir
def predict_step(
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
) -> Any:
self.model.eval()
if "query" in batch.keys():
batch = self.chunked_inference(batch, batch_idx, dataloader_idx)
self.save_to_audio(batch, batch_idx)
else:
batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx)
self.save_vdbo_to_audio(batch, batch_idx)
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = False
) -> Any:
return super().load_state_dict(state_dict, strict=False)
def log_dict_with_prefix(
self,
dict_: Dict[str, torch.Tensor],
prefix: str,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> None:
self.log_dict(
{f"{prefix}/{k}": v for k, v in dict_.items()},
batch_size=batch_size,
logger=True,
sync_dist=True,
**kwargs,
# on_step=True,
# on_epoch=False,
)
self.logger.save()