|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import datasets, transforms |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class VAE(nn.Module): |
|
|
def __init__(self): |
|
|
super(VAE, self).__init__() |
|
|
self.fc1 = nn.Linear(784, 400) |
|
|
self.fc21 = nn.Linear(400, 20) |
|
|
self.fc22 = nn.Linear(400, 20) |
|
|
self.fc3 = nn.Linear(20, 400) |
|
|
self.fc4 = nn.Linear(400, 784) |
|
|
|
|
|
def encode(self, x): |
|
|
h1 = F.relu(self.fc1(x)) |
|
|
return self.fc21(h1), self.fc22(h1) |
|
|
|
|
|
def reparam(self, mu, logvar): |
|
|
std = torch.exp(0.5 * logvar) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def decode(self, z): |
|
|
return torch.sigmoid(self.fc4(F.relu(self.fc3(z)))) |
|
|
|
|
|
def forward(self, x): |
|
|
mu, logvar = self.encode(x.view(-1, 784)) |
|
|
z = self.reparam(mu, logvar) |
|
|
return self.decode(z), mu, logvar |
|
|
|
|
|
|
|
|
def vae_loss(recon_x, x, mu, logvar): |
|
|
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') |
|
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
|
return BCE + KLD |
|
|
|
|
|
|
|
|
def train_vae(): |
|
|
vae = VAE() |
|
|
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) |
|
|
loader = torch.utils.data.DataLoader( |
|
|
datasets.MNIST('.', train=True, download=True, |
|
|
transform=transforms.ToTensor()), |
|
|
batch_size=128, shuffle=True) |
|
|
for epoch in range(1): |
|
|
for x, _ in loader: |
|
|
optimizer.zero_grad() |
|
|
recon, mu, logvar = vae(x) |
|
|
loss = vae_loss(recon, x, mu, logvar) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
return vae |
|
|
|
|
|
vae_model = train_vae() |
|
|
|
|
|
|
|
|
def generate(latent1=0.0, latent2=0.0): |
|
|
z = torch.tensor([[latent1]*10 + [latent2]*10]) |
|
|
img = vae_model.decode(z).view(28, 28).detach().numpy() |
|
|
img = (img * 255).astype(np.uint8) |
|
|
return Image.fromarray(img, mode='L') |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
gr.Slider(-3, 3, 0, label="Latent dim 1"), |
|
|
gr.Slider(-3, 3, 0, label="Latent dim 2") |
|
|
], |
|
|
outputs=gr.Image(type="pil", label="Generated Digit"), |
|
|
title="π§ VAE Digit Generator" |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|