Spaces:
Sleeping
Sleeping
File size: 1,775 Bytes
21fb9ff 3cf4417 21fb9ff 3cf4417 21fb9ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import io
from typing import Dict, Any
from src.models.food_classification_model import FoodClassificationModel
class VGG16(FoodClassificationModel):
"""Interface for accessing the VGG-16 model architecture."""
def __init__(self, weights: str = "IMAGENET1K_V1", num_classes: int = 101):
"""
Initialize VGG-16 strictly from torchvision weights (no local checkpoints).
Note: This will not be Food-101 fine-tuned unless you use a hub-published
VGG-16 checkpoint. Consider switching to hub-based models for best results.
"""
# Base model with ImageNet weights
self.model = models.vgg16(weights=weights)
num_features = self.model.classifier[6].in_features
self.model.classifier[6] = nn.Linear(num_features, num_classes)
self.model.eval()
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def classify(self, image: bytes) -> int:
pil_image = Image.open(io.BytesIO(image))
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
input_tensor = self.transform(pil_image)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = self.model(input_batch)
predicted_idx = torch.argmax(outputs, dim=1).item()
return predicted_idx
|