File size: 3,581 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Residual networks for classification."""

from __future__ import annotations

import torch
import torchvision.models.vgg as _vgg
from torchvision.models._utils import IntermediateLayerGetter

from .base import BaseModel


class VGG(BaseModel):
    """Wrapper for torch vision VGG."""

    def __init__(
        self,
        vgg_name: str,
        trainable_layers: None | int = None,
        pretrained: bool = False,
    ):
        """Initialize the VGG base model from torchvision.

        Args:
            vgg_name (str): name of the VGG variant. Choices in ["vgg11",
                "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn",
                "vgg19_bn"].
            trainable_layers (int, optional): Number layers for training or
                fine-tuning. None means all the layers can be fine-tuned.
            pretrained (bool, optional): Whether to load ImageNet
                pre-trained weights. Defaults to False.

        Raises:
            ValueError: The VGG name is not supported
        """
        super().__init__()
        if vgg_name not in [
            "vgg11",
            "vgg13",
            "vgg16",
            "vgg19",
            "vgg11_bn",
            "vgg13_bn",
            "vgg16_bn",
            "vgg19_bn",
        ]:
            raise ValueError("The VGG name is not supported!")

        weights = "IMAGENET1K_V1" if pretrained else None
        vgg = _vgg.__dict__[vgg_name](weights=weights)
        use_bn = vgg_name[-3:] == "_bn"
        self._out_channels: list[int] = []
        returned_layers = []
        last_channel = -1
        layer_counter = 0

        vgg_channels = _vgg.cfgs[
            {"vgg11": "A", "vgg13": "B", "vgg16": "D", "vgg19": "E"}[
                vgg_name[:5]
            ]
        ]
        for channel in vgg_channels:
            if channel == "M":
                returned_layers.append(layer_counter)
                self._out_channels.append(last_channel)
                layer_counter += 1
            else:
                if use_bn:
                    layer_counter += 3
                else:
                    layer_counter += 2
                last_channel = channel

        if trainable_layers is not None:
            for name, parameter in vgg.features.named_parameters():
                layer_ind = int(name.split(".")[0])
                if layer_ind < layer_counter - trainable_layers:
                    parameter.requires_grad_(False)

        return_layers = {str(v): str(i) for i, v in enumerate(returned_layers)}
        self.body = IntermediateLayerGetter(
            vgg.features, return_layers=return_layers
        )
        self.name = vgg_name

    @property
    def out_channels(self) -> list[int]:
        """Get the number of channels for each level of feature pyramid.

        Returns:
            list[int]: number of channels
        """
        return [3, 3, *self._out_channels]

    def forward(self, images: torch.Tensor) -> list[torch.Tensor]:
        """VGG feature forward without classification head.

        Args:
            images (Tensor[N, C, H, W]): Image input to process. Expected to
                type float32 with values ranging 0..255.

        Returns:
            fp (list[torch.Tensor]): The output feature pyramid. The list index
            represents the level, which has a downsampling raio of 2^index.
            fp[0] and fp[1] is a reference to the input images. The last
            feature map downsamples the input image by 64.
        """
        return [images, images, *self.body(images).values()]