Spaces:
Runtime error
Runtime error
| import os, torch, numpy | |
| from torch.utils.data import TensorDataset | |
| def z_dataset_for_model(model, size=100, seed=1): | |
| return TensorDataset(z_sample_for_model(model, size, seed)) | |
| def z_sample_for_model(model, size=100, seed=1): | |
| # If the model is marked with an input shape, use it. | |
| if hasattr(model, 'input_shape'): | |
| sample = standard_z_sample(size, model.input_shape[1], seed=seed).view( | |
| (size,) + model.input_shape[1:]) | |
| return sample | |
| # Examine first conv in model to determine input feature size. | |
| first_layer = [c for c in model.modules() | |
| if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, | |
| torch.nn.Linear))][0] | |
| # 4d input if convolutional, 2d input if first layer is linear. | |
| if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): | |
| sample = standard_z_sample( | |
| size, first_layer.in_channels, seed=seed)[:,:,None,None] | |
| else: | |
| sample = standard_z_sample( | |
| size, first_layer.in_features, seed=seed) | |
| return sample | |
| def standard_z_sample(size, depth, seed=1, device=None): | |
| ''' | |
| Generate a standard set of random Z as a (size, z_dimension) tensor. | |
| With the same random seed, it always returns the same z (e.g., | |
| the first one is always the same regardless of the size.) | |
| ''' | |
| # Use numpy RandomState since it can be done deterministically | |
| # without affecting global state | |
| rng = numpy.random.RandomState(seed) | |
| result = torch.from_numpy( | |
| rng.standard_normal(size * depth) | |
| .reshape(size, depth)).float() | |
| if device is not None: | |
| result = result.to(device) | |
| return result | |