|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Callable, Sequence |
|
|
from functools import partial |
|
|
from typing import Any, NamedTuple, Protocol, Union, cast, overload |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from typing_extensions import TypeAlias, Unpack |
|
|
|
|
|
from .wavelet import Wavelet as minimal_wavelet |
|
|
|
|
|
|
|
|
class WaveletDetailTuple2d(NamedTuple): |
|
|
horizontal: torch.Tensor |
|
|
vertical: torch.Tensor |
|
|
diagonal: torch.Tensor |
|
|
|
|
|
|
|
|
WaveletCoeff2d: TypeAlias = tuple[torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]] |
|
|
WaveletDetailDict: TypeAlias = dict[str, torch.Tensor] |
|
|
WaveletCoeffNd: TypeAlias = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] |
|
|
|
|
|
|
|
|
class Wavelet(Protocol): |
|
|
name: str |
|
|
dec_lo: Sequence[float] |
|
|
dec_hi: Sequence[float] |
|
|
rec_lo: Sequence[float] |
|
|
rec_hi: Sequence[float] |
|
|
dec_len: int |
|
|
rec_len: int |
|
|
filter_bank: tuple[Sequence[float], Sequence[float], Sequence[float], Sequence[float]] |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.dec_lo) |
|
|
|
|
|
|
|
|
class WaveletTensorTuple(NamedTuple): |
|
|
dec_lo: torch.Tensor |
|
|
dec_hi: torch.Tensor |
|
|
rec_lo: torch.Tensor |
|
|
rec_hi: torch.Tensor |
|
|
|
|
|
@classmethod |
|
|
def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> "WaveletTensorTuple": |
|
|
return cls( |
|
|
torch.tensor(wavelet.dec_lo, dtype=dtype), |
|
|
torch.tensor(wavelet.dec_hi, dtype=dtype), |
|
|
torch.tensor(wavelet.rec_lo, dtype=dtype), |
|
|
torch.tensor(wavelet.rec_hi, dtype=dtype), |
|
|
) |
|
|
|
|
|
|
|
|
def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: |
|
|
if isinstance(wavelet, str): |
|
|
return minimal_wavelet(wavelet) |
|
|
else: |
|
|
return wavelet |
|
|
|
|
|
|
|
|
def _is_dtype_supported(dtype: torch.dtype) -> bool: |
|
|
return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64] |
|
|
|
|
|
|
|
|
def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
|
a_flat = torch.reshape(a, [-1]) |
|
|
b_flat = torch.reshape(b, [-1]) |
|
|
a_mul = torch.unsqueeze(a_flat, dim=-1) |
|
|
b_mul = torch.unsqueeze(b_flat, dim=0) |
|
|
return a_mul * b_mul |
|
|
|
|
|
|
|
|
def _check_if_tensor(array: Any) -> torch.Tensor: |
|
|
if not isinstance(array, torch.Tensor): |
|
|
raise ValueError("First element of coeffs must be the approximation coefficient tensor.") |
|
|
return array |
|
|
|
|
|
|
|
|
def _check_axes_argument(axes: Sequence[int]) -> None: |
|
|
if len(set(axes)) != len(axes): |
|
|
raise ValueError("Cant transform the same axis twice.") |
|
|
|
|
|
|
|
|
def _check_same_device(tensor: torch.Tensor, torch_device: torch.device) -> torch.Tensor: |
|
|
if torch_device != tensor.device: |
|
|
raise ValueError("coefficients must be on the same device") |
|
|
return tensor |
|
|
|
|
|
|
|
|
def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.Tensor: |
|
|
if torch_dtype != tensor.dtype: |
|
|
raise ValueError("coefficients must have the same dtype") |
|
|
return tensor |
|
|
|
|
|
|
|
|
@overload |
|
|
def _coeff_tree_map( |
|
|
coeffs: list[torch.Tensor], function: Callable[[torch.Tensor], torch.Tensor] |
|
|
) -> list[torch.Tensor]: ... |
|
|
@overload |
|
|
def _coeff_tree_map(coeffs: WaveletCoeff2d, function: Callable[[torch.Tensor], torch.Tensor]) -> WaveletCoeff2d: ... |
|
|
@overload |
|
|
def _coeff_tree_map(coeffs: WaveletCoeffNd, function: Callable[[torch.Tensor], torch.Tensor]) -> WaveletCoeffNd: ... |
|
|
def _coeff_tree_map(coeffs, function): |
|
|
approx = function(coeffs[0]) |
|
|
result_lst: list[Any] = [] |
|
|
for element in coeffs[1:]: |
|
|
if isinstance(element, tuple): |
|
|
result_lst.append(WaveletDetailTuple2d(function(element[0]), function(element[1]), function(element[2]))) |
|
|
elif isinstance(element, dict): |
|
|
new_dict = {key: function(value) for key, value in element.items()} |
|
|
result_lst.append(new_dict) |
|
|
elif isinstance(element, torch.Tensor): |
|
|
result_lst.append(function(element)) |
|
|
else: |
|
|
raise ValueError(f"Unexpected input type {type(element)}") |
|
|
if not result_lst: |
|
|
return [approx] if isinstance(coeffs, list) else (approx,) |
|
|
elif isinstance(result_lst[0], torch.Tensor): |
|
|
return [approx] + cast(list[torch.Tensor], result_lst) |
|
|
else: |
|
|
cast_result_lst = cast(Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst) |
|
|
return (approx, *cast_result_lst) |
|
|
|
|
|
|
|
|
def _check_same_device_dtype( |
|
|
coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], |
|
|
) -> tuple[torch.device, torch.dtype]: |
|
|
c = _check_if_tensor(coeffs[0]) |
|
|
torch_device, torch_dtype = c.device, c.dtype |
|
|
_coeff_tree_map(coeffs, partial(_check_same_device, torch_device=torch_device)) |
|
|
_coeff_tree_map(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) |
|
|
return torch_device, torch_dtype |
|
|
|
|
|
|
|
|
def _get_transpose_order(axes: Sequence[int], data_shape: Sequence[int]) -> tuple[list[int], list[int]]: |
|
|
axes = [a + len(data_shape) if a < 0 else a for a in axes] |
|
|
all_axes = list(range(len(data_shape))) |
|
|
remove_transformed = list(filter(lambda a: a not in axes, all_axes)) |
|
|
return remove_transformed, axes |
|
|
|
|
|
|
|
|
def _swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: |
|
|
_check_axes_argument(axes) |
|
|
front, back = _get_transpose_order(axes, list(data.shape)) |
|
|
return torch.permute(data, front + back) |
|
|
|
|
|
|
|
|
def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: |
|
|
_check_axes_argument(axes) |
|
|
front, back = _get_transpose_order(axes, list(data.shape)) |
|
|
restore_sorted = torch.argsort(torch.tensor(front + back)).tolist() |
|
|
return torch.permute(data, restore_sorted) |
|
|
|
|
|
|
|
|
def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int]]: |
|
|
dshape = list(data.shape) |
|
|
return (torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]), dshape) |
|
|
|
|
|
|
|
|
def _unfold_axes(data: torch.Tensor, ds: list[int], keep_no: int) -> torch.Tensor: |
|
|
return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:])) |
|
|
|
|
|
|
|
|
def _preprocess_coeffs(coeffs, ndim: int, axes, add_channel_dim: bool = False): |
|
|
if isinstance(axes, int): |
|
|
axes = (axes,) |
|
|
torch_dtype = _check_if_tensor(coeffs[0]).dtype |
|
|
if not _is_dtype_supported(torch_dtype): |
|
|
raise ValueError(f"Input dtype {torch_dtype} not supported") |
|
|
if ndim <= 0: |
|
|
raise ValueError("Number of dimensions must be positive") |
|
|
if tuple(axes) != tuple(range(-ndim, 0)): |
|
|
if len(axes) != ndim: |
|
|
raise ValueError(f"{ndim}D transforms work with {ndim} axes.") |
|
|
else: |
|
|
swap_fn = partial(_swap_axes, axes=axes) |
|
|
coeffs = _coeff_tree_map(coeffs, swap_fn) |
|
|
ds = list(coeffs[0].shape) |
|
|
if len(ds) < ndim: |
|
|
raise ValueError(f"At least {ndim} input dimensions required.") |
|
|
elif len(ds) == ndim: |
|
|
coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(0)) |
|
|
elif len(ds) > ndim + 1: |
|
|
coeffs = _coeff_tree_map(coeffs, lambda t: _fold_axes(t, ndim)[0]) |
|
|
if add_channel_dim: |
|
|
coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(1)) |
|
|
return coeffs, ds |
|
|
|
|
|
|
|
|
def _postprocess_coeffs(coeffs, ndim: int, ds: list[int], axes): |
|
|
if isinstance(axes, int): |
|
|
axes = (axes,) |
|
|
if ndim <= 0: |
|
|
raise ValueError("Number of dimensions must be positive") |
|
|
if len(ds) < ndim: |
|
|
raise ValueError(f"At least {ndim} input dimensions required.") |
|
|
elif len(ds) == ndim: |
|
|
coeffs = _coeff_tree_map(coeffs, lambda x: x.squeeze(0)) |
|
|
elif len(ds) > ndim + 1: |
|
|
unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) |
|
|
coeffs = _coeff_tree_map(coeffs, unfold_axes_fn) |
|
|
if tuple(axes) != tuple(range(-ndim, 0)): |
|
|
if len(axes) != ndim: |
|
|
raise ValueError(f"{ndim}D transforms work with {ndim} axes.") |
|
|
else: |
|
|
undo_swap_fn = partial(_undo_swap_axes, axes=axes) |
|
|
coeffs = _coeff_tree_map(coeffs, undo_swap_fn) |
|
|
return coeffs |
|
|
|
|
|
|
|
|
def _postprocess_tensor( |
|
|
data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] |
|
|
) -> torch.Tensor: |
|
|
return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] |
|
|
|
|
|
|
|
|
def _get_filter_tensors( |
|
|
wavelet: Union[Wavelet, str], flip: bool, device: torch.device, dtype: torch.dtype |
|
|
) -> WaveletTensorTuple: |
|
|
wavelet = _as_wavelet(wavelet) |
|
|
if flip: |
|
|
filters = WaveletTensorTuple( |
|
|
torch.tensor(wavelet.rec_lo, device=device, dtype=dtype), |
|
|
torch.tensor(wavelet.rec_hi, device=device, dtype=dtype), |
|
|
torch.tensor(wavelet.dec_lo, device=device, dtype=dtype), |
|
|
torch.tensor(wavelet.dec_hi, device=device, dtype=dtype), |
|
|
) |
|
|
else: |
|
|
filters = WaveletTensorTuple.from_wavelet(wavelet, dtype=dtype) |
|
|
filters = WaveletTensorTuple( |
|
|
filters.dec_lo.to(device), |
|
|
filters.dec_hi.to(device), |
|
|
filters.rec_lo.to(device), |
|
|
filters.rec_hi.to(device), |
|
|
) |
|
|
return filters |
|
|
|
|
|
|
|
|
def _adjust_padding_at_reconstruction(tensor_len: int, coeff_len: int, padr: int, padl: int) -> tuple[int, int]: |
|
|
if 2 * coeff_len - tensor_len == 1: |
|
|
padr += 1 |
|
|
elif 2 * coeff_len - tensor_len != 0: |
|
|
raise ValueError("incorrect padding") |
|
|
return padr, padl |
|
|
|
|
|
|
|
|
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor: |
|
|
ll = _outer(lo, lo) |
|
|
lh = _outer(hi, lo) |
|
|
hl = _outer(lo, hi) |
|
|
hh = _outer(hi, hi) |
|
|
filt = torch.stack([ll, lh, hl, hh], 0) |
|
|
filt = filt.unsqueeze(1) |
|
|
return filt |
|
|
|
|
|
|
|
|
def waverec2d( |
|
|
coeffs: WaveletCoeff2d, |
|
|
wavelet: Union[Wavelet, str], |
|
|
axes: tuple[int, int] = (-2, -1), |
|
|
) -> torch.Tensor: |
|
|
coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) |
|
|
torch_device, torch_dtype = _check_same_device_dtype(coeffs) |
|
|
|
|
|
_, _, rec_lo, rec_hi = _get_filter_tensors(wavelet, flip=False, device=torch_device, dtype=torch_dtype) |
|
|
filt_len = rec_lo.shape[-1] |
|
|
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) |
|
|
|
|
|
res_ll = coeffs[0] |
|
|
for c_pos, coeff_tuple in enumerate(coeffs[1:]): |
|
|
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: |
|
|
raise ValueError(f"Unexpected detail coefficient type: {type(coeff_tuple)}. Must be a 3-tuple.") |
|
|
|
|
|
curr_shape = res_ll.shape |
|
|
for coeff in coeff_tuple: |
|
|
if coeff.shape != curr_shape: |
|
|
raise ValueError("All coefficients on each level must have the same shape") |
|
|
|
|
|
res_lh, res_hl, res_hh = coeff_tuple |
|
|
res_ll = torch.stack([res_ll, res_lh, res_hl, res_hh], 1) |
|
|
res_ll = torch.nn.functional.conv_transpose2d(res_ll, rec_filt, stride=2).squeeze(1) |
|
|
|
|
|
padl = (2 * filt_len - 3) // 2 |
|
|
padr = (2 * filt_len - 3) // 2 |
|
|
padt = (2 * filt_len - 3) // 2 |
|
|
padb = (2 * filt_len - 3) // 2 |
|
|
if c_pos < len(coeffs) - 2: |
|
|
padr, padl = _adjust_padding_at_reconstruction( |
|
|
res_ll.shape[-1], coeffs[c_pos + 2][0].shape[-1], padr, padl |
|
|
) |
|
|
padb, padt = _adjust_padding_at_reconstruction( |
|
|
res_ll.shape[-2], coeffs[c_pos + 2][0].shape[-2], padb, padt |
|
|
) |
|
|
|
|
|
if padt > 0: |
|
|
res_ll = res_ll[..., padt:, :] |
|
|
if padb > 0: |
|
|
res_ll = res_ll[..., :-padb, :] |
|
|
if padl > 0: |
|
|
res_ll = res_ll[..., padl:] |
|
|
if padr > 0: |
|
|
res_ll = res_ll[..., :-padr] |
|
|
|
|
|
res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) |
|
|
return res_ll |
|
|
|