Spaces:
Runtime error
Runtime error
| import argparse | |
| import pickle | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| from scipy import linalg | |
| from tqdm import tqdm | |
| from model import Generator | |
| from calc_inception import load_patched_inception_v3 | |
| def extract_feature_from_samples( | |
| generator, inception, truncation, truncation_latent, batch_size, n_sample, device | |
| ): | |
| n_batch = n_sample // batch_size | |
| resid = n_sample - (n_batch * batch_size) | |
| batch_sizes = [batch_size] * n_batch + [resid] | |
| features = [] | |
| for batch in tqdm(batch_sizes): | |
| latent = torch.randn(batch, 512, device=device) | |
| img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) | |
| feat = inception(img)[0].view(img.shape[0], -1) | |
| features.append(feat.to('cpu')) | |
| features = torch.cat(features, 0) | |
| return features | |
| def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): | |
| cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) | |
| if not np.isfinite(cov_sqrt).all(): | |
| print('product of cov matrices is singular') | |
| offset = np.eye(sample_cov.shape[0]) * eps | |
| cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) | |
| if np.iscomplexobj(cov_sqrt): | |
| if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(cov_sqrt.imag)) | |
| raise ValueError(f'Imaginary component {m}') | |
| cov_sqrt = cov_sqrt.real | |
| mean_diff = sample_mean - real_mean | |
| mean_norm = mean_diff @ mean_diff | |
| trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) | |
| fid = mean_norm + trace | |
| return fid | |
| if __name__ == '__main__': | |
| device = 'cuda' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--truncation', type=float, default=1) | |
| parser.add_argument('--truncation_mean', type=int, default=4096) | |
| parser.add_argument('--batch', type=int, default=64) | |
| parser.add_argument('--n_sample', type=int, default=50000) | |
| parser.add_argument('--size', type=int, default=256) | |
| parser.add_argument('--inception', type=str, default=None, required=True) | |
| parser.add_argument('ckpt', metavar='CHECKPOINT') | |
| args = parser.parse_args() | |
| ckpt = torch.load(args.ckpt) | |
| g = Generator(args.size, 512, 8).to(device) | |
| g.load_state_dict(ckpt['g_ema']) | |
| g = nn.DataParallel(g) | |
| g.eval() | |
| if args.truncation < 1: | |
| with torch.no_grad(): | |
| mean_latent = g.mean_latent(args.truncation_mean) | |
| else: | |
| mean_latent = None | |
| inception = nn.DataParallel(load_patched_inception_v3()).to(device) | |
| inception.eval() | |
| features = extract_feature_from_samples( | |
| g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device | |
| ).numpy() | |
| print(f'extracted {features.shape[0]} features') | |
| sample_mean = np.mean(features, 0) | |
| sample_cov = np.cov(features, rowvar=False) | |
| with open(args.inception, 'rb') as f: | |
| embeds = pickle.load(f) | |
| real_mean = embeds['mean'] | |
| real_cov = embeds['cov'] | |
| fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) | |
| print('fid:', fid) | |