Spaces:
Build error
Build error
| import matplotlib | |
| matplotlib.use('Agg') | |
| import torch | |
| from torch import nn | |
| from models.encoders import psp_encoders_features | |
| def get_keys(d, name): | |
| if 'state_dict' in d: | |
| d = d['state_dict'] | |
| d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} | |
| return d_filt | |
| class pSp(nn.Module): | |
| def __init__(self, opts): | |
| super(pSp, self).__init__() | |
| self.opts = opts | |
| # Define architecture | |
| self.encoder = self.set_encoder().eval() | |
| # Load weights if needed | |
| self.load_weights() | |
| def set_encoder(self): | |
| encoder = psp_encoders_features.Encoder4Editing(50, 'ir_se', self.opts) | |
| return encoder | |
| def load_weights(self): | |
| # We only load the encoder weights | |
| print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.pretrained_e4e_path)) | |
| ckpt = torch.load(self.opts.pretrained_e4e_path, map_location='cpu') | |
| self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) | |
| self.__load_latent_avg(ckpt) | |
| def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, | |
| inject_latent=None, return_latents=False, alpha=None): | |
| if input_code: | |
| codes = x | |
| else: | |
| codes, features = self.encoder(x) | |
| # normalize with respect to the center of an average face | |
| if self.opts.start_from_latent_avg: | |
| if codes.ndim == 2: | |
| codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] | |
| else: | |
| codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) | |
| if latent_mask is not None: | |
| for i in latent_mask: | |
| if inject_latent is not None: | |
| if alpha is not None: | |
| codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] | |
| else: | |
| codes[:, i] = inject_latent[:, i] | |
| else: | |
| codes[:, i] = 0 | |
| return codes, features | |
| # Forward the modulated feature maps | |
| def forward_features(self, features): | |
| return self.encoder.forward_features(features) | |
| def __load_latent_avg(self, ckpt, repeat=None): | |
| if 'latent_avg' in ckpt: | |
| self.latent_avg = ckpt['latent_avg'].to(self.opts.device) | |
| if repeat is not None: | |
| self.latent_avg = self.latent_avg.repeat(repeat, 1) | |
| else: | |
| self.latent_avg = None | |