Spaces:
Running
on
Zero
Running
on
Zero
| """Op utility functions.""" | |
| from __future__ import annotations | |
| from functools import partial | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| def multi_apply(func, *args, **kwargs): | |
| """Apply function to a list of arguments. | |
| Note: | |
| This function applies the ``func`` to multiple inputs and | |
| map the multiple outputs of the ``func`` into different | |
| list. Each list contains the same type of outputs corresponding | |
| to different inputs. | |
| Args: | |
| func (Function): A function that will be applied to a list of | |
| arguments | |
| Returns: | |
| tuple(list): A tuple containing multiple list, each list contains \ | |
| a kind of returned results by the function | |
| """ | |
| pfunc = partial(func, **kwargs) if kwargs else func | |
| map_results = map(pfunc, *args) | |
| return tuple(map(list, zip(*map_results))) | |
| def flat_interpolate( | |
| flat_tensor: Tensor, | |
| old: tuple[int, int], | |
| new: tuple[int, int], | |
| antialias: bool = True, | |
| mode: str = "bilinear", | |
| ) -> Tensor: | |
| if old[0] == new[0] and old[1] == new[1]: | |
| return flat_tensor | |
| tensor = flat_tensor.view( | |
| flat_tensor.shape[0], old[0], old[1], -1 | |
| ).permute( | |
| 0, 3, 1, 2 | |
| ) # b c h w | |
| tensor_interp = F.interpolate( | |
| tensor, | |
| size=(new[0], new[1]), | |
| mode=mode, | |
| align_corners=False, | |
| antialias=antialias, | |
| ) | |
| flat_tensor_interp = tensor_interp.view( | |
| flat_tensor.shape[0], -1, new[0] * new[1] | |
| ).permute( | |
| 0, 2, 1 | |
| ) # b (h w) c | |
| return flat_tensor_interp.contiguous() | |