Spaces:
Runtime error
Runtime error
| # Copyright 2020 Erik Härkönen. All rights reserved. | |
| # This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. You may obtain a copy | |
| # of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software distributed under | |
| # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
| # OF ANY KIND, either express or implied. See the License for the specific language | |
| # governing permissions and limitations under the License. | |
| # Patch for broken CTRL+C handler | |
| # https://github.com/ContinuumIO/anaconda-issues/issues/905 | |
| import os | |
| os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| import re | |
| import sys | |
| import datetime | |
| import argparse | |
| import torch | |
| import json | |
| from types import SimpleNamespace | |
| import scipy | |
| from scipy.cluster.vq import kmeans | |
| from tqdm import trange | |
| from netdissect.nethook import InstrumentedModel | |
| from config import Config | |
| from estimators import get_estimator | |
| from models import get_instrumented_model | |
| SEED_SAMPLING = 1 | |
| SEED_RANDOM_DIRS = 2 | |
| SEED_LINREG = 3 | |
| SEED_VISUALIZATION = 5 | |
| B = 20 | |
| n_clusters = 500 | |
| def get_random_dirs(components, dimensions): | |
| gen = np.random.RandomState(seed=SEED_RANDOM_DIRS) | |
| dirs = gen.normal(size=(components, dimensions)) | |
| dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True)) | |
| return dirs.astype(np.float32) | |
| # Compute maximum batch size for given VRAM and network | |
| def get_max_batch_size(inst, device, layer_name=None): | |
| inst.remove_edits() | |
| # Reset statistics | |
| torch.cuda.reset_max_memory_cached(device) | |
| torch.cuda.reset_max_memory_allocated(device) | |
| total_mem = torch.cuda.get_device_properties(device).total_memory | |
| B_max = 20 | |
| # Measure actual usage | |
| for i in range(2, B_max, 2): | |
| z = inst.model.sample_latent(n_samples=i) | |
| if layer_name: | |
| inst.model.partial_forward(z, layer_name) | |
| else: | |
| inst.model.forward(z) | |
| maxmem = torch.cuda.max_memory_allocated(device) | |
| del z | |
| if maxmem > 0.5*total_mem: | |
| print('Batch size {:d}: memory usage {:.0f}MB'.format(i, maxmem / 1e6)) | |
| return i | |
| return B_max | |
| # Solve for directions in latent space that match PCs in activaiton space | |
| def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config): | |
| print('Performing least squares regression', flush=True) | |
| torch.manual_seed(SEED_LINREG) | |
| np.random.seed(SEED_LINREG) | |
| comp = torch.from_numpy(comp_np).float().to(inst.model.device) | |
| mean = torch.from_numpy(mean_np).float().to(inst.model.device) | |
| stdev = torch.from_numpy(stdev_np).float().to(inst.model.device) | |
| n_samp = max(10_000, config.n) // B * B # make divisible | |
| n_comp = comp.shape[0] | |
| latent_dims = inst.model.get_latent_dims() | |
| # We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z | |
| # Z = batch of latent vectors (n_samples x latent_dims) | |
| # G'(Z) = batch of activations at intermediate layer | |
| # A = P*G'(Z) = projected activations (n_samples x pca_coords) | |
| # M = linear mapping (pca_coords x latent_dims) | |
| # Minimization min_M ||MA - Z||_l2 rewritten as min_M.T ||A.T*M.T - Z.T||_l2 | |
| # to match format expected by pytorch.lstsq | |
| # TODO: regression on pixel-space outputs? (using nonlinear optimizer) | |
| # min_M lpips(G_full(MA), G_full(Z)) | |
| # Tensors to fill with data | |
| # Dimensions other way around, so these are actually the transposes | |
| A = np.zeros((n_samp, n_comp), dtype=np.float32) | |
| Z = np.zeros((n_samp, latent_dims), dtype=np.float32) | |
| # Project tensor X onto PCs, return coordinates | |
| def project(X, comp): | |
| N = X.shape[0] | |
| K = comp.shape[0] | |
| coords = torch.bmm(comp.expand([N]+[-1]*comp.ndim), X.view(N, -1, 1)) | |
| return coords.reshape(N, K) | |
| for i in trange(n_samp // B, desc='Collecting samples', ascii=True): | |
| z = inst.model.sample_latent(B) | |
| inst.model.partial_forward(z, config.layer) | |
| act = inst.retained_features()[config.layer].reshape(B, -1) | |
| # Project onto basis | |
| act = act - mean | |
| coords = project(act, comp) | |
| coords_scaled = coords / stdev | |
| A[i*B:(i+1)*B] = coords_scaled.detach().cpu().numpy() | |
| Z[i*B:(i+1)*B] = z.detach().cpu().numpy().reshape(B, -1) | |
| # Solve least squares fit | |
| # gelsd = divide-and-conquer SVD; good default | |
| # gelsy = complete orthogonal factorization; sometimes faster | |
| # gelss = SVD; slow but less memory hungry | |
| M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :] | |
| # Solution given by rows of M_t | |
| Z_comp = M_t[:n_comp, :] | |
| Z_mean = np.mean(Z, axis=0, keepdims=True) | |
| return Z_comp, Z_mean | |
| def regression(comp, mean, stdev, inst, config): | |
| # Sanity check: verify orthonormality | |
| M = np.dot(comp, comp.T) | |
| if not np.allclose(M, np.identity(M.shape[0])): | |
| det = np.linalg.det(M) | |
| print(f'WARNING: Computed basis is not orthonormal (determinant={det})') | |
| return linreg_lstsq(comp, mean, stdev, inst, config) | |
| def compute(config, dump_name, instrumented_model): | |
| global B | |
| timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M") | |
| print(f'[{timestamp()}] Computing', dump_name.name) | |
| # Ensure reproducibility | |
| torch.manual_seed(0) # also sets cuda seeds | |
| np.random.seed(0) | |
| # Speed up backend | |
| torch.backends.cudnn.benchmark = True | |
| has_gpu = torch.cuda.is_available() | |
| device = torch.device('cuda' if has_gpu else 'cpu') | |
| layer_key = config.layer | |
| if instrumented_model is None: | |
| inst = get_instrumented_model(config.model, config.output_class, layer_key, device) | |
| model = inst.model | |
| else: | |
| print('Reusing InstrumentedModel instance') | |
| inst = instrumented_model | |
| model = inst.model | |
| inst.remove_edits() | |
| model.set_output_class(config.output_class) | |
| # Regress back to w space | |
| if config.use_w: | |
| print('Using W latent space') | |
| model.use_w() | |
| inst.retain_layer(layer_key) | |
| model.partial_forward(model.sample_latent(1), layer_key) | |
| sample_shape = inst.retained_features()[layer_key].shape | |
| sample_dims = np.prod(sample_shape) | |
| print('Feature shape:', sample_shape) | |
| input_shape = inst.model.get_latent_shape() | |
| input_dims = inst.model.get_latent_dims() | |
| config.components = min(config.components, sample_dims) | |
| transformer = get_estimator(config.estimator, config.components, config.sparsity) | |
| X = None | |
| X_global_mean = None | |
| # Figure out batch size if not provided | |
| B = config.batch_size or get_max_batch_size(inst, device, layer_key) | |
| # Divisible by B (ignored in output name) | |
| N = config.n // B * B | |
| # Compute maximum batch size based on RAM + pagefile budget | |
| target_bytes = 20 * 1_000_000_000 # GB | |
| feat_size_bytes = sample_dims * np.dtype('float64').itemsize | |
| N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes) | |
| if not transformer.batch_support and N > N_limit_RAM: | |
| print('WARNING: estimator does not support batching, ' \ | |
| 'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N)) | |
| # 32-bit LAPACK gets very unhappy about huge matrices (in linalg.svd) | |
| if config.estimator == 'ica': | |
| lapack_max_N = np.floor_divide(np.iinfo(np.int32).max // 4, sample_dims) # 4x extra buffer | |
| if N > lapack_max_N: | |
| raise RuntimeError(f'Matrices too large for ICA, please use N <= {lapack_max_N}') | |
| print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True) | |
| # Must not depend on chosen batch size (reproducibility) | |
| NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible! | |
| samples = None | |
| if not transformer.batch_support: | |
| samples = np.zeros((N + NB, sample_dims), dtype=np.float32) | |
| torch.manual_seed(config.seed or SEED_SAMPLING) | |
| np.random.seed(config.seed or SEED_SAMPLING) | |
| # Use exactly the same latents regardless of batch size | |
| # Store in main memory, since N might be huge (1M+) | |
| # Run in batches, since sample_latent() might perform Z -> W mapping | |
| n_lat = ((N + NB - 1) // B + 1) * B | |
| latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32) | |
| with torch.no_grad(): | |
| for i in trange(n_lat // B, desc='Sampling latents'): | |
| latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy() | |
| # Decomposition on non-Gaussian latent space | |
| samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W' | |
| canceled = False | |
| try: | |
| X = np.ones((NB, sample_dims), dtype=np.float32) | |
| action = 'Fitting' if transformer.batch_support else 'Collecting' | |
| for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True): | |
| for mb in range(0, NB, B): | |
| z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device) | |
| if samples_are_latents: | |
| # Decomposition on latents directly (e.g. StyleGAN W) | |
| batch = z.reshape((B, -1)) | |
| else: | |
| # Decomposition on intermediate layer | |
| with torch.no_grad(): | |
| model.partial_forward(z, layer_key) | |
| # Permuted to place PCA dimensions last | |
| batch = inst.retained_features()[layer_key].reshape((B, -1)) | |
| space_left = min(B, NB - mb) | |
| X[mb:mb+space_left] = batch.cpu().numpy()[:space_left] | |
| if transformer.batch_support: | |
| if not transformer.fit_partial(X.reshape(-1, sample_dims)): | |
| break | |
| else: | |
| samples[gi:gi+NB, :] = X.copy() | |
| except KeyboardInterrupt: | |
| if not transformer.batch_support: | |
| sys.exit(1) # no progress yet | |
| dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}') | |
| print(f'Saving current state to "{dump_name.name}" before exiting') | |
| canceled = True | |
| if not transformer.batch_support: | |
| X = samples # Use all samples | |
| X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...! | |
| X -= X_global_mean | |
| print(f'[{timestamp()}] Fitting whole batch') | |
| t_start_fit = datetime.datetime.now() | |
| transformer.fit(X) | |
| print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}') | |
| assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero' | |
| else: | |
| X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims)) | |
| X = X.reshape(-1, sample_dims) | |
| X -= X_global_mean | |
| X_comp, X_stdev, X_var_ratio = transformer.get_components() | |
| assert X_comp.shape[1] == sample_dims \ | |
| and X_comp.shape[0] == config.components \ | |
| and X_global_mean.shape[1] == sample_dims \ | |
| and X_stdev.shape[0] == config.components, 'Invalid shape' | |
| # 'Activations' are really latents in a secondary latent space | |
| if samples_are_latents: | |
| Z_comp = X_comp | |
| Z_global_mean = X_global_mean | |
| else: | |
| Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config) | |
| # Normalize | |
| Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True) | |
| # Random projections | |
| # We expect these to explain much less of the variance | |
| random_dirs = get_random_dirs(config.components, np.prod(sample_shape)) | |
| n_rand_samples = min(5000, X.shape[0]) | |
| X_view = X[:n_rand_samples, :].T | |
| assert np.shares_memory(X_view, X), "Error: slice produced copy" | |
| X_stdev_random = np.dot(random_dirs, X_view).std(axis=1) | |
| # Inflate back to proper shapes (for easier broadcasting) | |
| X_comp = X_comp.reshape(-1, *sample_shape) | |
| X_global_mean = X_global_mean.reshape(sample_shape) | |
| Z_comp = Z_comp.reshape(-1, *input_shape) | |
| Z_global_mean = Z_global_mean.reshape(input_shape) | |
| # Compute stdev in latent space if non-Gaussian | |
| lat_stdev = np.ones_like(X_stdev) | |
| if config.use_w: | |
| samples = model.sample_latent(5000).reshape(5000, input_dims).detach().cpu().numpy() | |
| coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T) | |
| lat_stdev = coords.std(axis=1) | |
| os.makedirs(dump_name.parent, exist_ok=True) | |
| np.savez_compressed(dump_name, **{ | |
| 'act_comp': X_comp.astype(np.float32), | |
| 'act_mean': X_global_mean.astype(np.float32), | |
| 'act_stdev': X_stdev.astype(np.float32), | |
| 'lat_comp': Z_comp.astype(np.float32), | |
| 'lat_mean': Z_global_mean.astype(np.float32), | |
| 'lat_stdev': lat_stdev.astype(np.float32), | |
| 'var_ratio': X_var_ratio.astype(np.float32), | |
| 'random_stdevs': X_stdev_random.astype(np.float32), | |
| }) | |
| if canceled: | |
| sys.exit(1) | |
| # Don't shutdown if passed as param | |
| if instrumented_model is None: | |
| inst.close() | |
| del inst | |
| del model | |
| del X | |
| del X_comp | |
| del random_dirs | |
| del batch | |
| del samples | |
| del latents | |
| torch.cuda.empty_cache() | |
| # Return cached results or commpute if needed | |
| # Pass existing InstrumentedModel instance to reuse it | |
| def get_or_compute(config, model=None, submit_config=None, force_recompute=False): | |
| if submit_config is None: | |
| wrkdir = str(Path(__file__).parent.resolve()) | |
| submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir) | |
| # Called directly by run.py | |
| return _compute(submit_config, config, model, force_recompute) | |
| def _compute(submit_config, config, model=None, force_recompute=False): | |
| basedir = Path(submit_config.run_dir) | |
| outdir = basedir / 'out' | |
| if config.n is None: | |
| raise RuntimeError('Must specify number of samples with -n=XXX') | |
| if model and not isinstance(model, InstrumentedModel): | |
| raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"') | |
| if config.use_w and not 'StyleGAN' in config.model: | |
| raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}') | |
| transformer = get_estimator(config.estimator, config.components, config.sparsity) | |
| dump_name = "{}-{}_{}_{}_n{}{}{}.npz".format( | |
| config.model.lower(), | |
| config.output_class.replace(' ', '_'), | |
| config.layer.lower(), | |
| transformer.get_param_str(), | |
| config.n, | |
| '_w' if config.use_w else '', | |
| f'_seed{config.seed}' if config.seed else '' | |
| ) | |
| dump_path = basedir / 'cache' / 'components' / dump_name | |
| if not dump_path.is_file() or force_recompute: | |
| print('Not cached') | |
| t_start = datetime.datetime.now() | |
| compute(config, dump_path, model) | |
| print('Total time:', datetime.datetime.now() - t_start) | |
| return dump_path |