Theo Viel
commited on
Commit
·
8baa7cf
1
Parent(s):
bbf44bf
Code
Browse files- graphic_element_v1.py +83 -0
- model.py +219 -0
- post_processing/graphic_elt_pp.py +118 -0
- utils.py +201 -0
- yolox/__init__.py +7 -0
- yolox/boxes.py +55 -0
- yolox/darknet.py +179 -0
- yolox/network_blocks.py +210 -0
- yolox/yolo_fpn.py +84 -0
- yolox/yolo_head.py +235 -0
- yolox/yolo_pafpn.py +116 -0
- yolox/yolox.py +32 -0
graphic_element_v1.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Exp:
|
| 7 |
+
"""
|
| 8 |
+
Configuration class for the graphic element model.
|
| 9 |
+
|
| 10 |
+
This class contains all configuration parameters for the YOLOX-based
|
| 11 |
+
page element detection model, including architecture settings, inference
|
| 12 |
+
parameters, and class-specific thresholds.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
"""Initialize the configuration with default parameters."""
|
| 17 |
+
self.name: str = "graphic-element-v1"
|
| 18 |
+
self.ckpt: str = "weights.pth"
|
| 19 |
+
self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# YOLOX architecture parameters
|
| 22 |
+
self.act: str = "silu"
|
| 23 |
+
self.depth: float = 1.00
|
| 24 |
+
self.width: float = 1.00
|
| 25 |
+
self.labels: List[str] = [
|
| 26 |
+
"chart_title",
|
| 27 |
+
"x_title",
|
| 28 |
+
"y_title",
|
| 29 |
+
"xlabel",
|
| 30 |
+
"ylabel",
|
| 31 |
+
"other",
|
| 32 |
+
"legend_label",
|
| 33 |
+
"legend_title",
|
| 34 |
+
"mark_label",
|
| 35 |
+
"value_label",
|
| 36 |
+
]
|
| 37 |
+
self.num_classes: int = len(self.labels)
|
| 38 |
+
|
| 39 |
+
# Inference parameters
|
| 40 |
+
self.size: Tuple[int, int] = (1024, 1024)
|
| 41 |
+
self.min_bbox_size: int = 0
|
| 42 |
+
self.normalize_boxes: bool = True
|
| 43 |
+
|
| 44 |
+
# NMS & thresholding. These can be updated
|
| 45 |
+
self.conf_thresh: float = 0.01
|
| 46 |
+
self.iou_thresh: float = 0.25
|
| 47 |
+
self.class_agnostic: bool = True # False
|
| 48 |
+
|
| 49 |
+
self.threshold: float = 0.1
|
| 50 |
+
|
| 51 |
+
def get_model(self) -> nn.Module:
|
| 52 |
+
"""
|
| 53 |
+
Get the YOLOX model.
|
| 54 |
+
|
| 55 |
+
Builds and returns a YOLOX model with the configured architecture.
|
| 56 |
+
Also updates batch normalization parameters for optimal inference.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
nn.Module: The YOLOX model with configured parameters.
|
| 60 |
+
"""
|
| 61 |
+
from yolox import YOLOX, YOLOPAFPN, YOLOXHead
|
| 62 |
+
|
| 63 |
+
# Build model
|
| 64 |
+
if getattr(self, "model", None) is None:
|
| 65 |
+
in_channels = [256, 512, 1024]
|
| 66 |
+
backbone = YOLOPAFPN(
|
| 67 |
+
self.depth, self.width, in_channels=in_channels, act=self.act
|
| 68 |
+
)
|
| 69 |
+
head = YOLOXHead(
|
| 70 |
+
self.num_classes, self.width, in_channels=in_channels, act=self.act
|
| 71 |
+
)
|
| 72 |
+
self.model = YOLOX(backbone, head)
|
| 73 |
+
|
| 74 |
+
# Update batch-norm parameters
|
| 75 |
+
def init_yolo(M: nn.Module) -> None:
|
| 76 |
+
for m in M.modules():
|
| 77 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 78 |
+
m.eps = 1e-3
|
| 79 |
+
m.momentum = 0.03
|
| 80 |
+
|
| 81 |
+
self.model.apply(init_yolo)
|
| 82 |
+
|
| 83 |
+
return self.model
|
model.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import importlib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import numpy.typing as npt
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, List, Tuple, Union
|
| 10 |
+
from yolox.boxes import postprocess
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module:
|
| 14 |
+
"""
|
| 15 |
+
Defines and initializes the model based on the configuration.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config_name (str): Configuration name. Defaults to "graphic_element_v1".
|
| 19 |
+
verbose (bool): Whether to print verbose output. Defaults to True.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.nn.Module: The initialized YOLOX model.
|
| 23 |
+
"""
|
| 24 |
+
# Load model from exp_file
|
| 25 |
+
sys.path.append(os.path.dirname(config_name))
|
| 26 |
+
exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0])
|
| 27 |
+
|
| 28 |
+
config = exp_module.Exp()
|
| 29 |
+
model = config.get_model()
|
| 30 |
+
|
| 31 |
+
# Load weights
|
| 32 |
+
if verbose:
|
| 33 |
+
print(" -> Loading weights from", config.ckpt)
|
| 34 |
+
|
| 35 |
+
ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
|
| 36 |
+
model.load_state_dict(ckpt["model"], strict=True)
|
| 37 |
+
|
| 38 |
+
model = YoloXWrapper(model, config)
|
| 39 |
+
return model.eval().to(config.device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Resizes and pads an image to a given size.
|
| 45 |
+
The goal is to preserve the aspect ratio of the image.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
img (torch.Tensor[C x H x W]): The image to resize and pad.
|
| 49 |
+
size (tuple[2]): The size to resize and pad the image to.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
torch.Tensor: The resized and padded image.
|
| 53 |
+
"""
|
| 54 |
+
img = img.float()
|
| 55 |
+
_, h, w = img.shape
|
| 56 |
+
scale = min(size[0] / h, size[1] / w)
|
| 57 |
+
nh = int(h * scale)
|
| 58 |
+
nw = int(w * scale)
|
| 59 |
+
img = F.interpolate(
|
| 60 |
+
img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False
|
| 61 |
+
).squeeze(0)
|
| 62 |
+
img = torch.clamp(img, 0, 255)
|
| 63 |
+
pad_b = size[0] - nh
|
| 64 |
+
pad_r = size[1] - nw
|
| 65 |
+
img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0)
|
| 66 |
+
return img
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class YoloXWrapper(nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Wrapper for YoloX models.
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self, model: nn.Module, config) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Constructor
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model (torch model): Yolo model.
|
| 79 |
+
config (Config): Config object containing model parameters.
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.model = model
|
| 83 |
+
self.config = config
|
| 84 |
+
|
| 85 |
+
# Copy config parameters
|
| 86 |
+
self.device = config.device
|
| 87 |
+
self.img_size = config.size
|
| 88 |
+
self.min_bbox_size = config.min_bbox_size
|
| 89 |
+
self.normalize_boxes = config.normalize_boxes
|
| 90 |
+
self.conf_thresh = config.conf_thresh
|
| 91 |
+
self.iou_thresh = config.iou_thresh
|
| 92 |
+
self.class_agnostic = config.class_agnostic
|
| 93 |
+
self.threshold = config.threshold
|
| 94 |
+
self.labels = config.labels
|
| 95 |
+
self.num_classes = config.num_classes
|
| 96 |
+
|
| 97 |
+
def reformat_input(
|
| 98 |
+
self,
|
| 99 |
+
x: torch.Tensor,
|
| 100 |
+
orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
|
| 101 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 102 |
+
"""
|
| 103 |
+
Reformats the input data and original sizes to the correct format.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
x (torch.Tensor[BS x C x H x W]): Input image batch.
|
| 107 |
+
orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
|
| 108 |
+
Returns:
|
| 109 |
+
torch tensor [BS x C x H x W]: Input image batch.
|
| 110 |
+
torch tensor [BS x 2]: Original image sizes (before resizing and padding).
|
| 111 |
+
"""
|
| 112 |
+
# Convert image size to tensor
|
| 113 |
+
if isinstance(orig_sizes, (list, tuple)):
|
| 114 |
+
orig_sizes = np.array(orig_sizes)
|
| 115 |
+
if orig_sizes.shape[-1] == 3: # remove channel
|
| 116 |
+
orig_sizes = orig_sizes[..., :2]
|
| 117 |
+
if isinstance(orig_sizes, np.ndarray):
|
| 118 |
+
orig_sizes = torch.from_numpy(orig_sizes).to(self.device)
|
| 119 |
+
|
| 120 |
+
# Add batch dimension if not present
|
| 121 |
+
if len(x.size()) == 3:
|
| 122 |
+
x = x.unsqueeze(0)
|
| 123 |
+
if len(orig_sizes.size()) == 1:
|
| 124 |
+
orig_sizes = orig_sizes.unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
return x, orig_sizes
|
| 127 |
+
|
| 128 |
+
def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
|
| 129 |
+
"""
|
| 130 |
+
YoloX preprocessing function:
|
| 131 |
+
- Resizes to the longest edge to img_size while preserving the aspect ratio
|
| 132 |
+
- Pads the shortest edge to img_size
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
image (torch tensor or np array [H x W x 3]): Input images in uint8 format.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
torch tensor [3 x H x W]: Processed image.
|
| 139 |
+
"""
|
| 140 |
+
if not isinstance(image, torch.Tensor):
|
| 141 |
+
image = torch.from_numpy(image)
|
| 142 |
+
image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
|
| 143 |
+
image = resize_pad(image, self.img_size)
|
| 144 |
+
return image.float()
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self,
|
| 148 |
+
x: torch.Tensor,
|
| 149 |
+
orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
|
| 150 |
+
) -> List[Dict[str, torch.Tensor]]:
|
| 151 |
+
"""
|
| 152 |
+
Forward pass of the model.
|
| 153 |
+
Applies NMS and reformats the predictions.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
x (torch.Tensor[BS x C x H x W]): Input image batch.
|
| 157 |
+
orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
list[dict]: List of prediction dictionaries. Each dictionary contains:
|
| 161 |
+
- labels (torch.Tensor[N]): Class labels
|
| 162 |
+
- boxes (torch.Tensor[N x 4]): Bounding boxes
|
| 163 |
+
- scores (torch.Tensor[N]): Confidence scores.
|
| 164 |
+
"""
|
| 165 |
+
x, orig_sizes = self.reformat_input(x, orig_sizes)
|
| 166 |
+
|
| 167 |
+
# Scale to 0-255 if in range 0-1
|
| 168 |
+
if x.max() <= 1:
|
| 169 |
+
x *= 255
|
| 170 |
+
|
| 171 |
+
pred_boxes = self.model(x.to(self.device))
|
| 172 |
+
|
| 173 |
+
# NMS
|
| 174 |
+
pred_boxes = postprocess(
|
| 175 |
+
pred_boxes,
|
| 176 |
+
self.config.num_classes,
|
| 177 |
+
self.conf_thresh,
|
| 178 |
+
self.iou_thresh,
|
| 179 |
+
class_agnostic=self.class_agnostic,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Reformat output
|
| 183 |
+
preds = []
|
| 184 |
+
for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)):
|
| 185 |
+
if p is None: # No detections
|
| 186 |
+
preds.append({
|
| 187 |
+
"labels": torch.empty(0),
|
| 188 |
+
"boxes": torch.empty((0, 4)),
|
| 189 |
+
"scores": torch.empty(0),
|
| 190 |
+
})
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
p = p.view(-1, p.size(-1))
|
| 194 |
+
ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1])
|
| 195 |
+
boxes = p[:, :4] / ratio
|
| 196 |
+
|
| 197 |
+
# Clip
|
| 198 |
+
boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1])
|
| 199 |
+
boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0])
|
| 200 |
+
|
| 201 |
+
# Remove too small
|
| 202 |
+
kept = (
|
| 203 |
+
(boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) &
|
| 204 |
+
(boxes[:, 3] - boxes[:, 1] > self.min_bbox_size)
|
| 205 |
+
)
|
| 206 |
+
boxes = boxes[kept]
|
| 207 |
+
p = p[kept]
|
| 208 |
+
|
| 209 |
+
# Normalize to 0-1
|
| 210 |
+
if self.normalize_boxes:
|
| 211 |
+
boxes[:, [0, 2]] /= size[1]
|
| 212 |
+
boxes[:, [1, 3]] /= size[0]
|
| 213 |
+
|
| 214 |
+
scores = p[:, 4] * p[:, 5]
|
| 215 |
+
labels = p[:, 6]
|
| 216 |
+
|
| 217 |
+
preds.append({"labels": labels, "boxes": boxes, "scores": scores})
|
| 218 |
+
|
| 219 |
+
return preds
|
post_processing/graphic_elt_pp.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import numpy.typing as npt
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def bb_iou_array(
|
| 7 |
+
boxes: npt.NDArray[np.float64], new_box: npt.NDArray[np.float64]
|
| 8 |
+
) -> npt.NDArray[np.float64]:
|
| 9 |
+
"""
|
| 10 |
+
Calculates the Intersection over Union (IoU) between a box and an array of boxes.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 14 |
+
new_box (numpy.ndarray): A single bounding box [x_min, y_min, x_max, y_max].
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
numpy.ndarray: Array of IoU values between the new_box and each box in the array.
|
| 18 |
+
"""
|
| 19 |
+
# bb interesection over union
|
| 20 |
+
xA = np.maximum(boxes[:, 0], new_box[0])
|
| 21 |
+
yA = np.maximum(boxes[:, 1], new_box[1])
|
| 22 |
+
xB = np.minimum(boxes[:, 2], new_box[2])
|
| 23 |
+
yB = np.minimum(boxes[:, 3], new_box[3])
|
| 24 |
+
|
| 25 |
+
interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
|
| 26 |
+
|
| 27 |
+
# compute the area of both the prediction and ground-truth rectangles
|
| 28 |
+
boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 29 |
+
boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
|
| 30 |
+
|
| 31 |
+
iou = interArea / (boxAArea + boxBArea - interArea)
|
| 32 |
+
|
| 33 |
+
return iou
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def expand_boxes(
|
| 37 |
+
boxes: npt.NDArray[np.float64],
|
| 38 |
+
r_x: Tuple[float, float] = (1, 1),
|
| 39 |
+
r_y: Tuple[float, float] = (1, 1),
|
| 40 |
+
size_agnostic: bool = True,
|
| 41 |
+
) -> npt.NDArray[np.float64]:
|
| 42 |
+
"""
|
| 43 |
+
Expands bounding boxes by a specified ratio.
|
| 44 |
+
Expected box format is normalized [x_min, y_min, x_max, y_max].
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 48 |
+
r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
|
| 49 |
+
r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
|
| 50 |
+
size_agnostic (bool, optional): Expand independently of the box shape. Defaults to True.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
|
| 54 |
+
"""
|
| 55 |
+
old_boxes = boxes.copy()
|
| 56 |
+
|
| 57 |
+
if not size_agnostic:
|
| 58 |
+
h = boxes[:, 3] - boxes[:, 1]
|
| 59 |
+
w = boxes[:, 2] - boxes[:, 0]
|
| 60 |
+
else:
|
| 61 |
+
h, w = 1, 1
|
| 62 |
+
|
| 63 |
+
boxes[:, 0] -= w * (r_x[0] - 1) # left
|
| 64 |
+
boxes[:, 2] += w * (r_x[1] - 1) # right
|
| 65 |
+
boxes[:, 1] -= h * (r_y[0] - 1) # up
|
| 66 |
+
boxes[:, 3] += h * (r_y[1] - 1) # down
|
| 67 |
+
|
| 68 |
+
boxes = np.clip(boxes, 0, 1)
|
| 69 |
+
|
| 70 |
+
# Enforce non-overlapping boxes
|
| 71 |
+
for i in range(len(boxes)):
|
| 72 |
+
for j in range(i + 1, len(boxes)):
|
| 73 |
+
iou = bb_iou_array(boxes[i][None], boxes[j])[0]
|
| 74 |
+
old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
|
| 75 |
+
# print(iou, old_iou)
|
| 76 |
+
if iou > 0.05 and old_iou < 0.1:
|
| 77 |
+
if boxes[i, 1] < boxes[j, 1]: # i above j
|
| 78 |
+
boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
|
| 79 |
+
if old_iou > 0:
|
| 80 |
+
boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
|
| 81 |
+
else:
|
| 82 |
+
boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
|
| 83 |
+
if old_iou > 0:
|
| 84 |
+
boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
|
| 85 |
+
|
| 86 |
+
return boxes
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def retrieve_title(
|
| 90 |
+
boxes: npt.NDArray[np.float64],
|
| 91 |
+
labels: npt.NDArray[np.int_],
|
| 92 |
+
scores: npt.NDArray[np.float64],
|
| 93 |
+
classes: List[str],
|
| 94 |
+
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
|
| 95 |
+
"""
|
| 96 |
+
Retrieves missed captions by using the biggest `other` box.
|
| 97 |
+
|
| 98 |
+
If no chart_title is detected, this function finds the largest box
|
| 99 |
+
labeled as 'other' (with width > 0.3) and relabels it as 'chart_title'.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
|
| 103 |
+
labels (numpy.ndarray): Array of labels with shape (N,).
|
| 104 |
+
scores (numpy.ndarray): Array of confidence scores with shape (N,).
|
| 105 |
+
classes (list): List of class labels.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
numpy.ndarray [N x 4]: Array of bounding boxes (unchanged).
|
| 109 |
+
numpy.ndarray [N]: Array of labels (potentially modified).
|
| 110 |
+
numpy.ndarray [N]: Array of scores (unchanged).
|
| 111 |
+
"""
|
| 112 |
+
if classes.index("chart_title") not in labels:
|
| 113 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
| 114 |
+
scores = widths * (labels == classes.index("other")) * (widths > 0.3)
|
| 115 |
+
replaced = np.argmax(scores) if max(scores) > 0 else None
|
| 116 |
+
if replaced is not None:
|
| 117 |
+
labels[replaced] = classes.index("chart_title")
|
| 118 |
+
return boxes, labels, scores
|
utils.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy.typing as npt
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from matplotlib.patches import Rectangle
|
| 7 |
+
from typing import Dict, List, Tuple, Optional, Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
COLORS = [
|
| 11 |
+
"#003EFF",
|
| 12 |
+
"#FF8F00",
|
| 13 |
+
"#079700",
|
| 14 |
+
"#A123FF",
|
| 15 |
+
"#87CEEB",
|
| 16 |
+
"#FF5733",
|
| 17 |
+
"#C70039",
|
| 18 |
+
"#900C3F",
|
| 19 |
+
"#581845",
|
| 20 |
+
"#11998E",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def reformat_for_plotting(
|
| 25 |
+
boxes: npt.NDArray[np.float64],
|
| 26 |
+
labels: npt.NDArray[np.int_],
|
| 27 |
+
scores: npt.NDArray[np.float64],
|
| 28 |
+
shape: Tuple[int, int, int],
|
| 29 |
+
num_classes: int,
|
| 30 |
+
) -> Tuple[List[npt.NDArray[np.int_]], List[npt.NDArray[np.float64]]]:
|
| 31 |
+
"""
|
| 32 |
+
Reformat YOLOX predictions for plotting.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
boxes (np.ndarray): Array of bounding boxes.
|
| 36 |
+
labels (np.ndarray): Array of labels.
|
| 37 |
+
scores (np.ndarray): Array of confidence scores.
|
| 38 |
+
shape (tuple): Shape of the image.
|
| 39 |
+
num_classes (int): Number of classes.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
list[np.ndarray]: List of box bounding boxes per class.
|
| 43 |
+
list[np.ndarray]: List of confidence scores per class.
|
| 44 |
+
"""
|
| 45 |
+
boxes_plot = boxes.copy()
|
| 46 |
+
boxes_plot[:, [0, 2]] *= shape[1]
|
| 47 |
+
boxes_plot[:, [1, 3]] *= shape[0]
|
| 48 |
+
boxes_plot = boxes_plot.astype(int)
|
| 49 |
+
boxes_plot[:, 2] -= boxes_plot[:, 0]
|
| 50 |
+
boxes_plot[:, 3] -= boxes_plot[:, 1]
|
| 51 |
+
boxes_plot = [boxes_plot[labels == c] for c in range(num_classes)]
|
| 52 |
+
confs = [scores[labels == c] for c in range(num_classes)]
|
| 53 |
+
return boxes_plot, confs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def plot_sample(
|
| 57 |
+
img: npt.NDArray[np.uint8],
|
| 58 |
+
boxes_list: List[npt.NDArray[np.int_]],
|
| 59 |
+
confs_list: List[npt.NDArray[np.float64]],
|
| 60 |
+
labels: List[str],
|
| 61 |
+
) -> None:
|
| 62 |
+
"""
|
| 63 |
+
Plots an image with bounding boxes.
|
| 64 |
+
Coordinates are expected in format [x_min, y_min, width, height].
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
img (numpy.ndarray): The input image to be plotted.
|
| 68 |
+
boxes_list (list[np.ndarray]): List of box bounding boxes per class.
|
| 69 |
+
confs_list (list[np.ndarray]): List of confidence scores per class.
|
| 70 |
+
labels (list): List of class labels.
|
| 71 |
+
"""
|
| 72 |
+
plt.imshow(img, cmap="gray")
|
| 73 |
+
plt.axis(False)
|
| 74 |
+
|
| 75 |
+
for boxes, confs, col, l in zip(boxes_list, confs_list, COLORS, labels):
|
| 76 |
+
for box_idx, box in enumerate(boxes):
|
| 77 |
+
# Better display around boundaries
|
| 78 |
+
h, w, _ = img.shape
|
| 79 |
+
box = np.copy(box)
|
| 80 |
+
box[:2] = np.clip(box[:2], 2, max(h, w))
|
| 81 |
+
box[2] = min(box[2], w - 2 - box[0])
|
| 82 |
+
box[3] = min(box[3], h - 2 - box[1])
|
| 83 |
+
|
| 84 |
+
rect = Rectangle(
|
| 85 |
+
(box[0], box[1]),
|
| 86 |
+
box[2],
|
| 87 |
+
box[3],
|
| 88 |
+
linewidth=2,
|
| 89 |
+
facecolor="none",
|
| 90 |
+
edgecolor=col,
|
| 91 |
+
)
|
| 92 |
+
plt.gca().add_patch(rect)
|
| 93 |
+
|
| 94 |
+
# Add class and index label with proper alignment
|
| 95 |
+
plt.text(
|
| 96 |
+
box[0], box[1],
|
| 97 |
+
f"{l}_{box_idx} conf={confs[box_idx]:.3f}",
|
| 98 |
+
color='white',
|
| 99 |
+
fontsize=8,
|
| 100 |
+
bbox=dict(facecolor=col, alpha=1, edgecolor=col, pad=0, linewidth=2),
|
| 101 |
+
verticalalignment='bottom',
|
| 102 |
+
horizontalalignment='left'
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def reorder_boxes(
|
| 107 |
+
boxes: npt.NDArray[np.float64],
|
| 108 |
+
labels: npt.NDArray[np.int_],
|
| 109 |
+
classes: Optional[List[str]] = None,
|
| 110 |
+
scores: Optional[npt.NDArray[np.float64]] = None,
|
| 111 |
+
) -> Union[
|
| 112 |
+
Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_]],
|
| 113 |
+
Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]],
|
| 114 |
+
]:
|
| 115 |
+
"""
|
| 116 |
+
Reorder boxes, labels and scores by box coordinates.
|
| 117 |
+
Ordering depends on the class.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
boxes (np.ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2].
|
| 121 |
+
labels (np.ndarray): Array of labels of shape (N,).
|
| 122 |
+
classes (list, optional): List of class labels. Defaults to None.
|
| 123 |
+
scores (np.ndarray, optional): Array of confidences of shape (N,). Defaults to None.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
np.ndarray [N, 4]: Ordered boxes.
|
| 127 |
+
np.ndarray [N]: Ordered labels.
|
| 128 |
+
np.ndarray [N]: Ordered scores if scores is not None.
|
| 129 |
+
"""
|
| 130 |
+
n_classes = labels.max() if classes is None else len(classes)
|
| 131 |
+
classes = labels.unique() if classes is None else classes
|
| 132 |
+
|
| 133 |
+
ordered_boxes, ordered_labels, ordered_scores = [], [], []
|
| 134 |
+
for c in range(n_classes):
|
| 135 |
+
boxes_class = boxes[labels == c]
|
| 136 |
+
if len(boxes_class):
|
| 137 |
+
# Reorder
|
| 138 |
+
sort = ["y0", "x0"]
|
| 139 |
+
ascending = [True, True]
|
| 140 |
+
if classes[c] == "ylabel":
|
| 141 |
+
ascending = [False, True]
|
| 142 |
+
elif classes[c] == "y_title":
|
| 143 |
+
sort = ["x0", "y0"]
|
| 144 |
+
ascending = [True, False]
|
| 145 |
+
|
| 146 |
+
df_coords = pd.DataFrame({
|
| 147 |
+
"y0": np.round(boxes_class[:, 1] - boxes_class[:, 1].min(), 2),
|
| 148 |
+
"x0": np.round(boxes_class[:, 0] - boxes_class[:, 0].min(), 2),
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
idxs = df_coords.sort_values(sort, ascending=ascending).index
|
| 152 |
+
|
| 153 |
+
ordered_boxes.append(boxes_class[idxs])
|
| 154 |
+
ordered_labels.append(labels[labels == c][idxs])
|
| 155 |
+
|
| 156 |
+
if scores is not None:
|
| 157 |
+
ordered_scores.append(scores[labels == c][idxs])
|
| 158 |
+
|
| 159 |
+
ordered_boxes = np.concatenate(ordered_boxes)
|
| 160 |
+
ordered_labels = np.concatenate(ordered_labels)
|
| 161 |
+
if scores is not None:
|
| 162 |
+
ordered_scores = np.concatenate(ordered_scores)
|
| 163 |
+
return ordered_boxes, ordered_labels, ordered_scores
|
| 164 |
+
return ordered_boxes, ordered_labels
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def postprocess_preds_graphic_element(
|
| 168 |
+
preds: Dict[str, npt.NDArray],
|
| 169 |
+
threshold: float = 0.1,
|
| 170 |
+
class_labels: Optional[List[str]] = None,
|
| 171 |
+
reorder: bool = True,
|
| 172 |
+
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
|
| 173 |
+
"""
|
| 174 |
+
Post process predictions for the page element task.
|
| 175 |
+
- Applies thresholding
|
| 176 |
+
- Reorders boxes using the reading order
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
preds (dict): Predictions. Keys are "scores", "boxes", "labels".
|
| 180 |
+
threshold (float, optional): Threshold for the confidence scores. Defaults to 0.1.
|
| 181 |
+
class_labels (list, optional): List of class labels. Defaults to None.
|
| 182 |
+
reorder (bool, optional): Whether to apply reordering. Defaults to True.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
numpy.ndarray [N x 4]: Array of bounding boxes.
|
| 186 |
+
numpy.ndarray [N]: Array of labels.
|
| 187 |
+
numpy.ndarray [N]: Array of scores.
|
| 188 |
+
"""
|
| 189 |
+
boxes = preds["boxes"].cpu().numpy()
|
| 190 |
+
labels = preds["labels"].cpu().numpy()
|
| 191 |
+
scores = preds["scores"].cpu().numpy()
|
| 192 |
+
|
| 193 |
+
# Threshold
|
| 194 |
+
boxes = boxes[scores > threshold]
|
| 195 |
+
labels = labels[scores > threshold]
|
| 196 |
+
scores = scores[scores > threshold]
|
| 197 |
+
|
| 198 |
+
if len(boxes) > 0 and reorder:
|
| 199 |
+
boxes, labels, scores = reorder_boxes(boxes, labels, class_labels, scores)
|
| 200 |
+
|
| 201 |
+
return boxes, labels, scores
|
yolox/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
from .yolo_head import YOLOXHead
|
| 6 |
+
from .yolo_pafpn import YOLOPAFPN
|
| 7 |
+
from .yolox import YOLOX
|
yolox/boxes.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
|
| 9 |
+
"""
|
| 10 |
+
Copied from YOLOX/yolox/utils/boxes.py
|
| 11 |
+
"""
|
| 12 |
+
box_corner = prediction.new(prediction.shape)
|
| 13 |
+
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
| 14 |
+
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
| 15 |
+
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
| 16 |
+
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
| 17 |
+
prediction[:, :, :4] = box_corner[:, :, :4]
|
| 18 |
+
|
| 19 |
+
output = [None for _ in range(len(prediction))]
|
| 20 |
+
for i, image_pred in enumerate(prediction):
|
| 21 |
+
|
| 22 |
+
# If none are remaining => process next image
|
| 23 |
+
if not image_pred.size(0):
|
| 24 |
+
continue
|
| 25 |
+
# Get score and class with highest confidence
|
| 26 |
+
class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
|
| 27 |
+
|
| 28 |
+
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
|
| 29 |
+
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
|
| 30 |
+
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
|
| 31 |
+
detections = detections[conf_mask]
|
| 32 |
+
if not detections.size(0):
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
if class_agnostic:
|
| 36 |
+
nms_out_index = torchvision.ops.nms(
|
| 37 |
+
detections[:, :4],
|
| 38 |
+
detections[:, 4] * detections[:, 5],
|
| 39 |
+
nms_thre,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
nms_out_index = torchvision.ops.batched_nms(
|
| 43 |
+
detections[:, :4],
|
| 44 |
+
detections[:, 4] * detections[:, 5],
|
| 45 |
+
detections[:, 6],
|
| 46 |
+
nms_thre,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
detections = detections[nms_out_index]
|
| 50 |
+
if output[i] is None:
|
| 51 |
+
output[i] = detections
|
| 52 |
+
else:
|
| 53 |
+
output[i] = torch.cat((output[i], detections))
|
| 54 |
+
|
| 55 |
+
return output
|
yolox/darknet.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Darknet(nn.Module):
|
| 11 |
+
# number of blocks from dark2 to dark5.
|
| 12 |
+
depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
depth,
|
| 17 |
+
in_channels=3,
|
| 18 |
+
stem_out_channels=32,
|
| 19 |
+
out_features=("dark3", "dark4", "dark5"),
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Args:
|
| 23 |
+
depth (int): depth of darknet used in model, usually use [21, 53] for this param.
|
| 24 |
+
in_channels (int): number of input channels, for example, use 3 for RGB image.
|
| 25 |
+
stem_out_channels (int): number of output channels of darknet stem.
|
| 26 |
+
It decides channels of darknet layer2 to layer5.
|
| 27 |
+
out_features (Tuple[str]): desired output layer name.
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert out_features, "please provide output features of Darknet"
|
| 31 |
+
self.out_features = out_features
|
| 32 |
+
self.stem = nn.Sequential(
|
| 33 |
+
BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
|
| 34 |
+
*self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
|
| 35 |
+
)
|
| 36 |
+
in_channels = stem_out_channels * 2 # 64
|
| 37 |
+
|
| 38 |
+
num_blocks = Darknet.depth2blocks[depth]
|
| 39 |
+
# create darknet with `stem_out_channels` and `num_blocks` layers.
|
| 40 |
+
# to make model structure more clear, we don't use `for` statement in python.
|
| 41 |
+
self.dark2 = nn.Sequential(
|
| 42 |
+
*self.make_group_layer(in_channels, num_blocks[0], stride=2)
|
| 43 |
+
)
|
| 44 |
+
in_channels *= 2 # 128
|
| 45 |
+
self.dark3 = nn.Sequential(
|
| 46 |
+
*self.make_group_layer(in_channels, num_blocks[1], stride=2)
|
| 47 |
+
)
|
| 48 |
+
in_channels *= 2 # 256
|
| 49 |
+
self.dark4 = nn.Sequential(
|
| 50 |
+
*self.make_group_layer(in_channels, num_blocks[2], stride=2)
|
| 51 |
+
)
|
| 52 |
+
in_channels *= 2 # 512
|
| 53 |
+
|
| 54 |
+
self.dark5 = nn.Sequential(
|
| 55 |
+
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
|
| 56 |
+
*self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
|
| 60 |
+
"starts with conv layer then has `num_blocks` `ResLayer`"
|
| 61 |
+
return [
|
| 62 |
+
BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
|
| 63 |
+
*[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
def make_spp_block(self, filters_list, in_filters):
|
| 67 |
+
m = nn.Sequential(
|
| 68 |
+
*[
|
| 69 |
+
BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
|
| 70 |
+
BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
|
| 71 |
+
SPPBottleneck(
|
| 72 |
+
in_channels=filters_list[1],
|
| 73 |
+
out_channels=filters_list[0],
|
| 74 |
+
activation="lrelu",
|
| 75 |
+
),
|
| 76 |
+
BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
|
| 77 |
+
BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
return m
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
outputs = {}
|
| 84 |
+
x = self.stem(x)
|
| 85 |
+
outputs["stem"] = x
|
| 86 |
+
x = self.dark2(x)
|
| 87 |
+
outputs["dark2"] = x
|
| 88 |
+
x = self.dark3(x)
|
| 89 |
+
outputs["dark3"] = x
|
| 90 |
+
x = self.dark4(x)
|
| 91 |
+
outputs["dark4"] = x
|
| 92 |
+
x = self.dark5(x)
|
| 93 |
+
outputs["dark5"] = x
|
| 94 |
+
return {k: v for k, v in outputs.items() if k in self.out_features}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CSPDarknet(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
dep_mul,
|
| 101 |
+
wid_mul,
|
| 102 |
+
out_features=("dark3", "dark4", "dark5"),
|
| 103 |
+
depthwise=False,
|
| 104 |
+
act="silu",
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
assert out_features, "please provide output features of Darknet"
|
| 108 |
+
self.out_features = out_features
|
| 109 |
+
Conv = DWConv if depthwise else BaseConv
|
| 110 |
+
|
| 111 |
+
base_channels = int(wid_mul * 64) # 64
|
| 112 |
+
base_depth = max(round(dep_mul * 3), 1) # 3
|
| 113 |
+
|
| 114 |
+
# stem
|
| 115 |
+
self.stem = Focus(3, base_channels, ksize=3, act=act)
|
| 116 |
+
|
| 117 |
+
# dark2
|
| 118 |
+
self.dark2 = nn.Sequential(
|
| 119 |
+
Conv(base_channels, base_channels * 2, 3, 2, act=act),
|
| 120 |
+
CSPLayer(
|
| 121 |
+
base_channels * 2,
|
| 122 |
+
base_channels * 2,
|
| 123 |
+
n=base_depth,
|
| 124 |
+
depthwise=depthwise,
|
| 125 |
+
act=act,
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# dark3
|
| 130 |
+
self.dark3 = nn.Sequential(
|
| 131 |
+
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
|
| 132 |
+
CSPLayer(
|
| 133 |
+
base_channels * 4,
|
| 134 |
+
base_channels * 4,
|
| 135 |
+
n=base_depth * 3,
|
| 136 |
+
depthwise=depthwise,
|
| 137 |
+
act=act,
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# dark4
|
| 142 |
+
self.dark4 = nn.Sequential(
|
| 143 |
+
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
|
| 144 |
+
CSPLayer(
|
| 145 |
+
base_channels * 8,
|
| 146 |
+
base_channels * 8,
|
| 147 |
+
n=base_depth * 3,
|
| 148 |
+
depthwise=depthwise,
|
| 149 |
+
act=act,
|
| 150 |
+
),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# dark5
|
| 154 |
+
self.dark5 = nn.Sequential(
|
| 155 |
+
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
|
| 156 |
+
SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
|
| 157 |
+
CSPLayer(
|
| 158 |
+
base_channels * 16,
|
| 159 |
+
base_channels * 16,
|
| 160 |
+
n=base_depth,
|
| 161 |
+
shortcut=False,
|
| 162 |
+
depthwise=depthwise,
|
| 163 |
+
act=act,
|
| 164 |
+
),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
outputs = {}
|
| 169 |
+
x = self.stem(x)
|
| 170 |
+
outputs["stem"] = x
|
| 171 |
+
x = self.dark2(x)
|
| 172 |
+
outputs["dark2"] = x
|
| 173 |
+
x = self.dark3(x)
|
| 174 |
+
outputs["dark3"] = x
|
| 175 |
+
x = self.dark4(x)
|
| 176 |
+
outputs["dark4"] = x
|
| 177 |
+
x = self.dark5(x)
|
| 178 |
+
outputs["dark5"] = x
|
| 179 |
+
return {k: v for k, v in outputs.items() if k in self.out_features}
|
yolox/network_blocks.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SiLU(nn.Module):
|
| 10 |
+
"""export-friendly version of nn.SiLU()"""
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def forward(x):
|
| 14 |
+
return x * torch.sigmoid(x)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_activation(name="silu", inplace=True):
|
| 18 |
+
if name == "silu":
|
| 19 |
+
module = nn.SiLU(inplace=inplace)
|
| 20 |
+
elif name == "relu":
|
| 21 |
+
module = nn.ReLU(inplace=inplace)
|
| 22 |
+
elif name == "lrelu":
|
| 23 |
+
module = nn.LeakyReLU(0.1, inplace=inplace)
|
| 24 |
+
else:
|
| 25 |
+
raise AttributeError("Unsupported act type: {}".format(name))
|
| 26 |
+
return module
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BaseConv(nn.Module):
|
| 30 |
+
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
# same padding
|
| 37 |
+
pad = (ksize - 1) // 2
|
| 38 |
+
self.conv = nn.Conv2d(
|
| 39 |
+
in_channels,
|
| 40 |
+
out_channels,
|
| 41 |
+
kernel_size=ksize,
|
| 42 |
+
stride=stride,
|
| 43 |
+
padding=pad,
|
| 44 |
+
groups=groups,
|
| 45 |
+
bias=bias,
|
| 46 |
+
)
|
| 47 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 48 |
+
self.act = get_activation(act, inplace=True)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
return self.act(self.bn(self.conv(x)))
|
| 52 |
+
|
| 53 |
+
def fuseforward(self, x):
|
| 54 |
+
return self.act(self.conv(x))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class DWConv(nn.Module):
|
| 58 |
+
"""Depthwise Conv + Conv"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.dconv = BaseConv(
|
| 63 |
+
in_channels,
|
| 64 |
+
in_channels,
|
| 65 |
+
ksize=ksize,
|
| 66 |
+
stride=stride,
|
| 67 |
+
groups=in_channels,
|
| 68 |
+
act=act,
|
| 69 |
+
)
|
| 70 |
+
self.pconv = BaseConv(
|
| 71 |
+
in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
x = self.dconv(x)
|
| 76 |
+
return self.pconv(x)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Bottleneck(nn.Module):
|
| 80 |
+
# Standard bottleneck
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
in_channels,
|
| 84 |
+
out_channels,
|
| 85 |
+
shortcut=True,
|
| 86 |
+
expansion=0.5,
|
| 87 |
+
depthwise=False,
|
| 88 |
+
act="silu",
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
hidden_channels = int(out_channels * expansion)
|
| 92 |
+
Conv = DWConv if depthwise else BaseConv
|
| 93 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
| 94 |
+
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
|
| 95 |
+
self.use_add = shortcut and in_channels == out_channels
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
y = self.conv2(self.conv1(x))
|
| 99 |
+
if self.use_add:
|
| 100 |
+
y = y + x
|
| 101 |
+
return y
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResLayer(nn.Module):
|
| 105 |
+
"Residual layer with `in_channels` inputs."
|
| 106 |
+
|
| 107 |
+
def __init__(self, in_channels: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
mid_channels = in_channels // 2
|
| 110 |
+
self.layer1 = BaseConv(
|
| 111 |
+
in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
|
| 112 |
+
)
|
| 113 |
+
self.layer2 = BaseConv(
|
| 114 |
+
mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
out = self.layer2(self.layer1(x))
|
| 119 |
+
return x + out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SPPBottleneck(nn.Module):
|
| 123 |
+
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
hidden_channels = in_channels // 2
|
| 130 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
|
| 131 |
+
self.m = nn.ModuleList(
|
| 132 |
+
[
|
| 133 |
+
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
|
| 134 |
+
for ks in kernel_sizes
|
| 135 |
+
]
|
| 136 |
+
)
|
| 137 |
+
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
|
| 138 |
+
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
x = self.conv1(x)
|
| 142 |
+
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
|
| 143 |
+
x = self.conv2(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CSPLayer(nn.Module):
|
| 148 |
+
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
in_channels,
|
| 153 |
+
out_channels,
|
| 154 |
+
n=1,
|
| 155 |
+
shortcut=True,
|
| 156 |
+
expansion=0.5,
|
| 157 |
+
depthwise=False,
|
| 158 |
+
act="silu",
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
in_channels (int): input channels.
|
| 163 |
+
out_channels (int): output channels.
|
| 164 |
+
n (int): number of Bottlenecks. Default value: 1.
|
| 165 |
+
"""
|
| 166 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
| 167 |
+
super().__init__()
|
| 168 |
+
hidden_channels = int(out_channels * expansion) # hidden channels
|
| 169 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
| 170 |
+
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
| 171 |
+
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
|
| 172 |
+
module_list = [
|
| 173 |
+
Bottleneck(
|
| 174 |
+
hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
|
| 175 |
+
)
|
| 176 |
+
for _ in range(n)
|
| 177 |
+
]
|
| 178 |
+
self.m = nn.Sequential(*module_list)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
x_1 = self.conv1(x)
|
| 182 |
+
x_2 = self.conv2(x)
|
| 183 |
+
x_1 = self.m(x_1)
|
| 184 |
+
x = torch.cat((x_1, x_2), dim=1)
|
| 185 |
+
return self.conv3(x)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Focus(nn.Module):
|
| 189 |
+
"""Focus width and height information into channel space."""
|
| 190 |
+
|
| 191 |
+
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
|
| 197 |
+
patch_top_left = x[..., ::2, ::2]
|
| 198 |
+
patch_top_right = x[..., ::2, 1::2]
|
| 199 |
+
patch_bot_left = x[..., 1::2, ::2]
|
| 200 |
+
patch_bot_right = x[..., 1::2, 1::2]
|
| 201 |
+
x = torch.cat(
|
| 202 |
+
(
|
| 203 |
+
patch_top_left,
|
| 204 |
+
patch_bot_left,
|
| 205 |
+
patch_top_right,
|
| 206 |
+
patch_bot_right,
|
| 207 |
+
),
|
| 208 |
+
dim=1,
|
| 209 |
+
)
|
| 210 |
+
return self.conv(x)
|
yolox/yolo_fpn.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .darknet import Darknet
|
| 9 |
+
from .network_blocks import BaseConv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class YOLOFPN(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
YOLOFPN module. Darknet 53 is the default backbone of this model.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
depth=53,
|
| 20 |
+
in_features=["dark3", "dark4", "dark5"],
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.backbone = Darknet(depth)
|
| 25 |
+
self.in_features = in_features
|
| 26 |
+
|
| 27 |
+
# out 1
|
| 28 |
+
self.out1_cbl = self._make_cbl(512, 256, 1)
|
| 29 |
+
self.out1 = self._make_embedding([256, 512], 512 + 256)
|
| 30 |
+
|
| 31 |
+
# out 2
|
| 32 |
+
self.out2_cbl = self._make_cbl(256, 128, 1)
|
| 33 |
+
self.out2 = self._make_embedding([128, 256], 256 + 128)
|
| 34 |
+
|
| 35 |
+
# upsample
|
| 36 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
| 37 |
+
|
| 38 |
+
def _make_cbl(self, _in, _out, ks):
|
| 39 |
+
return BaseConv(_in, _out, ks, stride=1, act="lrelu")
|
| 40 |
+
|
| 41 |
+
def _make_embedding(self, filters_list, in_filters):
|
| 42 |
+
m = nn.Sequential(
|
| 43 |
+
*[
|
| 44 |
+
self._make_cbl(in_filters, filters_list[0], 1),
|
| 45 |
+
self._make_cbl(filters_list[0], filters_list[1], 3),
|
| 46 |
+
self._make_cbl(filters_list[1], filters_list[0], 1),
|
| 47 |
+
self._make_cbl(filters_list[0], filters_list[1], 3),
|
| 48 |
+
self._make_cbl(filters_list[1], filters_list[0], 1),
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
return m
|
| 52 |
+
|
| 53 |
+
def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
|
| 54 |
+
with open(filename, "rb") as f:
|
| 55 |
+
state_dict = torch.load(f, map_location="cpu")
|
| 56 |
+
print("loading pretrained weights...")
|
| 57 |
+
self.backbone.load_state_dict(state_dict)
|
| 58 |
+
|
| 59 |
+
def forward(self, inputs):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
inputs (Tensor): input image.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Tuple[Tensor]: FPN output features..
|
| 66 |
+
"""
|
| 67 |
+
# backbone
|
| 68 |
+
out_features = self.backbone(inputs)
|
| 69 |
+
x2, x1, x0 = [out_features[f] for f in self.in_features]
|
| 70 |
+
|
| 71 |
+
# yolo branch 1
|
| 72 |
+
x1_in = self.out1_cbl(x0)
|
| 73 |
+
x1_in = self.upsample(x1_in)
|
| 74 |
+
x1_in = torch.cat([x1_in, x1], 1)
|
| 75 |
+
out_dark4 = self.out1(x1_in)
|
| 76 |
+
|
| 77 |
+
# yolo branch 2
|
| 78 |
+
x2_in = self.out2_cbl(out_dark4)
|
| 79 |
+
x2_in = self.upsample(x2_in)
|
| 80 |
+
x2_in = torch.cat([x2_in, x2], 1)
|
| 81 |
+
out_dark3 = self.out2(x2_in)
|
| 82 |
+
|
| 83 |
+
outputs = (out_dark3, out_dark4, x0)
|
| 84 |
+
return outputs
|
yolox/yolo_head.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from .network_blocks import BaseConv, DWConv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def meshgrid(*tensors):
|
| 14 |
+
"""
|
| 15 |
+
Copied from YOLOX/yolox/utils/compat.py
|
| 16 |
+
"""
|
| 17 |
+
if _TORCH_VER >= [1, 10]:
|
| 18 |
+
return torch.meshgrid(*tensors, indexing="ij")
|
| 19 |
+
else:
|
| 20 |
+
return torch.meshgrid(*tensors)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
|
| 24 |
+
"""
|
| 25 |
+
Copied from YOLOX/yolox/utils/boxes.py
|
| 26 |
+
"""
|
| 27 |
+
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
|
| 28 |
+
raise IndexError
|
| 29 |
+
|
| 30 |
+
if xyxy:
|
| 31 |
+
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
|
| 32 |
+
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
|
| 33 |
+
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
|
| 34 |
+
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
|
| 35 |
+
else:
|
| 36 |
+
tl = torch.max(
|
| 37 |
+
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
|
| 38 |
+
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
|
| 39 |
+
)
|
| 40 |
+
br = torch.min(
|
| 41 |
+
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
|
| 42 |
+
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
area_a = torch.prod(bboxes_a[:, 2:], 1)
|
| 46 |
+
area_b = torch.prod(bboxes_b[:, 2:], 1)
|
| 47 |
+
en = (tl < br).type(tl.type()).prod(dim=2)
|
| 48 |
+
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
|
| 49 |
+
return area_i / (area_a[:, None] + area_b - area_i)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class YOLOXHead(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
num_classes,
|
| 56 |
+
width=1.0,
|
| 57 |
+
strides=[8, 16, 32],
|
| 58 |
+
in_channels=[256, 512, 1024],
|
| 59 |
+
act="silu",
|
| 60 |
+
depthwise=False,
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
act (str): activation type of conv. Defalut value: "silu".
|
| 65 |
+
depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
|
| 66 |
+
"""
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.num_classes = num_classes
|
| 70 |
+
self.decode_in_inference = True # for deploy, set to False
|
| 71 |
+
|
| 72 |
+
self.cls_convs = nn.ModuleList()
|
| 73 |
+
self.reg_convs = nn.ModuleList()
|
| 74 |
+
self.cls_preds = nn.ModuleList()
|
| 75 |
+
self.reg_preds = nn.ModuleList()
|
| 76 |
+
self.obj_preds = nn.ModuleList()
|
| 77 |
+
self.stems = nn.ModuleList()
|
| 78 |
+
Conv = DWConv if depthwise else BaseConv
|
| 79 |
+
|
| 80 |
+
for i in range(len(in_channels)):
|
| 81 |
+
self.stems.append(
|
| 82 |
+
BaseConv(
|
| 83 |
+
in_channels=int(in_channels[i] * width),
|
| 84 |
+
out_channels=int(256 * width),
|
| 85 |
+
ksize=1,
|
| 86 |
+
stride=1,
|
| 87 |
+
act=act,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
self.cls_convs.append(
|
| 91 |
+
nn.Sequential(
|
| 92 |
+
*[
|
| 93 |
+
Conv(
|
| 94 |
+
in_channels=int(256 * width),
|
| 95 |
+
out_channels=int(256 * width),
|
| 96 |
+
ksize=3,
|
| 97 |
+
stride=1,
|
| 98 |
+
act=act,
|
| 99 |
+
),
|
| 100 |
+
Conv(
|
| 101 |
+
in_channels=int(256 * width),
|
| 102 |
+
out_channels=int(256 * width),
|
| 103 |
+
ksize=3,
|
| 104 |
+
stride=1,
|
| 105 |
+
act=act,
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
self.reg_convs.append(
|
| 111 |
+
nn.Sequential(
|
| 112 |
+
*[
|
| 113 |
+
Conv(
|
| 114 |
+
in_channels=int(256 * width),
|
| 115 |
+
out_channels=int(256 * width),
|
| 116 |
+
ksize=3,
|
| 117 |
+
stride=1,
|
| 118 |
+
act=act,
|
| 119 |
+
),
|
| 120 |
+
Conv(
|
| 121 |
+
in_channels=int(256 * width),
|
| 122 |
+
out_channels=int(256 * width),
|
| 123 |
+
ksize=3,
|
| 124 |
+
stride=1,
|
| 125 |
+
act=act,
|
| 126 |
+
),
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
self.cls_preds.append(
|
| 131 |
+
nn.Conv2d(
|
| 132 |
+
in_channels=int(256 * width),
|
| 133 |
+
out_channels=self.num_classes,
|
| 134 |
+
kernel_size=1,
|
| 135 |
+
stride=1,
|
| 136 |
+
padding=0,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
self.reg_preds.append(
|
| 140 |
+
nn.Conv2d(
|
| 141 |
+
in_channels=int(256 * width),
|
| 142 |
+
out_channels=4,
|
| 143 |
+
kernel_size=1,
|
| 144 |
+
stride=1,
|
| 145 |
+
padding=0,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
self.obj_preds.append(
|
| 149 |
+
nn.Conv2d(
|
| 150 |
+
in_channels=int(256 * width),
|
| 151 |
+
out_channels=1,
|
| 152 |
+
kernel_size=1,
|
| 153 |
+
stride=1,
|
| 154 |
+
padding=0,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.use_l1 = False
|
| 159 |
+
self.l1_loss = nn.L1Loss(reduction="none")
|
| 160 |
+
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
| 161 |
+
self.iou_loss = None
|
| 162 |
+
self.strides = strides
|
| 163 |
+
self.grids = [torch.zeros(1)] * len(in_channels)
|
| 164 |
+
|
| 165 |
+
def forward(self, xin, labels=None, imgs=None):
|
| 166 |
+
outputs = []
|
| 167 |
+
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
|
| 168 |
+
zip(self.cls_convs, self.reg_convs, self.strides, xin)
|
| 169 |
+
):
|
| 170 |
+
x = self.stems[k](x)
|
| 171 |
+
cls_x = x
|
| 172 |
+
reg_x = x
|
| 173 |
+
|
| 174 |
+
cls_feat = cls_conv(cls_x)
|
| 175 |
+
cls_output = self.cls_preds[k](cls_feat)
|
| 176 |
+
|
| 177 |
+
reg_feat = reg_conv(reg_x)
|
| 178 |
+
reg_output = self.reg_preds[k](reg_feat)
|
| 179 |
+
obj_output = self.obj_preds[k](reg_feat)
|
| 180 |
+
|
| 181 |
+
output = torch.cat(
|
| 182 |
+
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
outputs.append(output)
|
| 186 |
+
|
| 187 |
+
self.hw = [x.shape[-2:] for x in outputs]
|
| 188 |
+
# [batch, n_anchors_all, 85]
|
| 189 |
+
outputs = torch.cat(
|
| 190 |
+
[x.flatten(start_dim=2) for x in outputs], dim=2
|
| 191 |
+
).permute(0, 2, 1)
|
| 192 |
+
if self.decode_in_inference:
|
| 193 |
+
return self.decode_outputs(outputs, dtype=xin[0].type())
|
| 194 |
+
else:
|
| 195 |
+
return outputs
|
| 196 |
+
|
| 197 |
+
def get_output_and_grid(self, output, k, stride, dtype):
|
| 198 |
+
grid = self.grids[k]
|
| 199 |
+
|
| 200 |
+
batch_size = output.shape[0]
|
| 201 |
+
n_ch = 5 + self.num_classes
|
| 202 |
+
hsize, wsize = output.shape[-2:]
|
| 203 |
+
if grid.shape[2:4] != output.shape[2:4]:
|
| 204 |
+
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
| 205 |
+
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
| 206 |
+
self.grids[k] = grid
|
| 207 |
+
|
| 208 |
+
output = output.view(batch_size, 1, n_ch, hsize, wsize)
|
| 209 |
+
output = output.permute(0, 1, 3, 4, 2).reshape(
|
| 210 |
+
batch_size, hsize * wsize, -1
|
| 211 |
+
)
|
| 212 |
+
grid = grid.view(1, -1, 2)
|
| 213 |
+
output[..., :2] = (output[..., :2] + grid) * stride
|
| 214 |
+
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
| 215 |
+
return output, grid
|
| 216 |
+
|
| 217 |
+
def decode_outputs(self, outputs, dtype):
|
| 218 |
+
grids = []
|
| 219 |
+
strides = []
|
| 220 |
+
for (hsize, wsize), stride in zip(self.hw, self.strides):
|
| 221 |
+
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
| 222 |
+
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
|
| 223 |
+
grids.append(grid)
|
| 224 |
+
shape = grid.shape[:2]
|
| 225 |
+
strides.append(torch.full((*shape, 1), stride))
|
| 226 |
+
|
| 227 |
+
grids = torch.cat(grids, dim=1).type(dtype)
|
| 228 |
+
strides = torch.cat(strides, dim=1).type(dtype)
|
| 229 |
+
|
| 230 |
+
outputs = torch.cat([
|
| 231 |
+
(outputs[..., 0:2] + grids) * strides,
|
| 232 |
+
torch.exp(outputs[..., 2:4]) * strides,
|
| 233 |
+
outputs[..., 4:]
|
| 234 |
+
], dim=-1)
|
| 235 |
+
return outputs
|
yolox/yolo_pafpn.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .darknet import CSPDarknet
|
| 9 |
+
from .network_blocks import BaseConv, CSPLayer, DWConv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class YOLOPAFPN(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
YOLOv3 model. Darknet 53 is the default backbone of this model.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
depth=1.0,
|
| 20 |
+
width=1.0,
|
| 21 |
+
in_features=("dark3", "dark4", "dark5"),
|
| 22 |
+
in_channels=[256, 512, 1024],
|
| 23 |
+
depthwise=False,
|
| 24 |
+
act="silu",
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
|
| 28 |
+
self.in_features = in_features
|
| 29 |
+
self.in_channels = in_channels
|
| 30 |
+
Conv = DWConv if depthwise else BaseConv
|
| 31 |
+
|
| 32 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
| 33 |
+
self.lateral_conv0 = BaseConv(
|
| 34 |
+
int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
|
| 35 |
+
)
|
| 36 |
+
self.C3_p4 = CSPLayer(
|
| 37 |
+
int(2 * in_channels[1] * width),
|
| 38 |
+
int(in_channels[1] * width),
|
| 39 |
+
round(3 * depth),
|
| 40 |
+
False,
|
| 41 |
+
depthwise=depthwise,
|
| 42 |
+
act=act,
|
| 43 |
+
) # cat
|
| 44 |
+
|
| 45 |
+
self.reduce_conv1 = BaseConv(
|
| 46 |
+
int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
|
| 47 |
+
)
|
| 48 |
+
self.C3_p3 = CSPLayer(
|
| 49 |
+
int(2 * in_channels[0] * width),
|
| 50 |
+
int(in_channels[0] * width),
|
| 51 |
+
round(3 * depth),
|
| 52 |
+
False,
|
| 53 |
+
depthwise=depthwise,
|
| 54 |
+
act=act,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# bottom-up conv
|
| 58 |
+
self.bu_conv2 = Conv(
|
| 59 |
+
int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
|
| 60 |
+
)
|
| 61 |
+
self.C3_n3 = CSPLayer(
|
| 62 |
+
int(2 * in_channels[0] * width),
|
| 63 |
+
int(in_channels[1] * width),
|
| 64 |
+
round(3 * depth),
|
| 65 |
+
False,
|
| 66 |
+
depthwise=depthwise,
|
| 67 |
+
act=act,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# bottom-up conv
|
| 71 |
+
self.bu_conv1 = Conv(
|
| 72 |
+
int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
|
| 73 |
+
)
|
| 74 |
+
self.C3_n4 = CSPLayer(
|
| 75 |
+
int(2 * in_channels[1] * width),
|
| 76 |
+
int(in_channels[2] * width),
|
| 77 |
+
round(3 * depth),
|
| 78 |
+
False,
|
| 79 |
+
depthwise=depthwise,
|
| 80 |
+
act=act,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, input):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
inputs: input images.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Tuple[Tensor]: FPN feature.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
# backbone
|
| 93 |
+
out_features = self.backbone(input)
|
| 94 |
+
features = [out_features[f] for f in self.in_features]
|
| 95 |
+
[x2, x1, x0] = features
|
| 96 |
+
|
| 97 |
+
fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
|
| 98 |
+
f_out0 = self.upsample(fpn_out0) # 512/16
|
| 99 |
+
f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
|
| 100 |
+
f_out0 = self.C3_p4(f_out0) # 1024->512/16
|
| 101 |
+
|
| 102 |
+
fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
|
| 103 |
+
f_out1 = self.upsample(fpn_out1) # 256/8
|
| 104 |
+
f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
|
| 105 |
+
pan_out2 = self.C3_p3(f_out1) # 512->256/8
|
| 106 |
+
|
| 107 |
+
p_out1 = self.bu_conv2(pan_out2) # 256->256/16
|
| 108 |
+
p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
|
| 109 |
+
pan_out1 = self.C3_n3(p_out1) # 512->512/16
|
| 110 |
+
|
| 111 |
+
p_out0 = self.bu_conv1(pan_out1) # 512->512/32
|
| 112 |
+
p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
|
| 113 |
+
pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
|
| 114 |
+
|
| 115 |
+
outputs = (pan_out2, pan_out1, pan_out0)
|
| 116 |
+
return outputs
|
yolox/yolox.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .yolo_head import YOLOXHead
|
| 8 |
+
from .yolo_pafpn import YOLOPAFPN
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class YOLOX(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
YOLOX model module. The module list is defined by create_yolov3_modules function.
|
| 14 |
+
The network returns loss values from three YOLO layers during training
|
| 15 |
+
and detection results during test.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, backbone=None, head=None):
|
| 19 |
+
super().__init__()
|
| 20 |
+
if backbone is None:
|
| 21 |
+
backbone = YOLOPAFPN()
|
| 22 |
+
if head is None:
|
| 23 |
+
head = YOLOXHead(80)
|
| 24 |
+
|
| 25 |
+
self.backbone = backbone
|
| 26 |
+
self.head = head
|
| 27 |
+
|
| 28 |
+
def forward(self, x, targets=None):
|
| 29 |
+
assert not self.training, "Training mode not supported, please refer to the YOLOX repo"
|
| 30 |
+
fpn_outs = self.backbone(x)
|
| 31 |
+
outputs = self.head(fpn_outs)
|
| 32 |
+
return outputs
|