# app.py import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms import gradio as gr import numpy as np from PIL import Image # 🧱 VAE architecture class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparam(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): return torch.sigmoid(self.fc4(F.relu(self.fc3(z)))) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparam(mu, logvar) return self.decode(z), mu, logvar # 🧪 Loss function def vae_loss(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD # 📦 Load MNIST and train def train_vae(): vae = VAE() optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) loader = torch.utils.data.DataLoader( datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True) for epoch in range(1): for x, _ in loader: optimizer.zero_grad() recon, mu, logvar = vae(x) loss = vae_loss(recon, x, mu, logvar) loss.backward() optimizer.step() return vae vae_model = train_vae() # 🎨 Gradio UI def generate(latent1=0.0, latent2=0.0): z = torch.tensor([[latent1]*10 + [latent2]*10]) img = vae_model.decode(z).view(28, 28).detach().numpy() img = (img * 255).astype(np.uint8) return Image.fromarray(img, mode='L') demo = gr.Interface( fn=generate, inputs=[ gr.Slider(-3, 3, 0, label="Latent dim 1"), gr.Slider(-3, 3, 0, label="Latent dim 2") ], outputs=gr.Image(type="pil", label="Generated Digit"), title="🧠 VAE Digit Generator" ) demo.launch()