Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import rearrange | |
| def video_to_image(func): | |
| def wrapper(self, x, *args, **kwargs): | |
| if x.dim() == 5: | |
| t = x.shape[2] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = func(self, x, *args, **kwargs) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
| return x | |
| return wrapper | |
| def nonlinearity(x): | |
| return x * torch.sigmoid(x) | |
| def cast_tuple(t, length=1): | |
| return t if isinstance(t, tuple) else ((t,) * length) | |
| def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
| n_dims = len(x.shape) | |
| if src_dim < 0: | |
| src_dim = n_dims + src_dim | |
| if dest_dim < 0: | |
| dest_dim = n_dims + dest_dim | |
| assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
| dims = list(range(n_dims)) | |
| del dims[src_dim] | |
| permutation = [] | |
| ctr = 0 | |
| for i in range(n_dims): | |
| if i == dest_dim: | |
| permutation.append(src_dim) | |
| else: | |
| permutation.append(dims[ctr]) | |
| ctr += 1 | |
| x = x.permute(permutation) | |
| if make_contiguous: | |
| x = x.contiguous() | |
| return x |