| import math | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| from PIL import Image | |
| from cli import iterative_refinement | |
| from viz import grid_of_images_default | |
| models = { | |
| "ConvAE": torch.load("convae.th", map_location="cpu"), | |
| "Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"), | |
| "Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"), | |
| } | |
| def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg, binarize, binarize_threshold): | |
| torch.manual_seed(int(seed)) | |
| bs = 64 | |
| model = models[model_name] | |
| if model_name == "Dense K-Sparse": | |
| model.nb_active = nb_active | |
| samples = iterative_refinement( | |
| model, | |
| nb_iter=int(nb_iter), | |
| nb_examples=int(nb_samples), | |
| w=int(width), h=int(height), c=1, | |
| batch_size=bs, | |
| binarize_threshold=binarize_threshold if binarize else None, | |
| ) | |
| if not black_bg: | |
| samples = 1 - samples | |
| if only_last: | |
| s = int(math.sqrt((nb_samples))) | |
| grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s)) | |
| else: | |
| grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1])) | |
| grid = (grid*255).astype("uint8") | |
| return Image.fromarray(grid) | |
| text = """ | |
| This interface supports generation of samples from: | |
| - ConvAE model (from [`Digits that are not: Generating new types through deep neural nets`](https://arxiv.org/pdf/1606.04345.pdf)) | |
| - DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`) | |
| - Dense K-Sparse model (from [`Out-of-class novelty generation`](https://openreview.net/forum?id=r1QXQkSYg)) | |
| These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details. | |
| NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer. | |
| """ | |
| iface = gr.Interface( | |
| fn=gen, | |
| inputs=[ | |
| gr.Markdown(text), | |
| gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background"), gr.Checkbox(value=False, label="binarize"), gr.Number(value=0.5) | |
| ], | |
| outputs="image" | |
| ) | |
| iface.launch() | |