|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import importlib |
|
|
import numpy as np |
|
|
import numpy.typing as npt |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, List, Tuple, Union |
|
|
from yolox.boxes import postprocess |
|
|
|
|
|
|
|
|
def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module: |
|
|
""" |
|
|
Defines and initializes the model based on the configuration. |
|
|
|
|
|
Args: |
|
|
config_name (str): Configuration name. Defaults to "graphic_element_v1". |
|
|
verbose (bool): Whether to print verbose output. Defaults to True. |
|
|
|
|
|
Returns: |
|
|
torch.nn.Module: The initialized YOLOX model. |
|
|
""" |
|
|
|
|
|
sys.path.append(os.path.dirname(config_name)) |
|
|
exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0]) |
|
|
|
|
|
config = exp_module.Exp() |
|
|
model = config.get_model() |
|
|
|
|
|
|
|
|
if verbose: |
|
|
print(" -> Loading weights from", config.ckpt) |
|
|
|
|
|
ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False) |
|
|
model.load_state_dict(ckpt["model"], strict=True) |
|
|
|
|
|
model = YoloXWrapper(model, config) |
|
|
return model.eval().to(config.device) |
|
|
|
|
|
|
|
|
def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor: |
|
|
""" |
|
|
Resizes and pads an image to a given size. |
|
|
The goal is to preserve the aspect ratio of the image. |
|
|
|
|
|
Args: |
|
|
img (torch.Tensor[C x H x W]): The image to resize and pad. |
|
|
size (tuple[2]): The size to resize and pad the image to. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The resized and padded image. |
|
|
""" |
|
|
img = img.float() |
|
|
_, h, w = img.shape |
|
|
scale = min(size[0] / h, size[1] / w) |
|
|
nh = int(h * scale) |
|
|
nw = int(w * scale) |
|
|
img = F.interpolate( |
|
|
img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False |
|
|
).squeeze(0) |
|
|
img = torch.clamp(img, 0, 255) |
|
|
pad_b = size[0] - nh |
|
|
pad_r = size[1] - nw |
|
|
img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0) |
|
|
return img |
|
|
|
|
|
|
|
|
class YoloXWrapper(nn.Module): |
|
|
""" |
|
|
Wrapper for YoloX models. |
|
|
""" |
|
|
def __init__(self, model: nn.Module, config) -> None: |
|
|
""" |
|
|
Constructor |
|
|
|
|
|
Args: |
|
|
model (torch model): Yolo model. |
|
|
config (Config): Config object containing model parameters. |
|
|
""" |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.device = config.device |
|
|
self.img_size = config.size |
|
|
self.min_bbox_size = config.min_bbox_size |
|
|
self.normalize_boxes = config.normalize_boxes |
|
|
self.conf_thresh = config.conf_thresh |
|
|
self.iou_thresh = config.iou_thresh |
|
|
self.class_agnostic = config.class_agnostic |
|
|
self.threshold = config.threshold |
|
|
self.labels = config.labels |
|
|
self.num_classes = config.num_classes |
|
|
|
|
|
def reformat_input( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray] |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Reformats the input data and original sizes to the correct format. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor[BS x C x H x W]): Input image batch. |
|
|
orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes. |
|
|
Returns: |
|
|
torch tensor [BS x C x H x W]: Input image batch. |
|
|
torch tensor [BS x 2]: Original image sizes (before resizing and padding). |
|
|
""" |
|
|
|
|
|
if isinstance(orig_sizes, (list, tuple)): |
|
|
orig_sizes = np.array(orig_sizes) |
|
|
if orig_sizes.shape[-1] == 3: |
|
|
orig_sizes = orig_sizes[..., :2] |
|
|
if isinstance(orig_sizes, np.ndarray): |
|
|
orig_sizes = torch.from_numpy(orig_sizes).to(self.device) |
|
|
|
|
|
|
|
|
if len(x.size()) == 3: |
|
|
x = x.unsqueeze(0) |
|
|
if len(orig_sizes.size()) == 1: |
|
|
orig_sizes = orig_sizes.unsqueeze(0) |
|
|
|
|
|
return x, orig_sizes |
|
|
|
|
|
def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor: |
|
|
""" |
|
|
YoloX preprocessing function: |
|
|
- Resizes to the longest edge to img_size while preserving the aspect ratio |
|
|
- Pads the shortest edge to img_size |
|
|
|
|
|
Args: |
|
|
image (torch tensor or np array [H x W x 3]): Input images in uint8 format. |
|
|
|
|
|
Returns: |
|
|
torch tensor [3 x H x W]: Processed image. |
|
|
""" |
|
|
if not isinstance(image, torch.Tensor): |
|
|
image = torch.from_numpy(image) |
|
|
image = image.permute(2, 0, 1) |
|
|
image = resize_pad(image, self.img_size) |
|
|
return image.float() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray] |
|
|
) -> List[Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Forward pass of the model. |
|
|
Applies NMS and reformats the predictions. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor[BS x C x H x W]): Input image batch. |
|
|
orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes. |
|
|
|
|
|
Returns: |
|
|
list[dict]: List of prediction dictionaries. Each dictionary contains: |
|
|
- labels (torch.Tensor[N]): Class labels |
|
|
- boxes (torch.Tensor[N x 4]): Bounding boxes |
|
|
- scores (torch.Tensor[N]): Confidence scores. |
|
|
""" |
|
|
x, orig_sizes = self.reformat_input(x, orig_sizes) |
|
|
|
|
|
|
|
|
if x.max() <= 1: |
|
|
x *= 255 |
|
|
|
|
|
pred_boxes = self.model(x.to(self.device)) |
|
|
|
|
|
|
|
|
pred_boxes = postprocess( |
|
|
pred_boxes, |
|
|
self.config.num_classes, |
|
|
self.conf_thresh, |
|
|
self.iou_thresh, |
|
|
class_agnostic=self.class_agnostic, |
|
|
) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)): |
|
|
if p is None: |
|
|
preds.append({ |
|
|
"labels": torch.empty(0), |
|
|
"boxes": torch.empty((0, 4)), |
|
|
"scores": torch.empty(0), |
|
|
}) |
|
|
continue |
|
|
|
|
|
p = p.view(-1, p.size(-1)) |
|
|
ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1]) |
|
|
boxes = p[:, :4] / ratio |
|
|
|
|
|
|
|
|
boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1]) |
|
|
boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0]) |
|
|
|
|
|
|
|
|
kept = ( |
|
|
(boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) & |
|
|
(boxes[:, 3] - boxes[:, 1] > self.min_bbox_size) |
|
|
) |
|
|
boxes = boxes[kept] |
|
|
p = p[kept] |
|
|
|
|
|
|
|
|
if self.normalize_boxes: |
|
|
boxes[:, [0, 2]] /= size[1] |
|
|
boxes[:, [1, 3]] /= size[0] |
|
|
|
|
|
scores = p[:, 4] * p[:, 5] |
|
|
labels = p[:, 6] |
|
|
|
|
|
preds.append({"labels": labels, "boxes": boxes, "scores": scores}) |
|
|
|
|
|
return preds |
|
|
|