RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Op utility functions."""
from __future__ import annotations
from functools import partial
import torch.nn.functional as F
from torch import Tensor
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
def flat_interpolate(
flat_tensor: Tensor,
old: tuple[int, int],
new: tuple[int, int],
antialias: bool = True,
mode: str = "bilinear",
) -> Tensor:
if old[0] == new[0] and old[1] == new[1]:
return flat_tensor
tensor = flat_tensor.view(
flat_tensor.shape[0], old[0], old[1], -1
).permute(
0, 3, 1, 2
) # b c h w
tensor_interp = F.interpolate(
tensor,
size=(new[0], new[1]),
mode=mode,
align_corners=False,
antialias=antialias,
)
flat_tensor_interp = tensor_interp.view(
flat_tensor.shape[0], -1, new[0] * new[1]
).permute(
0, 2, 1
) # b (h w) c
return flat_tensor_interp.contiguous()