Spaces:
Build error
Build error
| import jax | |
| import jax.numpy as jnp | |
| import flax | |
| import flax.linen as nn | |
| import numpy as np | |
| import os | |
| import functools | |
| import argparse | |
| import scipy | |
| from tqdm import tqdm | |
| import logging | |
| from . import inception | |
| from . import utils | |
| logger = logging.getLogger(__name__) | |
| class FID: | |
| def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0): | |
| """ | |
| Evaluates the FID score for a given generator and a given dataset. | |
| Implementation mostly taken from https://github.com/matthias-wright/jax-fid | |
| Reference: https://arxiv.org/abs/1706.08500 | |
| Args: | |
| generator (nn.Module): Generator network. | |
| dataset (tf.data.Dataset): Dataset containing the real images. | |
| config (argparse.Namespace): Configuration. | |
| use_cache (bool): If True, only compute the activation stats once for the real images and store them. | |
| truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled. | |
| """ | |
| self.num_images = config.num_fid_images | |
| self.batch_size = config.batch_size | |
| self.c_dim = config.c_dim | |
| self.z_dim = config.z_dim | |
| self.dataset = dataset | |
| self.num_devices = jax.device_count() | |
| self.num_local_devices = jax.local_device_count() | |
| self.use_cache = use_cache | |
| if self.use_cache: | |
| self.cache = {} | |
| rng = jax.random.PRNGKey(0) | |
| inception_net = inception.InceptionV3(pretrained=True) | |
| self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3))) | |
| self.inception_params = flax.jax_utils.replicate(self.inception_params) | |
| #self.inception = jax.jit(functools.partial(model.apply, train=False)) | |
| self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch') | |
| self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch') | |
| def compute_fid(self, generator_params, seed_offset=0): | |
| generator_params = flax.jax_utils.replicate(generator_params) | |
| mu_real, sigma_real = self.compute_stats_for_dataset() | |
| mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset) | |
| fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6) | |
| return fid_score | |
| def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6): | |
| # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py | |
| mu1 = np.atleast_1d(mu1) | |
| mu2 = np.atleast_1d(mu2) | |
| sigma1 = np.atleast_1d(sigma1) | |
| sigma2 = np.atleast_1d(sigma2) | |
| assert mu1.shape == mu2.shape | |
| assert sigma1.shape == sigma2.shape | |
| diff = mu1 - mu2 | |
| covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
| if not np.isfinite(covmean).all(): | |
| msg = ('fid calculation produces singular product; ' | |
| 'adding %s to diagonal of cov estimates') % eps | |
| logger.info(msg) | |
| offset = np.eye(sigma1.shape[0]) * eps | |
| covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
| # Numerical error might give slight imaginary component | |
| if np.iscomplexobj(covmean): | |
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(covmean.imag)) | |
| raise ValueError('Imaginary component {}'.format(m)) | |
| covmean = covmean.real | |
| tr_covmean = np.trace(covmean) | |
| return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) | |
| def compute_stats_for_dataset(self): | |
| if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache: | |
| logger.info('Use cached statistics for dataset...') | |
| return self.cache['mu'], self.cache['sigma'] | |
| print() | |
| logger.info('Compute statistics for dataset...') | |
| image_count = 0 | |
| activations = [] | |
| for batch in utils.prefetch(self.dataset, n_prefetch=2): | |
| act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image'])) | |
| act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1)) | |
| activations.append(act) | |
| image_count += self.num_local_devices * self.batch_size | |
| if image_count >= self.num_images: | |
| break | |
| activations = jnp.concatenate(activations, axis=0) | |
| activations = activations[:self.num_images] | |
| mu = np.mean(activations, axis=0) | |
| sigma = np.cov(activations, rowvar=False) | |
| self.cache['mu'] = mu | |
| self.cache['sigma'] = sigma | |
| return mu, sigma | |
| def compute_stats_for_generator(self, generator_params, seed_offset): | |
| print() | |
| logger.info('Compute statistics for generator...') | |
| num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices))) | |
| activations = [] | |
| for i in range(num_batches): | |
| rng = jax.random.PRNGKey(seed_offset + i) | |
| z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim)) | |
| labels = None | |
| if self.c_dim > 0: | |
| labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim) | |
| labels = jax.nn.one_hot(labels, num_classes=self.c_dim) | |
| labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim)) | |
| image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels) | |
| image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) | |
| image = 2 * image - 1 | |
| act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image)) | |
| act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1)) | |
| activations.append(act) | |
| activations = jnp.concatenate(activations, axis=0) | |
| activations = activations[:self.num_images] | |
| mu = np.mean(activations, axis=0) | |
| sigma = np.cov(activations, rowvar=False) | |
| return mu, sigma | |