File size: 1,775 Bytes
21fb9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf4417
 
 
21fb9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf4417
21fb9ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import io
from typing import Dict, Any
from src.models.food_classification_model import FoodClassificationModel


class VGG16(FoodClassificationModel):
    """Interface for accessing the VGG-16 model architecture."""

    def __init__(self, weights: str = "IMAGENET1K_V1", num_classes: int = 101):
        """
        Initialize VGG-16 strictly from torchvision weights (no local checkpoints).
        Note: This will not be Food-101 fine-tuned unless you use a hub-published
        VGG-16 checkpoint. Consider switching to hub-based models for best results.
        """
        # Base model with ImageNet weights
        self.model = models.vgg16(weights=weights)

        num_features = self.model.classifier[6].in_features
        self.model.classifier[6] = nn.Linear(num_features, num_classes)

        self.model.eval()

        self.transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def classify(self, image: bytes) -> int:
        pil_image = Image.open(io.BytesIO(image))

        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        input_tensor = self.transform(pil_image)
        input_batch = input_tensor.unsqueeze(0)  # Add batch dimension

        with torch.no_grad():
            outputs = self.model(input_batch)
            predicted_idx = torch.argmax(outputs, dim=1).item()

        return predicted_idx