| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| class SineLayer(nn.Module): | |
| """ | |
| Paper: Implicit Neural Representation with Periodic Activ ation Function (SIREN) | |
| """ | |
| def __init__( | |
| self, in_features, out_features, bias=True, is_first=False, omega_0=30 | |
| ): | |
| super().__init__() | |
| self.omega_0 = omega_0 | |
| self.is_first = is_first | |
| self.in_features = in_features | |
| self.linear = nn.Linear(in_features, out_features, bias=bias) | |
| self.init_weights() | |
| def init_weights(self): | |
| with torch.no_grad(): | |
| if self.is_first: | |
| self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) | |
| else: | |
| self.linear.weight.uniform_( | |
| -np.sqrt(6 / self.in_features) / self.omega_0, | |
| np.sqrt(6 / self.in_features) / self.omega_0, | |
| ) | |
| def forward(self, input): | |
| return torch.sin(self.omega_0 * self.linear(input)) | |
| class ToPixel(nn.Module): | |
| def __init__( | |
| self, to_pixel="linear", img_size=256, in_channels=3, in_dim=512, patch_size=16 | |
| ) -> None: | |
| super().__init__() | |
| self.to_pixel_name = to_pixel | |
| self.patch_size = patch_size | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.in_channels = in_channels | |
| if to_pixel == "linear": | |
| self.model = nn.Linear(in_dim, in_channels * patch_size * patch_size) | |
| elif to_pixel == "conv": | |
| self.model = nn.Sequential( | |
| Rearrange("b (h w) c -> b c h w", h=img_size // patch_size), | |
| nn.ConvTranspose2d( | |
| in_dim, in_channels, kernel_size=patch_size, stride=patch_size | |
| ), | |
| ) | |
| elif to_pixel == "siren": | |
| self.model = nn.Sequential( | |
| SineLayer(in_dim, in_dim * 2, is_first=True, omega_0=30.0), | |
| SineLayer( | |
| in_dim * 2, | |
| img_size // patch_size * patch_size * in_channels, | |
| is_first=False, | |
| omega_0=30, | |
| ), | |
| ) | |
| elif to_pixel == "identity": | |
| self.model = nn.Identity() | |
| else: | |
| raise NotImplementedError | |
| def get_last_layer(self): | |
| if self.to_pixel_name == "linear": | |
| return self.model.weight | |
| elif self.to_pixel_name == "siren": | |
| return self.model[1].linear.weight | |
| elif self.to_pixel_name == "conv": | |
| return self.model[1].weight | |
| else: | |
| return None | |
| def unpatchify(self, x): | |
| """ | |
| x: (N, L, patch_size**2 *3) | |
| imgs: (N, 3, H, W) | |
| """ | |
| p = self.patch_size | |
| h = w = int(x.shape[1] ** 0.5) | |
| assert h * w == x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) | |
| x = torch.einsum("nhwpqc->nchpwq", x) | |
| imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) | |
| return imgs | |
| def forward(self, x): | |
| if self.to_pixel_name == "linear": | |
| x = self.model(x) | |
| x = self.unpatchify(x) | |
| elif self.to_pixel_name == "siren": | |
| x = self.model(x) | |
| x = x.view( | |
| x.shape[0], | |
| self.in_channels, | |
| self.patch_size * int(self.num_patches**0.5), | |
| self.patch_size * int(self.num_patches**0.5), | |
| ) | |
| elif self.to_pixel_name == "conv": | |
| x = self.model(x) | |
| elif self.to_pixel_name == "identity": | |
| pass | |
| return x | |