Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import json | |
| from pathlib import Path | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| class ModelPredictor: | |
| def __init__( | |
| self, | |
| model_repo: str, | |
| model_filename: str, | |
| device: str = None, | |
| ): | |
| self.device = ( | |
| device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| # Load the model | |
| checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename) | |
| self.model = self.load_model(checkpoint_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Setup transforms | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), | |
| ] | |
| ) | |
| # Load ImageNet class labels | |
| self.class_labels = self.load_imagenet_labels() | |
| def load_model(self, checkpoint_path: str): | |
| """Load the trained model from checkpoint""" | |
| from pl_train import ImageNetModule | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| model = ImageNetModule( | |
| learning_rate=0.156, | |
| batch_size=1, | |
| num_workers=0, # Set to 0 for Gradio | |
| max_epochs=40, | |
| train_path="", | |
| val_path="", | |
| checkpoint_dir="", | |
| ) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| return model | |
| def load_imagenet_labels(self): | |
| """Load ImageNet class labels""" | |
| # For HuggingFace Spaces, we'll look for the labels file in the same directory | |
| labels_path = Path("data/imagenet-simple-labels.json") | |
| if labels_path.exists(): | |
| with open(labels_path) as f: | |
| data = json.load(f) | |
| return {str(i + 1): name for i, name in enumerate(data)} | |
| return {str(i): f"class_{i}" for i in range(1000)} # Fallback | |
| def predict(self, image): | |
| """ | |
| Make prediction for a single image | |
| Args: | |
| image: numpy array from Gradio | |
| Returns: | |
| Dictionary of class labels and probabilities | |
| """ | |
| try: | |
| # Convert numpy array to PIL Image | |
| if isinstance(image, np.ndarray): | |
| # If image is from Gradio, it will be a numpy array | |
| image = Image.fromarray(image.astype("uint8")) | |
| elif isinstance(image, str): | |
| # If image is a file path | |
| image = Image.open(image) | |
| # Ensure image is in RGB mode | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Apply transforms and predict | |
| image_tensor = self.transform(image).unsqueeze(0) | |
| image_tensor = image_tensor.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| # Get top 5 predictions | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| # Create results dictionary | |
| results = {} | |
| for prob, idx in zip(top_probs[0], top_indices[0]): | |
| class_name = self.class_labels[str(idx.item())] | |
| results[class_name] = float(prob) | |
| return results | |
| except Exception as e: | |
| print(f"Error in prediction: {str(e)}") | |
| return {"error": 1.0} | |
| # Initialize the predictor | |
| try: | |
| predictor = ModelPredictor( | |
| model_repo="Adityak204/ResNetVision-1K", # Replace with your repo | |
| model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename | |
| ) | |
| except Exception as e: | |
| print(f"Error initializing predictor: {str(e)}") | |
| def predict_image(image): | |
| """ | |
| Gradio interface function | |
| Args: | |
| image: numpy array from Gradio's image input | |
| Returns: | |
| Dictionary of predictions formatted for display | |
| """ | |
| if image is None: | |
| return {"Error: No image provided": 1.0} | |
| try: | |
| predictions = predictor.predict(image) | |
| # Format results for display | |
| return predictions | |
| except Exception as e: | |
| print(f"Error in predict_image: {str(e)}") | |
| return {"Error: Failed to process image": 1.0} | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=5), | |
| title="ImageNet-1K Classification", | |
| description="Upload an image to classify it into one of 1000 ImageNet categories", | |
| # examples=( | |
| # [ | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], | |
| # ] | |
| # if all( | |
| # Path(f).exists() | |
| # for f in [ | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"], | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"], | |
| # ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"], | |
| # ] | |
| # ) | |
| # else None | |
| # ), | |
| analytics_enabled=False, | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch() | |