Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,791 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""FCN Head for semantic segmentation."""
from __future__ import annotations
from typing import NamedTuple
import torch
import torch.nn.functional as F
from torch import nn
class FCNOut(NamedTuple):
"""Output of the FCN prediction."""
pred: torch.Tensor # logits for final prediction, (N, C, H, W)
outputs: list[torch.Tensor] # transformed feature maps
class FCNHead(nn.Module):
"""FCN Head made with ResNet base model.
This is based on the implementation in `torchvision
<https://github.com/pytorch/vision/blob/torchvision/models/segmentation/
fcn.py>`_.
"""
def __init__(
self,
in_channels: list[int],
out_channels: int,
dropout_prob: float = 0.1,
resize: tuple[int, int] | None = None,
) -> None:
"""Creates an instance of the class.
Args:
in_channels (list[int]): Number of channels in multi-level image
feature.
out_channels (int): Number of output channels. Usually the number
of classes.
dropout_prob (float, optional): Dropout probability. Defaults to
0.1.
resize (tuple(int,int), optional): Target shape to resize output.
Defaults to None.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.resize = resize
self.heads = nn.ModuleList()
for in_channel in self.in_channels:
self.heads.append(
self._make_head(in_channel, self.out_channels, dropout_prob)
)
def _make_head(
self, in_channels: int, channels: int, dropout_prob: float
) -> nn.Module:
"""Generate FCN segmentation head.
Args:
in_channels (int): Input feature channels.
channels (int): Output segmentation channels.
dropout_prob (float): Dropout probability.
Returns:
nn.Module: FCN segmentation head.
"""
inter_channels = in_channels // 4
layers = [
nn.Conv2d(
in_channels,
inter_channels,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Conv2d(inter_channels, channels, kernel_size=1),
]
return nn.Sequential(*layers)
def forward(self, feats: list[torch.Tensor]) -> FCNOut:
"""Transforms feature maps and returns segmentation prediction.
Args:
feats (list[torch.Tensor]): List of multi-level image features.
Returns:
output (list[torch.Tensor]): Each tensor has shape (batch_size,
self.channels, H, W) which is prediction for each FCN stages. E.g.,
outputs[-1] ==> main output map
outputs[-2] ==> aux output map (e.g., used for training)
outputs[:-2] ==> x[:-2]
"""
outputs = feats.copy()
num_features = len(feats)
for i in range(len(self.in_channels)):
idx = num_features - len(self.in_channels) + i
feat = feats[idx]
output = self.heads[i](feat)
if self.resize:
output = F.interpolate(
output,
size=self.resize,
mode="bilinear",
align_corners=False,
)
outputs[idx] = F.log_softmax(output, dim=1)
return FCNOut(pred=outputs[-1], outputs=outputs)
def __call__(self, feats: list[torch.Tensor]) -> FCNOut:
"""Type definition for function call."""
return super()._call_impl(feats)
|