Hiera_img / app.py
sudhir2016's picture
Update app.py
9d91775
raw
history blame contribute delete
999 Bytes
import gradio as gr
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import hiera
df=pd.read_csv('Imagenet.txt',usecols=[0],header=None)
model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k")
input_size = 224
transform_list = [
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(input_size)
]
transform_norm = transforms.Compose(transform_list + [
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
def recognize(img):
img1=img.resize((224,224))
img_norm = transform_norm(img1)
output = model(img_norm[None,])
out=output.argmax(dim=-1).item()
out1=(df.iloc[out,0])
return out1
demo = gr.Interface(fn=recognize, inputs='pil',outputs='text',examples= [['Banana.jpg']])
demo.launch()