mahesh1209 commited on
Commit
c8e57ed
·
verified ·
1 Parent(s): c6270f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import datasets, transforms
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ # 🧱 VAE architecture
11
+ class VAE(nn.Module):
12
+ def __init__(self):
13
+ super(VAE, self).__init__()
14
+ self.fc1 = nn.Linear(784, 400)
15
+ self.fc21 = nn.Linear(400, 20)
16
+ self.fc22 = nn.Linear(400, 20)
17
+ self.fc3 = nn.Linear(20, 400)
18
+ self.fc4 = nn.Linear(400, 784)
19
+
20
+ def encode(self, x):
21
+ h1 = F.relu(self.fc1(x))
22
+ return self.fc21(h1), self.fc22(h1)
23
+
24
+ def reparam(self, mu, logvar):
25
+ std = torch.exp(0.5 * logvar)
26
+ eps = torch.randn_like(std)
27
+ return mu + eps * std
28
+
29
+ def decode(self, z):
30
+ return torch.sigmoid(self.fc4(F.relu(self.fc3(z))))
31
+
32
+ def forward(self, x):
33
+ mu, logvar = self.encode(x.view(-1, 784))
34
+ z = self.reparam(mu, logvar)
35
+ return self.decode(z), mu, logvar
36
+
37
+ # 🧪 Loss function
38
+ def vae_loss(recon_x, x, mu, logvar):
39
+ BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
40
+ KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
41
+ return BCE + KLD
42
+
43
+ # 📦 Load MNIST and train
44
+ def train_vae():
45
+ vae = VAE()
46
+ optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
47
+ loader = torch.utils.data.DataLoader(
48
+ datasets.MNIST('.', train=True, download=True,
49
+ transform=transforms.ToTensor()),
50
+ batch_size=128, shuffle=True)
51
+ for epoch in range(1):
52
+ for x, _ in loader:
53
+ optimizer.zero_grad()
54
+ recon, mu, logvar = vae(x)
55
+ loss = vae_loss(recon, x, mu, logvar)
56
+ loss.backward()
57
+ optimizer.step()
58
+ return vae
59
+
60
+ vae_model = train_vae()
61
+
62
+ # 🎨 Gradio UI
63
+ def generate(latent1=0.0, latent2=0.0):
64
+ z = torch.tensor([[latent1]*10 + [latent2]*10])
65
+ img = vae_model.decode(z).view(28, 28).detach().numpy()
66
+ img = (img * 255).astype(np.uint8)
67
+ return Image.fromarray(img, mode='L')
68
+
69
+ demo = gr.Interface(
70
+ fn=generate,
71
+ inputs=[
72
+ gr.Slider(-3, 3, 0, label="Latent dim 1"),
73
+ gr.Slider(-3, 3, 0, label="Latent dim 2")
74
+ ],
75
+ outputs=gr.Image(type="pil", label="Generated Digit"),
76
+ title="🧠 VAE Digit Generator"
77
+ )
78
+
79
+ demo.launch()