Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from model import * | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| title = "Digit Classifier" | |
| description = ( | |
| "Multilayer-Perceptron built for the fast.ai 'Deep Learning' course " | |
| "to classify handwritten digits from the MNIST dataset. " | |
| ) | |
| inputs = gr.components.Image() | |
| outputs = gr.components.Label() | |
| examples = "examples" | |
| model = torch.load("model/digit_classifier.pt", map_location=torch.device("cpu")) | |
| labels = [str(i) for i in range(10)] | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((28, 28)), | |
| transforms.Grayscale(), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: x[0]), | |
| transforms.Lambda(lambda x: x.unsqueeze(0)), | |
| ] | |
| ) | |
| def predict_digit(img): | |
| img = transform(Image.fromarray(img)) | |
| output = model(img) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| return dict(zip(labels, map(float, probs.flatten()[:10]))) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Digit Prediction"): | |
| gr.Interface( | |
| fn=predict_digit, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| ).queue(default_concurrency_limit=5) | |
| demo.launch() | |