Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from IJEPA_finetune import ViTIJEPA | |
| import torch | |
| from einops import rearrange | |
| from torchvision.transforms import Compose | |
| import torchvision | |
| classes = ['Acanthostichus', | |
| 'Aenictus', | |
| 'Amblyopone', | |
| 'Attini', | |
| 'Bothriomyrmecini', | |
| 'Camponotini', | |
| 'Cerapachys', | |
| 'Cheliomyrmex', | |
| 'Crematogastrini', | |
| 'Cylindromyrmex', | |
| 'Dolichoderini', | |
| 'Dorylus', | |
| 'Eciton', | |
| 'Ectatommini', | |
| 'Formicini', | |
| 'Fulakora', | |
| 'Gesomyrmecini', | |
| 'Gigantiopini', | |
| 'Heteroponerini', | |
| 'Labidus', | |
| 'Lasiini', | |
| 'Leptomyrmecini', | |
| 'Lioponera', | |
| 'Melophorini', | |
| 'Myopopone', | |
| 'Myrmecia', | |
| 'Myrmelachistini', | |
| 'Myrmicini', | |
| 'Myrmoteratini', | |
| 'Mystrium', | |
| 'Neivamyrmex', | |
| 'Nomamyrmex', | |
| 'Oecophyllini', | |
| 'Ooceraea', | |
| 'Paraponera', | |
| 'Parasyscia', | |
| 'Plagiolepidini', | |
| 'Platythyreini', | |
| 'Pogonomyrmecini', | |
| 'Ponerini', | |
| 'Prionopelta', | |
| 'Probolomyrmecini', | |
| 'Proceratiini', | |
| 'Pseudomyrmex', | |
| 'Solenopsidini', | |
| 'Stenammini', | |
| 'Stigmatomma', | |
| 'Syscia', | |
| 'Tapinomini', | |
| 'Tetraponera', | |
| 'Zasphinctus'] | |
| class_to_idx = {idx: cls for idx, cls in enumerate(classes)} | |
| tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)]) | |
| model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) | |
| model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) | |
| def ant_genus_classification(image): | |
| image = torch.Tensor(image) | |
| image = image.unsqueeze(0) | |
| image = rearrange(image, 'b h w c -> b c h w') | |
| image = tf(image) | |
| print(image.shape) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(image)[0], dim=0) | |
| # print(prediction.tolist()) | |
| confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))} | |
| return confidences | |
| # prediction = model(image)[0] | |
| # prediction = prediction.tolist() | |
| # print(prediction) | |
| # return { | |
| # class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01 | |
| # } | |
| demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3)) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |