Spaces:
Build error
Build error
| from torch import nn | |
| from models.stylegan2.model import PixelNorm | |
| from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential, Module, Conv2d, GroupNorm | |
| class TextModulationModule(Module): | |
| def __init__(self, in_channels): | |
| super(TextModulationModule, self).__init__() | |
| self.conv = Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=False) | |
| self.norm = GroupNorm(32, in_channels) | |
| self.gamma_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels)) | |
| self.beta_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels)) | |
| self.leakyrelu = LeakyReLU() | |
| def forward(self, x, embedding): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| log_gamma = self.gamma_function(embedding.float()) | |
| gamma = log_gamma.exp().unsqueeze(2).unsqueeze(3) | |
| beta = self.beta_function(embedding.float()).unsqueeze(2).unsqueeze(3) | |
| out = x * (1 + gamma) + beta | |
| out = self.leakyrelu(out) | |
| return out | |
| class SubTextMapper(Module): | |
| def __init__(self, opts, in_channels): | |
| super(SubTextMapper, self).__init__() | |
| self.opts = opts | |
| self.pixelnorm = PixelNorm() | |
| self.modulation_module_list = nn.ModuleList([TextModulationModule(in_channels) for _ in range(1)]) | |
| def forward(self, x, embedding): | |
| x = self.pixelnorm(x) | |
| for modulation_module in self.modulation_module_list: | |
| x = modulation_module(x, embedding) | |
| return x | |
| class CLIPAdapter(Module): | |
| def __init__(self, opts): | |
| super(CLIPAdapter, self).__init__() | |
| self.opts = opts | |
| if not opts.no_coarse_mapper: | |
| self.coarse_mapping = SubTextMapper(opts, 512) | |
| if not opts.no_medium_mapper: | |
| self.medium_mapping = SubTextMapper(opts, 256) | |
| if not opts.no_fine_mapper: | |
| self.fine_mapping = SubTextMapper(opts, 128) | |
| def forward(self, features, txt_embed): | |
| txt_embed = txt_embed.detach() | |
| c1, c2, c3 = features | |
| if not self.opts.no_coarse_mapper: | |
| c3 = self.coarse_mapping(c3, txt_embed) | |
| if not self.opts.no_medium_mapper: | |
| c2 = self.medium_mapping(c2, txt_embed) | |
| if not self.opts.no_fine_mapper: | |
| c1 = self.fine_mapping(c1, txt_embed) | |
| return (c1,c2,c3) |