Theo Viel commited on
Commit
8baa7cf
·
1 Parent(s): bbf44bf
graphic_element_v1.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Tuple
4
+
5
+
6
+ class Exp:
7
+ """
8
+ Configuration class for the graphic element model.
9
+
10
+ This class contains all configuration parameters for the YOLOX-based
11
+ page element detection model, including architecture settings, inference
12
+ parameters, and class-specific thresholds.
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ """Initialize the configuration with default parameters."""
17
+ self.name: str = "graphic-element-v1"
18
+ self.ckpt: str = "weights.pth"
19
+ self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+
21
+ # YOLOX architecture parameters
22
+ self.act: str = "silu"
23
+ self.depth: float = 1.00
24
+ self.width: float = 1.00
25
+ self.labels: List[str] = [
26
+ "chart_title",
27
+ "x_title",
28
+ "y_title",
29
+ "xlabel",
30
+ "ylabel",
31
+ "other",
32
+ "legend_label",
33
+ "legend_title",
34
+ "mark_label",
35
+ "value_label",
36
+ ]
37
+ self.num_classes: int = len(self.labels)
38
+
39
+ # Inference parameters
40
+ self.size: Tuple[int, int] = (1024, 1024)
41
+ self.min_bbox_size: int = 0
42
+ self.normalize_boxes: bool = True
43
+
44
+ # NMS & thresholding. These can be updated
45
+ self.conf_thresh: float = 0.01
46
+ self.iou_thresh: float = 0.25
47
+ self.class_agnostic: bool = True # False
48
+
49
+ self.threshold: float = 0.1
50
+
51
+ def get_model(self) -> nn.Module:
52
+ """
53
+ Get the YOLOX model.
54
+
55
+ Builds and returns a YOLOX model with the configured architecture.
56
+ Also updates batch normalization parameters for optimal inference.
57
+
58
+ Returns:
59
+ nn.Module: The YOLOX model with configured parameters.
60
+ """
61
+ from yolox import YOLOX, YOLOPAFPN, YOLOXHead
62
+
63
+ # Build model
64
+ if getattr(self, "model", None) is None:
65
+ in_channels = [256, 512, 1024]
66
+ backbone = YOLOPAFPN(
67
+ self.depth, self.width, in_channels=in_channels, act=self.act
68
+ )
69
+ head = YOLOXHead(
70
+ self.num_classes, self.width, in_channels=in_channels, act=self.act
71
+ )
72
+ self.model = YOLOX(backbone, head)
73
+
74
+ # Update batch-norm parameters
75
+ def init_yolo(M: nn.Module) -> None:
76
+ for m in M.modules():
77
+ if isinstance(m, nn.BatchNorm2d):
78
+ m.eps = 1e-3
79
+ m.momentum = 0.03
80
+
81
+ self.model.apply(init_yolo)
82
+
83
+ return self.model
model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import importlib
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Tuple, Union
10
+ from yolox.boxes import postprocess
11
+
12
+
13
+ def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module:
14
+ """
15
+ Defines and initializes the model based on the configuration.
16
+
17
+ Args:
18
+ config_name (str): Configuration name. Defaults to "graphic_element_v1".
19
+ verbose (bool): Whether to print verbose output. Defaults to True.
20
+
21
+ Returns:
22
+ torch.nn.Module: The initialized YOLOX model.
23
+ """
24
+ # Load model from exp_file
25
+ sys.path.append(os.path.dirname(config_name))
26
+ exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0])
27
+
28
+ config = exp_module.Exp()
29
+ model = config.get_model()
30
+
31
+ # Load weights
32
+ if verbose:
33
+ print(" -> Loading weights from", config.ckpt)
34
+
35
+ ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
36
+ model.load_state_dict(ckpt["model"], strict=True)
37
+
38
+ model = YoloXWrapper(model, config)
39
+ return model.eval().to(config.device)
40
+
41
+
42
+ def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor:
43
+ """
44
+ Resizes and pads an image to a given size.
45
+ The goal is to preserve the aspect ratio of the image.
46
+
47
+ Args:
48
+ img (torch.Tensor[C x H x W]): The image to resize and pad.
49
+ size (tuple[2]): The size to resize and pad the image to.
50
+
51
+ Returns:
52
+ torch.Tensor: The resized and padded image.
53
+ """
54
+ img = img.float()
55
+ _, h, w = img.shape
56
+ scale = min(size[0] / h, size[1] / w)
57
+ nh = int(h * scale)
58
+ nw = int(w * scale)
59
+ img = F.interpolate(
60
+ img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False
61
+ ).squeeze(0)
62
+ img = torch.clamp(img, 0, 255)
63
+ pad_b = size[0] - nh
64
+ pad_r = size[1] - nw
65
+ img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0)
66
+ return img
67
+
68
+
69
+ class YoloXWrapper(nn.Module):
70
+ """
71
+ Wrapper for YoloX models.
72
+ """
73
+ def __init__(self, model: nn.Module, config) -> None:
74
+ """
75
+ Constructor
76
+
77
+ Args:
78
+ model (torch model): Yolo model.
79
+ config (Config): Config object containing model parameters.
80
+ """
81
+ super().__init__()
82
+ self.model = model
83
+ self.config = config
84
+
85
+ # Copy config parameters
86
+ self.device = config.device
87
+ self.img_size = config.size
88
+ self.min_bbox_size = config.min_bbox_size
89
+ self.normalize_boxes = config.normalize_boxes
90
+ self.conf_thresh = config.conf_thresh
91
+ self.iou_thresh = config.iou_thresh
92
+ self.class_agnostic = config.class_agnostic
93
+ self.threshold = config.threshold
94
+ self.labels = config.labels
95
+ self.num_classes = config.num_classes
96
+
97
+ def reformat_input(
98
+ self,
99
+ x: torch.Tensor,
100
+ orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ """
103
+ Reformats the input data and original sizes to the correct format.
104
+
105
+ Args:
106
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
107
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
108
+ Returns:
109
+ torch tensor [BS x C x H x W]: Input image batch.
110
+ torch tensor [BS x 2]: Original image sizes (before resizing and padding).
111
+ """
112
+ # Convert image size to tensor
113
+ if isinstance(orig_sizes, (list, tuple)):
114
+ orig_sizes = np.array(orig_sizes)
115
+ if orig_sizes.shape[-1] == 3: # remove channel
116
+ orig_sizes = orig_sizes[..., :2]
117
+ if isinstance(orig_sizes, np.ndarray):
118
+ orig_sizes = torch.from_numpy(orig_sizes).to(self.device)
119
+
120
+ # Add batch dimension if not present
121
+ if len(x.size()) == 3:
122
+ x = x.unsqueeze(0)
123
+ if len(orig_sizes.size()) == 1:
124
+ orig_sizes = orig_sizes.unsqueeze(0)
125
+
126
+ return x, orig_sizes
127
+
128
+ def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
129
+ """
130
+ YoloX preprocessing function:
131
+ - Resizes to the longest edge to img_size while preserving the aspect ratio
132
+ - Pads the shortest edge to img_size
133
+
134
+ Args:
135
+ image (torch tensor or np array [H x W x 3]): Input images in uint8 format.
136
+
137
+ Returns:
138
+ torch tensor [3 x H x W]: Processed image.
139
+ """
140
+ if not isinstance(image, torch.Tensor):
141
+ image = torch.from_numpy(image)
142
+ image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
143
+ image = resize_pad(image, self.img_size)
144
+ return image.float()
145
+
146
+ def forward(
147
+ self,
148
+ x: torch.Tensor,
149
+ orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
150
+ ) -> List[Dict[str, torch.Tensor]]:
151
+ """
152
+ Forward pass of the model.
153
+ Applies NMS and reformats the predictions.
154
+
155
+ Args:
156
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
157
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
158
+
159
+ Returns:
160
+ list[dict]: List of prediction dictionaries. Each dictionary contains:
161
+ - labels (torch.Tensor[N]): Class labels
162
+ - boxes (torch.Tensor[N x 4]): Bounding boxes
163
+ - scores (torch.Tensor[N]): Confidence scores.
164
+ """
165
+ x, orig_sizes = self.reformat_input(x, orig_sizes)
166
+
167
+ # Scale to 0-255 if in range 0-1
168
+ if x.max() <= 1:
169
+ x *= 255
170
+
171
+ pred_boxes = self.model(x.to(self.device))
172
+
173
+ # NMS
174
+ pred_boxes = postprocess(
175
+ pred_boxes,
176
+ self.config.num_classes,
177
+ self.conf_thresh,
178
+ self.iou_thresh,
179
+ class_agnostic=self.class_agnostic,
180
+ )
181
+
182
+ # Reformat output
183
+ preds = []
184
+ for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)):
185
+ if p is None: # No detections
186
+ preds.append({
187
+ "labels": torch.empty(0),
188
+ "boxes": torch.empty((0, 4)),
189
+ "scores": torch.empty(0),
190
+ })
191
+ continue
192
+
193
+ p = p.view(-1, p.size(-1))
194
+ ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1])
195
+ boxes = p[:, :4] / ratio
196
+
197
+ # Clip
198
+ boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1])
199
+ boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0])
200
+
201
+ # Remove too small
202
+ kept = (
203
+ (boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) &
204
+ (boxes[:, 3] - boxes[:, 1] > self.min_bbox_size)
205
+ )
206
+ boxes = boxes[kept]
207
+ p = p[kept]
208
+
209
+ # Normalize to 0-1
210
+ if self.normalize_boxes:
211
+ boxes[:, [0, 2]] /= size[1]
212
+ boxes[:, [1, 3]] /= size[0]
213
+
214
+ scores = p[:, 4] * p[:, 5]
215
+ labels = p[:, 6]
216
+
217
+ preds.append({"labels": labels, "boxes": boxes, "scores": scores})
218
+
219
+ return preds
post_processing/graphic_elt_pp.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ from typing import Tuple, List
4
+
5
+
6
+ def bb_iou_array(
7
+ boxes: npt.NDArray[np.float64], new_box: npt.NDArray[np.float64]
8
+ ) -> npt.NDArray[np.float64]:
9
+ """
10
+ Calculates the Intersection over Union (IoU) between a box and an array of boxes.
11
+
12
+ Args:
13
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
14
+ new_box (numpy.ndarray): A single bounding box [x_min, y_min, x_max, y_max].
15
+
16
+ Returns:
17
+ numpy.ndarray: Array of IoU values between the new_box and each box in the array.
18
+ """
19
+ # bb interesection over union
20
+ xA = np.maximum(boxes[:, 0], new_box[0])
21
+ yA = np.maximum(boxes[:, 1], new_box[1])
22
+ xB = np.minimum(boxes[:, 2], new_box[2])
23
+ yB = np.minimum(boxes[:, 3], new_box[3])
24
+
25
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
26
+
27
+ # compute the area of both the prediction and ground-truth rectangles
28
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
29
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
30
+
31
+ iou = interArea / (boxAArea + boxBArea - interArea)
32
+
33
+ return iou
34
+
35
+
36
+ def expand_boxes(
37
+ boxes: npt.NDArray[np.float64],
38
+ r_x: Tuple[float, float] = (1, 1),
39
+ r_y: Tuple[float, float] = (1, 1),
40
+ size_agnostic: bool = True,
41
+ ) -> npt.NDArray[np.float64]:
42
+ """
43
+ Expands bounding boxes by a specified ratio.
44
+ Expected box format is normalized [x_min, y_min, x_max, y_max].
45
+
46
+ Args:
47
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
48
+ r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
49
+ r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
50
+ size_agnostic (bool, optional): Expand independently of the box shape. Defaults to True.
51
+
52
+ Returns:
53
+ numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
54
+ """
55
+ old_boxes = boxes.copy()
56
+
57
+ if not size_agnostic:
58
+ h = boxes[:, 3] - boxes[:, 1]
59
+ w = boxes[:, 2] - boxes[:, 0]
60
+ else:
61
+ h, w = 1, 1
62
+
63
+ boxes[:, 0] -= w * (r_x[0] - 1) # left
64
+ boxes[:, 2] += w * (r_x[1] - 1) # right
65
+ boxes[:, 1] -= h * (r_y[0] - 1) # up
66
+ boxes[:, 3] += h * (r_y[1] - 1) # down
67
+
68
+ boxes = np.clip(boxes, 0, 1)
69
+
70
+ # Enforce non-overlapping boxes
71
+ for i in range(len(boxes)):
72
+ for j in range(i + 1, len(boxes)):
73
+ iou = bb_iou_array(boxes[i][None], boxes[j])[0]
74
+ old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
75
+ # print(iou, old_iou)
76
+ if iou > 0.05 and old_iou < 0.1:
77
+ if boxes[i, 1] < boxes[j, 1]: # i above j
78
+ boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
79
+ if old_iou > 0:
80
+ boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
81
+ else:
82
+ boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
83
+ if old_iou > 0:
84
+ boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
85
+
86
+ return boxes
87
+
88
+
89
+ def retrieve_title(
90
+ boxes: npt.NDArray[np.float64],
91
+ labels: npt.NDArray[np.int_],
92
+ scores: npt.NDArray[np.float64],
93
+ classes: List[str],
94
+ ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
95
+ """
96
+ Retrieves missed captions by using the biggest `other` box.
97
+
98
+ If no chart_title is detected, this function finds the largest box
99
+ labeled as 'other' (with width > 0.3) and relabels it as 'chart_title'.
100
+
101
+ Args:
102
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
103
+ labels (numpy.ndarray): Array of labels with shape (N,).
104
+ scores (numpy.ndarray): Array of confidence scores with shape (N,).
105
+ classes (list): List of class labels.
106
+
107
+ Returns:
108
+ numpy.ndarray [N x 4]: Array of bounding boxes (unchanged).
109
+ numpy.ndarray [N]: Array of labels (potentially modified).
110
+ numpy.ndarray [N]: Array of scores (unchanged).
111
+ """
112
+ if classes.index("chart_title") not in labels:
113
+ widths = boxes[:, 2] - boxes[:, 0]
114
+ scores = widths * (labels == classes.index("other")) * (widths > 0.3)
115
+ replaced = np.argmax(scores) if max(scores) > 0 else None
116
+ if replaced is not None:
117
+ labels[replaced] = classes.index("chart_title")
118
+ return boxes, labels, scores
utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import numpy.typing as npt
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.patches import Rectangle
7
+ from typing import Dict, List, Tuple, Optional, Union
8
+
9
+
10
+ COLORS = [
11
+ "#003EFF",
12
+ "#FF8F00",
13
+ "#079700",
14
+ "#A123FF",
15
+ "#87CEEB",
16
+ "#FF5733",
17
+ "#C70039",
18
+ "#900C3F",
19
+ "#581845",
20
+ "#11998E",
21
+ ]
22
+
23
+
24
+ def reformat_for_plotting(
25
+ boxes: npt.NDArray[np.float64],
26
+ labels: npt.NDArray[np.int_],
27
+ scores: npt.NDArray[np.float64],
28
+ shape: Tuple[int, int, int],
29
+ num_classes: int,
30
+ ) -> Tuple[List[npt.NDArray[np.int_]], List[npt.NDArray[np.float64]]]:
31
+ """
32
+ Reformat YOLOX predictions for plotting.
33
+
34
+ Args:
35
+ boxes (np.ndarray): Array of bounding boxes.
36
+ labels (np.ndarray): Array of labels.
37
+ scores (np.ndarray): Array of confidence scores.
38
+ shape (tuple): Shape of the image.
39
+ num_classes (int): Number of classes.
40
+
41
+ Returns:
42
+ list[np.ndarray]: List of box bounding boxes per class.
43
+ list[np.ndarray]: List of confidence scores per class.
44
+ """
45
+ boxes_plot = boxes.copy()
46
+ boxes_plot[:, [0, 2]] *= shape[1]
47
+ boxes_plot[:, [1, 3]] *= shape[0]
48
+ boxes_plot = boxes_plot.astype(int)
49
+ boxes_plot[:, 2] -= boxes_plot[:, 0]
50
+ boxes_plot[:, 3] -= boxes_plot[:, 1]
51
+ boxes_plot = [boxes_plot[labels == c] for c in range(num_classes)]
52
+ confs = [scores[labels == c] for c in range(num_classes)]
53
+ return boxes_plot, confs
54
+
55
+
56
+ def plot_sample(
57
+ img: npt.NDArray[np.uint8],
58
+ boxes_list: List[npt.NDArray[np.int_]],
59
+ confs_list: List[npt.NDArray[np.float64]],
60
+ labels: List[str],
61
+ ) -> None:
62
+ """
63
+ Plots an image with bounding boxes.
64
+ Coordinates are expected in format [x_min, y_min, width, height].
65
+
66
+ Args:
67
+ img (numpy.ndarray): The input image to be plotted.
68
+ boxes_list (list[np.ndarray]): List of box bounding boxes per class.
69
+ confs_list (list[np.ndarray]): List of confidence scores per class.
70
+ labels (list): List of class labels.
71
+ """
72
+ plt.imshow(img, cmap="gray")
73
+ plt.axis(False)
74
+
75
+ for boxes, confs, col, l in zip(boxes_list, confs_list, COLORS, labels):
76
+ for box_idx, box in enumerate(boxes):
77
+ # Better display around boundaries
78
+ h, w, _ = img.shape
79
+ box = np.copy(box)
80
+ box[:2] = np.clip(box[:2], 2, max(h, w))
81
+ box[2] = min(box[2], w - 2 - box[0])
82
+ box[3] = min(box[3], h - 2 - box[1])
83
+
84
+ rect = Rectangle(
85
+ (box[0], box[1]),
86
+ box[2],
87
+ box[3],
88
+ linewidth=2,
89
+ facecolor="none",
90
+ edgecolor=col,
91
+ )
92
+ plt.gca().add_patch(rect)
93
+
94
+ # Add class and index label with proper alignment
95
+ plt.text(
96
+ box[0], box[1],
97
+ f"{l}_{box_idx} conf={confs[box_idx]:.3f}",
98
+ color='white',
99
+ fontsize=8,
100
+ bbox=dict(facecolor=col, alpha=1, edgecolor=col, pad=0, linewidth=2),
101
+ verticalalignment='bottom',
102
+ horizontalalignment='left'
103
+ )
104
+
105
+
106
+ def reorder_boxes(
107
+ boxes: npt.NDArray[np.float64],
108
+ labels: npt.NDArray[np.int_],
109
+ classes: Optional[List[str]] = None,
110
+ scores: Optional[npt.NDArray[np.float64]] = None,
111
+ ) -> Union[
112
+ Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_]],
113
+ Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]],
114
+ ]:
115
+ """
116
+ Reorder boxes, labels and scores by box coordinates.
117
+ Ordering depends on the class.
118
+
119
+ Args:
120
+ boxes (np.ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2].
121
+ labels (np.ndarray): Array of labels of shape (N,).
122
+ classes (list, optional): List of class labels. Defaults to None.
123
+ scores (np.ndarray, optional): Array of confidences of shape (N,). Defaults to None.
124
+
125
+ Returns:
126
+ np.ndarray [N, 4]: Ordered boxes.
127
+ np.ndarray [N]: Ordered labels.
128
+ np.ndarray [N]: Ordered scores if scores is not None.
129
+ """
130
+ n_classes = labels.max() if classes is None else len(classes)
131
+ classes = labels.unique() if classes is None else classes
132
+
133
+ ordered_boxes, ordered_labels, ordered_scores = [], [], []
134
+ for c in range(n_classes):
135
+ boxes_class = boxes[labels == c]
136
+ if len(boxes_class):
137
+ # Reorder
138
+ sort = ["y0", "x0"]
139
+ ascending = [True, True]
140
+ if classes[c] == "ylabel":
141
+ ascending = [False, True]
142
+ elif classes[c] == "y_title":
143
+ sort = ["x0", "y0"]
144
+ ascending = [True, False]
145
+
146
+ df_coords = pd.DataFrame({
147
+ "y0": np.round(boxes_class[:, 1] - boxes_class[:, 1].min(), 2),
148
+ "x0": np.round(boxes_class[:, 0] - boxes_class[:, 0].min(), 2),
149
+ })
150
+
151
+ idxs = df_coords.sort_values(sort, ascending=ascending).index
152
+
153
+ ordered_boxes.append(boxes_class[idxs])
154
+ ordered_labels.append(labels[labels == c][idxs])
155
+
156
+ if scores is not None:
157
+ ordered_scores.append(scores[labels == c][idxs])
158
+
159
+ ordered_boxes = np.concatenate(ordered_boxes)
160
+ ordered_labels = np.concatenate(ordered_labels)
161
+ if scores is not None:
162
+ ordered_scores = np.concatenate(ordered_scores)
163
+ return ordered_boxes, ordered_labels, ordered_scores
164
+ return ordered_boxes, ordered_labels
165
+
166
+
167
+ def postprocess_preds_graphic_element(
168
+ preds: Dict[str, npt.NDArray],
169
+ threshold: float = 0.1,
170
+ class_labels: Optional[List[str]] = None,
171
+ reorder: bool = True,
172
+ ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
173
+ """
174
+ Post process predictions for the page element task.
175
+ - Applies thresholding
176
+ - Reorders boxes using the reading order
177
+
178
+ Args:
179
+ preds (dict): Predictions. Keys are "scores", "boxes", "labels".
180
+ threshold (float, optional): Threshold for the confidence scores. Defaults to 0.1.
181
+ class_labels (list, optional): List of class labels. Defaults to None.
182
+ reorder (bool, optional): Whether to apply reordering. Defaults to True.
183
+
184
+ Returns:
185
+ numpy.ndarray [N x 4]: Array of bounding boxes.
186
+ numpy.ndarray [N]: Array of labels.
187
+ numpy.ndarray [N]: Array of scores.
188
+ """
189
+ boxes = preds["boxes"].cpu().numpy()
190
+ labels = preds["labels"].cpu().numpy()
191
+ scores = preds["scores"].cpu().numpy()
192
+
193
+ # Threshold
194
+ boxes = boxes[scores > threshold]
195
+ labels = labels[scores > threshold]
196
+ scores = scores[scores > threshold]
197
+
198
+ if len(boxes) > 0 and reorder:
199
+ boxes, labels, scores = reorder_boxes(boxes, labels, class_labels, scores)
200
+
201
+ return boxes, labels, scores
yolox/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from .yolo_head import YOLOXHead
6
+ from .yolo_pafpn import YOLOPAFPN
7
+ from .yolox import YOLOX
yolox/boxes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Megvii Inc. All rights reserved.
3
+
4
+ import torch
5
+ import torchvision
6
+
7
+
8
+ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
9
+ """
10
+ Copied from YOLOX/yolox/utils/boxes.py
11
+ """
12
+ box_corner = prediction.new(prediction.shape)
13
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
14
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
15
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
16
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
17
+ prediction[:, :, :4] = box_corner[:, :, :4]
18
+
19
+ output = [None for _ in range(len(prediction))]
20
+ for i, image_pred in enumerate(prediction):
21
+
22
+ # If none are remaining => process next image
23
+ if not image_pred.size(0):
24
+ continue
25
+ # Get score and class with highest confidence
26
+ class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
27
+
28
+ conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
29
+ # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
30
+ detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
31
+ detections = detections[conf_mask]
32
+ if not detections.size(0):
33
+ continue
34
+
35
+ if class_agnostic:
36
+ nms_out_index = torchvision.ops.nms(
37
+ detections[:, :4],
38
+ detections[:, 4] * detections[:, 5],
39
+ nms_thre,
40
+ )
41
+ else:
42
+ nms_out_index = torchvision.ops.batched_nms(
43
+ detections[:, :4],
44
+ detections[:, 4] * detections[:, 5],
45
+ detections[:, 6],
46
+ nms_thre,
47
+ )
48
+
49
+ detections = detections[nms_out_index]
50
+ if output[i] is None:
51
+ output[i] = detections
52
+ else:
53
+ output[i] = torch.cat((output[i], detections))
54
+
55
+ return output
yolox/darknet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from torch import nn
6
+
7
+ from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
8
+
9
+
10
+ class Darknet(nn.Module):
11
+ # number of blocks from dark2 to dark5.
12
+ depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
13
+
14
+ def __init__(
15
+ self,
16
+ depth,
17
+ in_channels=3,
18
+ stem_out_channels=32,
19
+ out_features=("dark3", "dark4", "dark5"),
20
+ ):
21
+ """
22
+ Args:
23
+ depth (int): depth of darknet used in model, usually use [21, 53] for this param.
24
+ in_channels (int): number of input channels, for example, use 3 for RGB image.
25
+ stem_out_channels (int): number of output channels of darknet stem.
26
+ It decides channels of darknet layer2 to layer5.
27
+ out_features (Tuple[str]): desired output layer name.
28
+ """
29
+ super().__init__()
30
+ assert out_features, "please provide output features of Darknet"
31
+ self.out_features = out_features
32
+ self.stem = nn.Sequential(
33
+ BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
34
+ *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
35
+ )
36
+ in_channels = stem_out_channels * 2 # 64
37
+
38
+ num_blocks = Darknet.depth2blocks[depth]
39
+ # create darknet with `stem_out_channels` and `num_blocks` layers.
40
+ # to make model structure more clear, we don't use `for` statement in python.
41
+ self.dark2 = nn.Sequential(
42
+ *self.make_group_layer(in_channels, num_blocks[0], stride=2)
43
+ )
44
+ in_channels *= 2 # 128
45
+ self.dark3 = nn.Sequential(
46
+ *self.make_group_layer(in_channels, num_blocks[1], stride=2)
47
+ )
48
+ in_channels *= 2 # 256
49
+ self.dark4 = nn.Sequential(
50
+ *self.make_group_layer(in_channels, num_blocks[2], stride=2)
51
+ )
52
+ in_channels *= 2 # 512
53
+
54
+ self.dark5 = nn.Sequential(
55
+ *self.make_group_layer(in_channels, num_blocks[3], stride=2),
56
+ *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
57
+ )
58
+
59
+ def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
60
+ "starts with conv layer then has `num_blocks` `ResLayer`"
61
+ return [
62
+ BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
63
+ *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
64
+ ]
65
+
66
+ def make_spp_block(self, filters_list, in_filters):
67
+ m = nn.Sequential(
68
+ *[
69
+ BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
70
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
71
+ SPPBottleneck(
72
+ in_channels=filters_list[1],
73
+ out_channels=filters_list[0],
74
+ activation="lrelu",
75
+ ),
76
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
77
+ BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
78
+ ]
79
+ )
80
+ return m
81
+
82
+ def forward(self, x):
83
+ outputs = {}
84
+ x = self.stem(x)
85
+ outputs["stem"] = x
86
+ x = self.dark2(x)
87
+ outputs["dark2"] = x
88
+ x = self.dark3(x)
89
+ outputs["dark3"] = x
90
+ x = self.dark4(x)
91
+ outputs["dark4"] = x
92
+ x = self.dark5(x)
93
+ outputs["dark5"] = x
94
+ return {k: v for k, v in outputs.items() if k in self.out_features}
95
+
96
+
97
+ class CSPDarknet(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dep_mul,
101
+ wid_mul,
102
+ out_features=("dark3", "dark4", "dark5"),
103
+ depthwise=False,
104
+ act="silu",
105
+ ):
106
+ super().__init__()
107
+ assert out_features, "please provide output features of Darknet"
108
+ self.out_features = out_features
109
+ Conv = DWConv if depthwise else BaseConv
110
+
111
+ base_channels = int(wid_mul * 64) # 64
112
+ base_depth = max(round(dep_mul * 3), 1) # 3
113
+
114
+ # stem
115
+ self.stem = Focus(3, base_channels, ksize=3, act=act)
116
+
117
+ # dark2
118
+ self.dark2 = nn.Sequential(
119
+ Conv(base_channels, base_channels * 2, 3, 2, act=act),
120
+ CSPLayer(
121
+ base_channels * 2,
122
+ base_channels * 2,
123
+ n=base_depth,
124
+ depthwise=depthwise,
125
+ act=act,
126
+ ),
127
+ )
128
+
129
+ # dark3
130
+ self.dark3 = nn.Sequential(
131
+ Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
132
+ CSPLayer(
133
+ base_channels * 4,
134
+ base_channels * 4,
135
+ n=base_depth * 3,
136
+ depthwise=depthwise,
137
+ act=act,
138
+ ),
139
+ )
140
+
141
+ # dark4
142
+ self.dark4 = nn.Sequential(
143
+ Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
144
+ CSPLayer(
145
+ base_channels * 8,
146
+ base_channels * 8,
147
+ n=base_depth * 3,
148
+ depthwise=depthwise,
149
+ act=act,
150
+ ),
151
+ )
152
+
153
+ # dark5
154
+ self.dark5 = nn.Sequential(
155
+ Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
156
+ SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
157
+ CSPLayer(
158
+ base_channels * 16,
159
+ base_channels * 16,
160
+ n=base_depth,
161
+ shortcut=False,
162
+ depthwise=depthwise,
163
+ act=act,
164
+ ),
165
+ )
166
+
167
+ def forward(self, x):
168
+ outputs = {}
169
+ x = self.stem(x)
170
+ outputs["stem"] = x
171
+ x = self.dark2(x)
172
+ outputs["dark2"] = x
173
+ x = self.dark3(x)
174
+ outputs["dark3"] = x
175
+ x = self.dark4(x)
176
+ outputs["dark4"] = x
177
+ x = self.dark5(x)
178
+ outputs["dark5"] = x
179
+ return {k: v for k, v in outputs.items() if k in self.out_features}
yolox/network_blocks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class SiLU(nn.Module):
10
+ """export-friendly version of nn.SiLU()"""
11
+
12
+ @staticmethod
13
+ def forward(x):
14
+ return x * torch.sigmoid(x)
15
+
16
+
17
+ def get_activation(name="silu", inplace=True):
18
+ if name == "silu":
19
+ module = nn.SiLU(inplace=inplace)
20
+ elif name == "relu":
21
+ module = nn.ReLU(inplace=inplace)
22
+ elif name == "lrelu":
23
+ module = nn.LeakyReLU(0.1, inplace=inplace)
24
+ else:
25
+ raise AttributeError("Unsupported act type: {}".format(name))
26
+ return module
27
+
28
+
29
+ class BaseConv(nn.Module):
30
+ """A Conv2d -> Batchnorm -> silu/leaky relu block"""
31
+
32
+ def __init__(
33
+ self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
34
+ ):
35
+ super().__init__()
36
+ # same padding
37
+ pad = (ksize - 1) // 2
38
+ self.conv = nn.Conv2d(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size=ksize,
42
+ stride=stride,
43
+ padding=pad,
44
+ groups=groups,
45
+ bias=bias,
46
+ )
47
+ self.bn = nn.BatchNorm2d(out_channels)
48
+ self.act = get_activation(act, inplace=True)
49
+
50
+ def forward(self, x):
51
+ return self.act(self.bn(self.conv(x)))
52
+
53
+ def fuseforward(self, x):
54
+ return self.act(self.conv(x))
55
+
56
+
57
+ class DWConv(nn.Module):
58
+ """Depthwise Conv + Conv"""
59
+
60
+ def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
61
+ super().__init__()
62
+ self.dconv = BaseConv(
63
+ in_channels,
64
+ in_channels,
65
+ ksize=ksize,
66
+ stride=stride,
67
+ groups=in_channels,
68
+ act=act,
69
+ )
70
+ self.pconv = BaseConv(
71
+ in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
72
+ )
73
+
74
+ def forward(self, x):
75
+ x = self.dconv(x)
76
+ return self.pconv(x)
77
+
78
+
79
+ class Bottleneck(nn.Module):
80
+ # Standard bottleneck
81
+ def __init__(
82
+ self,
83
+ in_channels,
84
+ out_channels,
85
+ shortcut=True,
86
+ expansion=0.5,
87
+ depthwise=False,
88
+ act="silu",
89
+ ):
90
+ super().__init__()
91
+ hidden_channels = int(out_channels * expansion)
92
+ Conv = DWConv if depthwise else BaseConv
93
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
94
+ self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
95
+ self.use_add = shortcut and in_channels == out_channels
96
+
97
+ def forward(self, x):
98
+ y = self.conv2(self.conv1(x))
99
+ if self.use_add:
100
+ y = y + x
101
+ return y
102
+
103
+
104
+ class ResLayer(nn.Module):
105
+ "Residual layer with `in_channels` inputs."
106
+
107
+ def __init__(self, in_channels: int):
108
+ super().__init__()
109
+ mid_channels = in_channels // 2
110
+ self.layer1 = BaseConv(
111
+ in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
112
+ )
113
+ self.layer2 = BaseConv(
114
+ mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
115
+ )
116
+
117
+ def forward(self, x):
118
+ out = self.layer2(self.layer1(x))
119
+ return x + out
120
+
121
+
122
+ class SPPBottleneck(nn.Module):
123
+ """Spatial pyramid pooling layer used in YOLOv3-SPP"""
124
+
125
+ def __init__(
126
+ self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
127
+ ):
128
+ super().__init__()
129
+ hidden_channels = in_channels // 2
130
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
131
+ self.m = nn.ModuleList(
132
+ [
133
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
134
+ for ks in kernel_sizes
135
+ ]
136
+ )
137
+ conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
138
+ self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
139
+
140
+ def forward(self, x):
141
+ x = self.conv1(x)
142
+ x = torch.cat([x] + [m(x) for m in self.m], dim=1)
143
+ x = self.conv2(x)
144
+ return x
145
+
146
+
147
+ class CSPLayer(nn.Module):
148
+ """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
149
+
150
+ def __init__(
151
+ self,
152
+ in_channels,
153
+ out_channels,
154
+ n=1,
155
+ shortcut=True,
156
+ expansion=0.5,
157
+ depthwise=False,
158
+ act="silu",
159
+ ):
160
+ """
161
+ Args:
162
+ in_channels (int): input channels.
163
+ out_channels (int): output channels.
164
+ n (int): number of Bottlenecks. Default value: 1.
165
+ """
166
+ # ch_in, ch_out, number, shortcut, groups, expansion
167
+ super().__init__()
168
+ hidden_channels = int(out_channels * expansion) # hidden channels
169
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
170
+ self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
171
+ self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
172
+ module_list = [
173
+ Bottleneck(
174
+ hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
175
+ )
176
+ for _ in range(n)
177
+ ]
178
+ self.m = nn.Sequential(*module_list)
179
+
180
+ def forward(self, x):
181
+ x_1 = self.conv1(x)
182
+ x_2 = self.conv2(x)
183
+ x_1 = self.m(x_1)
184
+ x = torch.cat((x_1, x_2), dim=1)
185
+ return self.conv3(x)
186
+
187
+
188
+ class Focus(nn.Module):
189
+ """Focus width and height information into channel space."""
190
+
191
+ def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
192
+ super().__init__()
193
+ self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
194
+
195
+ def forward(self, x):
196
+ # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
197
+ patch_top_left = x[..., ::2, ::2]
198
+ patch_top_right = x[..., ::2, 1::2]
199
+ patch_bot_left = x[..., 1::2, ::2]
200
+ patch_bot_right = x[..., 1::2, 1::2]
201
+ x = torch.cat(
202
+ (
203
+ patch_top_left,
204
+ patch_bot_left,
205
+ patch_top_right,
206
+ patch_bot_right,
207
+ ),
208
+ dim=1,
209
+ )
210
+ return self.conv(x)
yolox/yolo_fpn.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import Darknet
9
+ from .network_blocks import BaseConv
10
+
11
+
12
+ class YOLOFPN(nn.Module):
13
+ """
14
+ YOLOFPN module. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=53,
20
+ in_features=["dark3", "dark4", "dark5"],
21
+ ):
22
+ super().__init__()
23
+
24
+ self.backbone = Darknet(depth)
25
+ self.in_features = in_features
26
+
27
+ # out 1
28
+ self.out1_cbl = self._make_cbl(512, 256, 1)
29
+ self.out1 = self._make_embedding([256, 512], 512 + 256)
30
+
31
+ # out 2
32
+ self.out2_cbl = self._make_cbl(256, 128, 1)
33
+ self.out2 = self._make_embedding([128, 256], 256 + 128)
34
+
35
+ # upsample
36
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
37
+
38
+ def _make_cbl(self, _in, _out, ks):
39
+ return BaseConv(_in, _out, ks, stride=1, act="lrelu")
40
+
41
+ def _make_embedding(self, filters_list, in_filters):
42
+ m = nn.Sequential(
43
+ *[
44
+ self._make_cbl(in_filters, filters_list[0], 1),
45
+ self._make_cbl(filters_list[0], filters_list[1], 3),
46
+ self._make_cbl(filters_list[1], filters_list[0], 1),
47
+ self._make_cbl(filters_list[0], filters_list[1], 3),
48
+ self._make_cbl(filters_list[1], filters_list[0], 1),
49
+ ]
50
+ )
51
+ return m
52
+
53
+ def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
54
+ with open(filename, "rb") as f:
55
+ state_dict = torch.load(f, map_location="cpu")
56
+ print("loading pretrained weights...")
57
+ self.backbone.load_state_dict(state_dict)
58
+
59
+ def forward(self, inputs):
60
+ """
61
+ Args:
62
+ inputs (Tensor): input image.
63
+
64
+ Returns:
65
+ Tuple[Tensor]: FPN output features..
66
+ """
67
+ # backbone
68
+ out_features = self.backbone(inputs)
69
+ x2, x1, x0 = [out_features[f] for f in self.in_features]
70
+
71
+ # yolo branch 1
72
+ x1_in = self.out1_cbl(x0)
73
+ x1_in = self.upsample(x1_in)
74
+ x1_in = torch.cat([x1_in, x1], 1)
75
+ out_dark4 = self.out1(x1_in)
76
+
77
+ # yolo branch 2
78
+ x2_in = self.out2_cbl(out_dark4)
79
+ x2_in = self.upsample(x2_in)
80
+ x2_in = torch.cat([x2_in, x2], 1)
81
+ out_dark3 = self.out2(x2_in)
82
+
83
+ outputs = (out_dark3, out_dark4, x0)
84
+ return outputs
yolox/yolo_head.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from .network_blocks import BaseConv, DWConv
8
+
9
+
10
+ _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
11
+
12
+
13
+ def meshgrid(*tensors):
14
+ """
15
+ Copied from YOLOX/yolox/utils/compat.py
16
+ """
17
+ if _TORCH_VER >= [1, 10]:
18
+ return torch.meshgrid(*tensors, indexing="ij")
19
+ else:
20
+ return torch.meshgrid(*tensors)
21
+
22
+
23
+ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
24
+ """
25
+ Copied from YOLOX/yolox/utils/boxes.py
26
+ """
27
+ if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
28
+ raise IndexError
29
+
30
+ if xyxy:
31
+ tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
32
+ br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
33
+ area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
34
+ area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
35
+ else:
36
+ tl = torch.max(
37
+ (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
38
+ (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
39
+ )
40
+ br = torch.min(
41
+ (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
42
+ (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
43
+ )
44
+
45
+ area_a = torch.prod(bboxes_a[:, 2:], 1)
46
+ area_b = torch.prod(bboxes_b[:, 2:], 1)
47
+ en = (tl < br).type(tl.type()).prod(dim=2)
48
+ area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
49
+ return area_i / (area_a[:, None] + area_b - area_i)
50
+
51
+
52
+ class YOLOXHead(nn.Module):
53
+ def __init__(
54
+ self,
55
+ num_classes,
56
+ width=1.0,
57
+ strides=[8, 16, 32],
58
+ in_channels=[256, 512, 1024],
59
+ act="silu",
60
+ depthwise=False,
61
+ ):
62
+ """
63
+ Args:
64
+ act (str): activation type of conv. Defalut value: "silu".
65
+ depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
66
+ """
67
+ super().__init__()
68
+
69
+ self.num_classes = num_classes
70
+ self.decode_in_inference = True # for deploy, set to False
71
+
72
+ self.cls_convs = nn.ModuleList()
73
+ self.reg_convs = nn.ModuleList()
74
+ self.cls_preds = nn.ModuleList()
75
+ self.reg_preds = nn.ModuleList()
76
+ self.obj_preds = nn.ModuleList()
77
+ self.stems = nn.ModuleList()
78
+ Conv = DWConv if depthwise else BaseConv
79
+
80
+ for i in range(len(in_channels)):
81
+ self.stems.append(
82
+ BaseConv(
83
+ in_channels=int(in_channels[i] * width),
84
+ out_channels=int(256 * width),
85
+ ksize=1,
86
+ stride=1,
87
+ act=act,
88
+ )
89
+ )
90
+ self.cls_convs.append(
91
+ nn.Sequential(
92
+ *[
93
+ Conv(
94
+ in_channels=int(256 * width),
95
+ out_channels=int(256 * width),
96
+ ksize=3,
97
+ stride=1,
98
+ act=act,
99
+ ),
100
+ Conv(
101
+ in_channels=int(256 * width),
102
+ out_channels=int(256 * width),
103
+ ksize=3,
104
+ stride=1,
105
+ act=act,
106
+ ),
107
+ ]
108
+ )
109
+ )
110
+ self.reg_convs.append(
111
+ nn.Sequential(
112
+ *[
113
+ Conv(
114
+ in_channels=int(256 * width),
115
+ out_channels=int(256 * width),
116
+ ksize=3,
117
+ stride=1,
118
+ act=act,
119
+ ),
120
+ Conv(
121
+ in_channels=int(256 * width),
122
+ out_channels=int(256 * width),
123
+ ksize=3,
124
+ stride=1,
125
+ act=act,
126
+ ),
127
+ ]
128
+ )
129
+ )
130
+ self.cls_preds.append(
131
+ nn.Conv2d(
132
+ in_channels=int(256 * width),
133
+ out_channels=self.num_classes,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0,
137
+ )
138
+ )
139
+ self.reg_preds.append(
140
+ nn.Conv2d(
141
+ in_channels=int(256 * width),
142
+ out_channels=4,
143
+ kernel_size=1,
144
+ stride=1,
145
+ padding=0,
146
+ )
147
+ )
148
+ self.obj_preds.append(
149
+ nn.Conv2d(
150
+ in_channels=int(256 * width),
151
+ out_channels=1,
152
+ kernel_size=1,
153
+ stride=1,
154
+ padding=0,
155
+ )
156
+ )
157
+
158
+ self.use_l1 = False
159
+ self.l1_loss = nn.L1Loss(reduction="none")
160
+ self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
161
+ self.iou_loss = None
162
+ self.strides = strides
163
+ self.grids = [torch.zeros(1)] * len(in_channels)
164
+
165
+ def forward(self, xin, labels=None, imgs=None):
166
+ outputs = []
167
+ for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
168
+ zip(self.cls_convs, self.reg_convs, self.strides, xin)
169
+ ):
170
+ x = self.stems[k](x)
171
+ cls_x = x
172
+ reg_x = x
173
+
174
+ cls_feat = cls_conv(cls_x)
175
+ cls_output = self.cls_preds[k](cls_feat)
176
+
177
+ reg_feat = reg_conv(reg_x)
178
+ reg_output = self.reg_preds[k](reg_feat)
179
+ obj_output = self.obj_preds[k](reg_feat)
180
+
181
+ output = torch.cat(
182
+ [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
183
+ )
184
+
185
+ outputs.append(output)
186
+
187
+ self.hw = [x.shape[-2:] for x in outputs]
188
+ # [batch, n_anchors_all, 85]
189
+ outputs = torch.cat(
190
+ [x.flatten(start_dim=2) for x in outputs], dim=2
191
+ ).permute(0, 2, 1)
192
+ if self.decode_in_inference:
193
+ return self.decode_outputs(outputs, dtype=xin[0].type())
194
+ else:
195
+ return outputs
196
+
197
+ def get_output_and_grid(self, output, k, stride, dtype):
198
+ grid = self.grids[k]
199
+
200
+ batch_size = output.shape[0]
201
+ n_ch = 5 + self.num_classes
202
+ hsize, wsize = output.shape[-2:]
203
+ if grid.shape[2:4] != output.shape[2:4]:
204
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
205
+ grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
206
+ self.grids[k] = grid
207
+
208
+ output = output.view(batch_size, 1, n_ch, hsize, wsize)
209
+ output = output.permute(0, 1, 3, 4, 2).reshape(
210
+ batch_size, hsize * wsize, -1
211
+ )
212
+ grid = grid.view(1, -1, 2)
213
+ output[..., :2] = (output[..., :2] + grid) * stride
214
+ output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
215
+ return output, grid
216
+
217
+ def decode_outputs(self, outputs, dtype):
218
+ grids = []
219
+ strides = []
220
+ for (hsize, wsize), stride in zip(self.hw, self.strides):
221
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
222
+ grid = torch.stack((xv, yv), 2).view(1, -1, 2)
223
+ grids.append(grid)
224
+ shape = grid.shape[:2]
225
+ strides.append(torch.full((*shape, 1), stride))
226
+
227
+ grids = torch.cat(grids, dim=1).type(dtype)
228
+ strides = torch.cat(strides, dim=1).type(dtype)
229
+
230
+ outputs = torch.cat([
231
+ (outputs[..., 0:2] + grids) * strides,
232
+ torch.exp(outputs[..., 2:4]) * strides,
233
+ outputs[..., 4:]
234
+ ], dim=-1)
235
+ return outputs
yolox/yolo_pafpn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import CSPDarknet
9
+ from .network_blocks import BaseConv, CSPLayer, DWConv
10
+
11
+
12
+ class YOLOPAFPN(nn.Module):
13
+ """
14
+ YOLOv3 model. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=1.0,
20
+ width=1.0,
21
+ in_features=("dark3", "dark4", "dark5"),
22
+ in_channels=[256, 512, 1024],
23
+ depthwise=False,
24
+ act="silu",
25
+ ):
26
+ super().__init__()
27
+ self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
28
+ self.in_features = in_features
29
+ self.in_channels = in_channels
30
+ Conv = DWConv if depthwise else BaseConv
31
+
32
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
33
+ self.lateral_conv0 = BaseConv(
34
+ int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
35
+ )
36
+ self.C3_p4 = CSPLayer(
37
+ int(2 * in_channels[1] * width),
38
+ int(in_channels[1] * width),
39
+ round(3 * depth),
40
+ False,
41
+ depthwise=depthwise,
42
+ act=act,
43
+ ) # cat
44
+
45
+ self.reduce_conv1 = BaseConv(
46
+ int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
47
+ )
48
+ self.C3_p3 = CSPLayer(
49
+ int(2 * in_channels[0] * width),
50
+ int(in_channels[0] * width),
51
+ round(3 * depth),
52
+ False,
53
+ depthwise=depthwise,
54
+ act=act,
55
+ )
56
+
57
+ # bottom-up conv
58
+ self.bu_conv2 = Conv(
59
+ int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
60
+ )
61
+ self.C3_n3 = CSPLayer(
62
+ int(2 * in_channels[0] * width),
63
+ int(in_channels[1] * width),
64
+ round(3 * depth),
65
+ False,
66
+ depthwise=depthwise,
67
+ act=act,
68
+ )
69
+
70
+ # bottom-up conv
71
+ self.bu_conv1 = Conv(
72
+ int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
73
+ )
74
+ self.C3_n4 = CSPLayer(
75
+ int(2 * in_channels[1] * width),
76
+ int(in_channels[2] * width),
77
+ round(3 * depth),
78
+ False,
79
+ depthwise=depthwise,
80
+ act=act,
81
+ )
82
+
83
+ def forward(self, input):
84
+ """
85
+ Args:
86
+ inputs: input images.
87
+
88
+ Returns:
89
+ Tuple[Tensor]: FPN feature.
90
+ """
91
+
92
+ # backbone
93
+ out_features = self.backbone(input)
94
+ features = [out_features[f] for f in self.in_features]
95
+ [x2, x1, x0] = features
96
+
97
+ fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
98
+ f_out0 = self.upsample(fpn_out0) # 512/16
99
+ f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
100
+ f_out0 = self.C3_p4(f_out0) # 1024->512/16
101
+
102
+ fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
103
+ f_out1 = self.upsample(fpn_out1) # 256/8
104
+ f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
105
+ pan_out2 = self.C3_p3(f_out1) # 512->256/8
106
+
107
+ p_out1 = self.bu_conv2(pan_out2) # 256->256/16
108
+ p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
109
+ pan_out1 = self.C3_n3(p_out1) # 512->512/16
110
+
111
+ p_out0 = self.bu_conv1(pan_out1) # 512->512/32
112
+ p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
113
+ pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
114
+
115
+ outputs = (pan_out2, pan_out1, pan_out0)
116
+ return outputs
yolox/yolox.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch.nn as nn
6
+
7
+ from .yolo_head import YOLOXHead
8
+ from .yolo_pafpn import YOLOPAFPN
9
+
10
+
11
+ class YOLOX(nn.Module):
12
+ """
13
+ YOLOX model module. The module list is defined by create_yolov3_modules function.
14
+ The network returns loss values from three YOLO layers during training
15
+ and detection results during test.
16
+ """
17
+
18
+ def __init__(self, backbone=None, head=None):
19
+ super().__init__()
20
+ if backbone is None:
21
+ backbone = YOLOPAFPN()
22
+ if head is None:
23
+ head = YOLOXHead(80)
24
+
25
+ self.backbone = backbone
26
+ self.head = head
27
+
28
+ def forward(self, x, targets=None):
29
+ assert not self.training, "Training mode not supported, please refer to the YOLOX repo"
30
+ fpn_outs = self.backbone(x)
31
+ outputs = self.head(fpn_outs)
32
+ return outputs