Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import timm | |
| from huggingface_hub import login | |
| from torch import no_grad, softmax, topk | |
| MODEL_NAME = os.getenv("MODEL_NAME") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| login(token=HF_TOKEN) | |
| model = timm.create_model(f"hf_hub:{MODEL_NAME}", pretrained=True) | |
| model.eval() | |
| data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) | |
| transform = timm.data.create_transform(**data_cfg) | |
| def classify_image(input): | |
| inp = transform(input) | |
| with no_grad(): | |
| output = model(inp.unsqueeze(0)) | |
| probabilities = softmax(output[0], dim=0) | |
| values, indices = topk(probabilities, 3) | |
| return { | |
| model.pretrained_cfg["label_names"][str(id.item())].title(): prob | |
| for id, prob in zip(indices, values) | |
| } | |
| demo = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", sources=["upload", "clipboard"]), | |
| outputs=gr.Label(num_top_classes=3), | |
| allow_flagging="never", | |
| examples="examples", | |
| ) | |
| demo.queue() | |
| demo.launch(debug=True) | |