import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import io from typing import Dict, Any from src.models.food_classification_model import FoodClassificationModel class VGG16(FoodClassificationModel): """Interface for accessing the VGG-16 model architecture.""" def __init__(self, weights: str = "IMAGENET1K_V1", num_classes: int = 101): """ Initialize VGG-16 strictly from torchvision weights (no local checkpoints). Note: This will not be Food-101 fine-tuned unless you use a hub-published VGG-16 checkpoint. Consider switching to hub-based models for best results. """ # Base model with ImageNet weights self.model = models.vgg16(weights=weights) num_features = self.model.classifier[6].in_features self.model.classifier[6] = nn.Linear(num_features, num_classes) self.model.eval() self.transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def classify(self, image: bytes) -> int: pil_image = Image.open(io.BytesIO(image)) if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") input_tensor = self.transform(pil_image) input_batch = input_tensor.unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = self.model(input_batch) predicted_idx = torch.argmax(outputs, dim=1).item() return predicted_idx