File size: 4,326 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Cross Stage Partial Layer.

Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
"""

from __future__ import annotations

import torch
from torch import nn

from .conv2d import Conv2d


class DarknetBottleneck(nn.Module):
    """The basic bottleneck block used in Darknet.

    Each ResBlock consists of two Conv blocks and the input is added to the
    final output. Each block is composed of Conv, BN, and SiLU.
    The first convolutional layer has filter size of 1x1 and the second one
    has filter size of 3x3.

    Args:
        in_channels (int): The input channels of this Module.
        out_channels (int): The output channels of this Module.
        expansion (float, optional): The kernel size of the convolution.
            Defaults to 0.5.
        add_identity (bool, optional): Whether to add identity to the output.
            Defaults to True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        expansion: float = 0.5,
        add_identity: bool = True,
    ):
        """Init."""
        super().__init__()
        hidden_channels = int(out_channels * expansion)
        self.conv1 = Conv2d(
            in_channels,
            hidden_channels,
            1,
            bias=False,
            norm=nn.BatchNorm2d(hidden_channels, eps=0.001, momentum=0.03),
            activation=nn.SiLU(inplace=True),
        )
        self.conv2 = Conv2d(
            hidden_channels,
            out_channels,
            3,
            stride=1,
            padding=1,
            bias=False,
            norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
            activation=nn.SiLU(inplace=True),
        )
        self.add_identity = add_identity and in_channels == out_channels

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            features (torch.Tensor): Input features.
        """
        identity = features
        out = self.conv1(features)
        out = self.conv2(out)

        if self.add_identity:
            return out + identity
        return out


class CSPLayer(nn.Module):
    """Cross Stage Partial Layer.

    Args:
        in_channels (int): The input channels of the CSP layer.
        out_channels (int): The output channels of the CSP layer.
        expand_ratio (float, optional): Ratio to adjust the number of channels
            of the hidden layer. Defaults to 0.5.
        num_blocks (int, optional): Number of blocks. Defaults to 1.
        add_identity (bool, optional): Whether to add identity in blocks.
            Defaults to True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        expand_ratio: float = 0.5,
        num_blocks: int = 1,
        add_identity: bool = True,
    ):
        """Init."""
        super().__init__()
        mid_channels = int(out_channels * expand_ratio)
        self.main_conv = Conv2d(
            in_channels,
            mid_channels,
            1,
            bias=False,
            norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03),
            activation=nn.SiLU(inplace=True),
        )
        self.short_conv = Conv2d(
            in_channels,
            mid_channels,
            1,
            bias=False,
            norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03),
            activation=nn.SiLU(inplace=True),
        )
        self.final_conv = Conv2d(
            2 * mid_channels,
            out_channels,
            1,
            bias=False,
            norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
            activation=nn.SiLU(inplace=True),
        )

        self.blocks = nn.Sequential(
            *[
                DarknetBottleneck(
                    mid_channels, mid_channels, 1.0, add_identity
                )
                for _ in range(num_blocks)
            ]
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            features (torch.Tensor): Input features.
        """
        x_short = self.short_conv(features)

        x_main = self.main_conv(features)
        x_main = self.blocks(x_main)

        x_final = torch.cat((x_main, x_short), dim=1)
        return self.final_conv(x_final)