asammoud
commited on
Commit
·
3f2c461
1
Parent(s):
e912493
add redetr
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- rfdetr/__init__.py +12 -0
- rfdetr/__pycache__/__init__.cpython-310.pyc +0 -0
- rfdetr/__pycache__/__init__.cpython-312.pyc +0 -0
- rfdetr/__pycache__/__init__.cpython-39.pyc +0 -0
- rfdetr/__pycache__/config.cpython-310.pyc +0 -0
- rfdetr/__pycache__/config.cpython-312.pyc +0 -0
- rfdetr/__pycache__/config.cpython-39.pyc +0 -0
- rfdetr/__pycache__/detr.cpython-310.pyc +0 -0
- rfdetr/__pycache__/detr.cpython-312.pyc +0 -0
- rfdetr/__pycache__/detr.cpython-39.pyc +0 -0
- rfdetr/__pycache__/engine.cpython-310.pyc +0 -0
- rfdetr/__pycache__/engine.cpython-312.pyc +0 -0
- rfdetr/__pycache__/engine.cpython-39.pyc +0 -0
- rfdetr/__pycache__/main.cpython-310.pyc +0 -0
- rfdetr/__pycache__/main.cpython-312.pyc +0 -0
- rfdetr/__pycache__/main.cpython-39.pyc +0 -0
- rfdetr/cli/main.py +87 -0
- rfdetr/config.py +90 -0
- rfdetr/datasets/__init__.py +36 -0
- rfdetr/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- rfdetr/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
- rfdetr/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- rfdetr/datasets/__pycache__/coco.cpython-310.pyc +0 -0
- rfdetr/datasets/__pycache__/coco.cpython-312.pyc +0 -0
- rfdetr/datasets/__pycache__/coco.cpython-39.pyc +0 -0
- rfdetr/datasets/__pycache__/coco_eval.cpython-310.pyc +0 -0
- rfdetr/datasets/__pycache__/coco_eval.cpython-312.pyc +0 -0
- rfdetr/datasets/__pycache__/coco_eval.cpython-39.pyc +0 -0
- rfdetr/datasets/__pycache__/o365.cpython-310.pyc +0 -0
- rfdetr/datasets/__pycache__/o365.cpython-312.pyc +0 -0
- rfdetr/datasets/__pycache__/o365.cpython-39.pyc +0 -0
- rfdetr/datasets/__pycache__/transforms.cpython-310.pyc +0 -0
- rfdetr/datasets/__pycache__/transforms.cpython-312.pyc +0 -0
- rfdetr/datasets/__pycache__/transforms.cpython-39.pyc +0 -0
- rfdetr/datasets/coco.py +250 -0
- rfdetr/datasets/coco_eval.py +271 -0
- rfdetr/datasets/o365.py +53 -0
- rfdetr/datasets/transforms.py +475 -0
- rfdetr/deploy/__init__.py +0 -0
- rfdetr/deploy/_onnx/__init__.py +13 -0
- rfdetr/deploy/_onnx/optimizer.py +579 -0
- rfdetr/deploy/_onnx/symbolic.py +37 -0
- rfdetr/deploy/benchmark.py +590 -0
- rfdetr/deploy/export.py +276 -0
- rfdetr/deploy/requirements.txt +8 -0
- rfdetr/detr.py +315 -0
- rfdetr/engine.py +256 -0
- rfdetr/main.py +1034 -0
- rfdetr/models/__init__.py +16 -0
- rfdetr/models/__pycache__/__init__.cpython-310.pyc +0 -0
rfdetr/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
|
| 10 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 11 |
+
|
| 12 |
+
from rfdetr.detr import RFDETRBase, RFDETRLarge
|
rfdetr/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (341 Bytes). View file
|
|
|
rfdetr/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (441 Bytes). View file
|
|
|
rfdetr/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (314 Bytes). View file
|
|
|
rfdetr/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
rfdetr/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
rfdetr/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
rfdetr/__pycache__/detr.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
rfdetr/__pycache__/detr.cpython-312.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
rfdetr/__pycache__/detr.cpython-39.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
rfdetr/__pycache__/engine.cpython-310.pyc
ADDED
|
Binary file (7.08 kB). View file
|
|
|
rfdetr/__pycache__/engine.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
rfdetr/__pycache__/engine.cpython-39.pyc
ADDED
|
Binary file (7.13 kB). View file
|
|
|
rfdetr/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (25.9 kB). View file
|
|
|
rfdetr/__pycache__/main.cpython-312.pyc
ADDED
|
Binary file (45.4 kB). View file
|
|
|
rfdetr/__pycache__/main.cpython-39.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
rfdetr/cli/main.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
from rf100vl import get_rf100vl_projects
|
| 12 |
+
import roboflow
|
| 13 |
+
from rfdetr import RFDETRBase
|
| 14 |
+
import torch
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def download_dataset(rf_project: roboflow.Project, dataset_version: int):
|
| 18 |
+
versions = rf_project.versions()
|
| 19 |
+
if dataset_version is not None:
|
| 20 |
+
versions = [v for v in versions if v.version == str(dataset_version)]
|
| 21 |
+
if len(versions) == 0:
|
| 22 |
+
raise ValueError(f"Dataset version {dataset_version} not found")
|
| 23 |
+
version = versions[0]
|
| 24 |
+
else:
|
| 25 |
+
version = max(versions, key=lambda v: v.id)
|
| 26 |
+
location = os.path.join("datasets/", rf_project.name + "_v" + version.version)
|
| 27 |
+
if not os.path.exists(location):
|
| 28 |
+
location = version.download(
|
| 29 |
+
model_format="coco", location=location, overwrite=False
|
| 30 |
+
).location
|
| 31 |
+
|
| 32 |
+
return location
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int):
|
| 36 |
+
location = download_dataset(rf_project, dataset_version)
|
| 37 |
+
print(location)
|
| 38 |
+
rf_detr = RFDETRBase()
|
| 39 |
+
device_supports_cuda = torch.cuda.is_available()
|
| 40 |
+
rf_detr.train(
|
| 41 |
+
dataset_dir=location,
|
| 42 |
+
epochs=1,
|
| 43 |
+
device="cuda" if device_supports_cuda else "cpu",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def train_from_coco_dir(coco_dir: str):
|
| 48 |
+
rf_detr = RFDETRBase()
|
| 49 |
+
rf_detr.train(
|
| 50 |
+
dataset_dir=coco_dir,
|
| 51 |
+
epochs=1,
|
| 52 |
+
device="cuda" if device_supports_cuda else "cpu",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def trainer():
|
| 57 |
+
parser = argparse.ArgumentParser()
|
| 58 |
+
parser.add_argument("--coco_dir", type=str, required=False)
|
| 59 |
+
parser.add_argument("--api_key", type=str, required=False)
|
| 60 |
+
parser.add_argument("--workspace", type=str, required=False, default=None)
|
| 61 |
+
parser.add_argument("--project_name", type=str, required=False, default=None)
|
| 62 |
+
parser.add_argument("--dataset_version", type=int, required=False, default=None)
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
if args.coco_dir is not None:
|
| 66 |
+
train_from_coco_dir(args.coco_dir)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
if (args.workspace is None and args.project_name is not None) or (
|
| 70 |
+
args.workspace is not None and args.project_name is None
|
| 71 |
+
):
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"Either both workspace and project_name must be provided or none of them"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if args.workspace is not None:
|
| 77 |
+
rf = roboflow.Roboflow(api_key=args.api_key)
|
| 78 |
+
project = rf.workspace(args.workspace).project(args.project_name)
|
| 79 |
+
else:
|
| 80 |
+
projects = get_rf100vl_projects(api_key=args.api_key)
|
| 81 |
+
project = projects[0].rf_project
|
| 82 |
+
|
| 83 |
+
train_from_rf_project(project, args.dataset_version)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
trainer()
|
rfdetr/config.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from typing import List, Optional, Literal, Type
|
| 10 |
+
import torch
|
| 11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 12 |
+
|
| 13 |
+
class ModelConfig(BaseModel):
|
| 14 |
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
|
| 15 |
+
out_feature_indexes: List[int]
|
| 16 |
+
dec_layers: int = 3
|
| 17 |
+
two_stage: bool = True
|
| 18 |
+
projector_scale: List[Literal["P3", "P4", "P5"]]
|
| 19 |
+
hidden_dim: int
|
| 20 |
+
sa_nheads: int
|
| 21 |
+
ca_nheads: int
|
| 22 |
+
dec_n_points: int
|
| 23 |
+
bbox_reparam: bool = True
|
| 24 |
+
lite_refpoint_refine: bool = True
|
| 25 |
+
layer_norm: bool = True
|
| 26 |
+
amp: bool = True
|
| 27 |
+
num_classes: int = 90
|
| 28 |
+
pretrain_weights: Optional[str] = None
|
| 29 |
+
device: Literal["cpu", "cuda", "mps"] = DEVICE
|
| 30 |
+
resolution: int = 560
|
| 31 |
+
group_detr: int = 13
|
| 32 |
+
gradient_checkpointing: bool = False
|
| 33 |
+
|
| 34 |
+
class RFDETRBaseConfig(ModelConfig):
|
| 35 |
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
|
| 36 |
+
hidden_dim: int = 256
|
| 37 |
+
sa_nheads: int = 8
|
| 38 |
+
ca_nheads: int = 16
|
| 39 |
+
dec_n_points: int = 2
|
| 40 |
+
num_queries: int = 300
|
| 41 |
+
num_select: int = 300
|
| 42 |
+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
|
| 43 |
+
out_feature_indexes: List[int] = [2, 5, 8, 11]
|
| 44 |
+
pretrain_weights: Optional[str] = "rf-detr-base.pth"
|
| 45 |
+
|
| 46 |
+
class RFDETRLargeConfig(RFDETRBaseConfig):
|
| 47 |
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
|
| 48 |
+
hidden_dim: int = 384
|
| 49 |
+
sa_nheads: int = 12
|
| 50 |
+
ca_nheads: int = 24
|
| 51 |
+
dec_n_points: int = 4
|
| 52 |
+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
|
| 53 |
+
pretrain_weights: Optional[str] = "rf-detr-large.pth"
|
| 54 |
+
|
| 55 |
+
class TrainConfig(BaseModel):
|
| 56 |
+
lr: float = 1e-4
|
| 57 |
+
lr_encoder: float = 1.5e-4
|
| 58 |
+
batch_size: int = 4
|
| 59 |
+
grad_accum_steps: int = 4
|
| 60 |
+
epochs: int = 100
|
| 61 |
+
ema_decay: float = 0.993
|
| 62 |
+
ema_tau: int = 100
|
| 63 |
+
lr_drop: int = 100
|
| 64 |
+
checkpoint_interval: int = 10
|
| 65 |
+
warmup_epochs: int = 0
|
| 66 |
+
lr_vit_layer_decay: float = 0.8
|
| 67 |
+
lr_component_decay: float = 0.7
|
| 68 |
+
drop_path: float = 0.0
|
| 69 |
+
group_detr: int = 13
|
| 70 |
+
ia_bce_loss: bool = True
|
| 71 |
+
cls_loss_coef: float = 1.0
|
| 72 |
+
num_select: int = 300
|
| 73 |
+
dataset_file: Literal["coco", "o365", "roboflow"] = "roboflow"
|
| 74 |
+
square_resize_div_64: bool = True
|
| 75 |
+
dataset_dir: str
|
| 76 |
+
output_dir: str = "output"
|
| 77 |
+
multi_scale: bool = True
|
| 78 |
+
expanded_scales: bool = True
|
| 79 |
+
use_ema: bool = True
|
| 80 |
+
num_workers: int = 2
|
| 81 |
+
weight_decay: float = 1e-4
|
| 82 |
+
early_stopping: bool = False
|
| 83 |
+
early_stopping_patience: int = 10
|
| 84 |
+
early_stopping_min_delta: float = 0.001
|
| 85 |
+
early_stopping_use_ema: bool = False
|
| 86 |
+
tensorboard: bool = True
|
| 87 |
+
wandb: bool = False
|
| 88 |
+
project: Optional[str] = None
|
| 89 |
+
run: Optional[str] = None
|
| 90 |
+
class_names: List[str] = None
|
rfdetr/datasets/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# LW-DETR
|
| 3 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 10 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
import torch.utils.data
|
| 14 |
+
import torchvision
|
| 15 |
+
|
| 16 |
+
from .coco import build as build_coco
|
| 17 |
+
from .o365 import build_o365
|
| 18 |
+
from .coco import build_roboflow
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_coco_api_from_dataset(dataset):
|
| 22 |
+
for _ in range(10):
|
| 23 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
| 24 |
+
dataset = dataset.dataset
|
| 25 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
| 26 |
+
return dataset.coco
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_dataset(image_set, args, resolution):
|
| 30 |
+
if args.dataset_file == 'coco':
|
| 31 |
+
return build_coco(image_set, args, resolution)
|
| 32 |
+
if args.dataset_file == 'o365':
|
| 33 |
+
return build_o365(image_set, args, resolution)
|
| 34 |
+
if args.dataset_file == 'roboflow':
|
| 35 |
+
return build_roboflow(image_set, args, resolution)
|
| 36 |
+
raise ValueError(f'dataset {args.dataset_file} not supported')
|
rfdetr/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (971 Bytes). View file
|
|
|
rfdetr/datasets/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
rfdetr/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (944 Bytes). View file
|
|
|
rfdetr/datasets/__pycache__/coco.cpython-310.pyc
ADDED
|
Binary file (6.28 kB). View file
|
|
|
rfdetr/datasets/__pycache__/coco.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
rfdetr/datasets/__pycache__/coco.cpython-39.pyc
ADDED
|
Binary file (6.3 kB). View file
|
|
|
rfdetr/datasets/__pycache__/coco_eval.cpython-310.pyc
ADDED
|
Binary file (7.26 kB). View file
|
|
|
rfdetr/datasets/__pycache__/coco_eval.cpython-312.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
rfdetr/datasets/__pycache__/coco_eval.cpython-39.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|
rfdetr/datasets/__pycache__/o365.cpython-310.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
rfdetr/datasets/__pycache__/o365.cpython-312.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|
rfdetr/datasets/__pycache__/o365.cpython-39.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
rfdetr/datasets/__pycache__/transforms.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
rfdetr/datasets/__pycache__/transforms.cpython-312.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
rfdetr/datasets/__pycache__/transforms.cpython-39.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
rfdetr/datasets/coco.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 13 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 14 |
+
# ------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
COCO dataset which returns image_id for evaluation.
|
| 18 |
+
|
| 19 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
| 20 |
+
"""
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.data
|
| 25 |
+
import torchvision
|
| 26 |
+
|
| 27 |
+
import rfdetr.datasets.transforms as T
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def compute_multi_scale_scales(resolution, expanded_scales=False):
|
| 31 |
+
if resolution == 640:
|
| 32 |
+
# assume we're doing the original 640x640 and therefore patch_size is 16
|
| 33 |
+
patch_size = 16
|
| 34 |
+
elif resolution % (14 * 4) == 0:
|
| 35 |
+
# assume we're doing some dinov2 resolution variant and therefore patch_size is 14
|
| 36 |
+
patch_size = 14
|
| 37 |
+
elif resolution % (16 * 4) == 0:
|
| 38 |
+
# assume we're doing some other resolution and therefore patch_size is 16
|
| 39 |
+
patch_size = 16
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Resolution {resolution} is not divisible by 16*4 or 14*4")
|
| 42 |
+
# round to the nearest multiple of 4*patch_size to enable both patching and windowing
|
| 43 |
+
base_num_patches_per_window = resolution // (patch_size * 4)
|
| 44 |
+
offsets = [-3, -2, -1, 0, 1, 2, 3, 4] if not expanded_scales else [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
|
| 45 |
+
scales = [base_num_patches_per_window + offset for offset in offsets]
|
| 46 |
+
proposed_scales = [scale * patch_size * 4 for scale in scales]
|
| 47 |
+
proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * 4] # ensure minimum image size
|
| 48 |
+
return proposed_scales
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CocoDetection(torchvision.datasets.CocoDetection):
|
| 52 |
+
def __init__(self, img_folder, ann_file, transforms):
|
| 53 |
+
super(CocoDetection, self).__init__(img_folder, ann_file)
|
| 54 |
+
self._transforms = transforms
|
| 55 |
+
self.prepare = ConvertCoco()
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
img, target = super(CocoDetection, self).__getitem__(idx)
|
| 59 |
+
image_id = self.ids[idx]
|
| 60 |
+
target = {'image_id': image_id, 'annotations': target}
|
| 61 |
+
img, target = self.prepare(img, target)
|
| 62 |
+
if self._transforms is not None:
|
| 63 |
+
img, target = self._transforms(img, target)
|
| 64 |
+
return img, target
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ConvertCoco(object):
|
| 68 |
+
|
| 69 |
+
def __call__(self, image, target):
|
| 70 |
+
w, h = image.size
|
| 71 |
+
|
| 72 |
+
image_id = target["image_id"]
|
| 73 |
+
image_id = torch.tensor([image_id])
|
| 74 |
+
|
| 75 |
+
anno = target["annotations"]
|
| 76 |
+
|
| 77 |
+
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
|
| 78 |
+
|
| 79 |
+
boxes = [obj["bbox"] for obj in anno]
|
| 80 |
+
# guard against no boxes via resizing
|
| 81 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
| 82 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 83 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
| 84 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
| 85 |
+
|
| 86 |
+
classes = [obj["category_id"] for obj in anno]
|
| 87 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 88 |
+
|
| 89 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 90 |
+
boxes = boxes[keep]
|
| 91 |
+
classes = classes[keep]
|
| 92 |
+
|
| 93 |
+
target = {}
|
| 94 |
+
target["boxes"] = boxes
|
| 95 |
+
target["labels"] = classes
|
| 96 |
+
target["image_id"] = image_id
|
| 97 |
+
|
| 98 |
+
# for conversion to coco api
|
| 99 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
| 100 |
+
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
| 101 |
+
target["area"] = area[keep]
|
| 102 |
+
target["iscrowd"] = iscrowd[keep]
|
| 103 |
+
|
| 104 |
+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
| 105 |
+
target["size"] = torch.as_tensor([int(h), int(w)])
|
| 106 |
+
|
| 107 |
+
return image, target
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False):
|
| 111 |
+
|
| 112 |
+
normalize = T.Compose([
|
| 113 |
+
T.ToTensor(),
|
| 114 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
scales = [resolution]
|
| 118 |
+
if multi_scale:
|
| 119 |
+
# scales = [448, 512, 576, 640, 704, 768, 832, 896]
|
| 120 |
+
scales = compute_multi_scale_scales(resolution, expanded_scales)
|
| 121 |
+
print(scales)
|
| 122 |
+
|
| 123 |
+
if image_set == 'train':
|
| 124 |
+
return T.Compose([
|
| 125 |
+
T.RandomHorizontalFlip(),
|
| 126 |
+
T.RandomSelect(
|
| 127 |
+
T.RandomResize(scales, max_size=1333),
|
| 128 |
+
T.Compose([
|
| 129 |
+
T.RandomResize([400, 500, 600]),
|
| 130 |
+
T.RandomSizeCrop(384, 600),
|
| 131 |
+
T.RandomResize(scales, max_size=1333),
|
| 132 |
+
])
|
| 133 |
+
),
|
| 134 |
+
normalize,
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
if image_set == 'val':
|
| 138 |
+
return T.Compose([
|
| 139 |
+
T.RandomResize([resolution], max_size=1333),
|
| 140 |
+
normalize,
|
| 141 |
+
])
|
| 142 |
+
if image_set == 'val_speed':
|
| 143 |
+
return T.Compose([
|
| 144 |
+
T.SquareResize([resolution]),
|
| 145 |
+
normalize,
|
| 146 |
+
])
|
| 147 |
+
|
| 148 |
+
raise ValueError(f'unknown {image_set}')
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False):
|
| 152 |
+
"""
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
normalize = T.Compose([
|
| 156 |
+
T.ToTensor(),
|
| 157 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
scales = [resolution]
|
| 162 |
+
if multi_scale:
|
| 163 |
+
# scales = [448, 512, 576, 640, 704, 768, 832, 896]
|
| 164 |
+
scales = compute_multi_scale_scales(resolution, expanded_scales)
|
| 165 |
+
print(scales)
|
| 166 |
+
|
| 167 |
+
if image_set == 'train':
|
| 168 |
+
return T.Compose([
|
| 169 |
+
T.RandomHorizontalFlip(),
|
| 170 |
+
T.RandomSelect(
|
| 171 |
+
T.SquareResize(scales),
|
| 172 |
+
T.Compose([
|
| 173 |
+
T.RandomResize([400, 500, 600]),
|
| 174 |
+
T.RandomSizeCrop(384, 600),
|
| 175 |
+
T.SquareResize(scales),
|
| 176 |
+
]),
|
| 177 |
+
),
|
| 178 |
+
normalize,
|
| 179 |
+
])
|
| 180 |
+
|
| 181 |
+
if image_set == 'val':
|
| 182 |
+
return T.Compose([
|
| 183 |
+
T.SquareResize([resolution]),
|
| 184 |
+
normalize,
|
| 185 |
+
])
|
| 186 |
+
if image_set == 'val_speed':
|
| 187 |
+
return T.Compose([
|
| 188 |
+
T.SquareResize([resolution]),
|
| 189 |
+
normalize,
|
| 190 |
+
])
|
| 191 |
+
|
| 192 |
+
raise ValueError(f'unknown {image_set}')
|
| 193 |
+
|
| 194 |
+
def build(image_set, args, resolution):
|
| 195 |
+
root = Path(args.coco_path)
|
| 196 |
+
assert root.exists(), f'provided COCO path {root} does not exist'
|
| 197 |
+
mode = 'instances'
|
| 198 |
+
PATHS = {
|
| 199 |
+
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
|
| 200 |
+
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
|
| 201 |
+
"test": (root / "test2017", root / "annotations" / f'image_info_test-dev2017.json'),
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
img_folder, ann_file = PATHS[image_set.split("_")[0]]
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
square_resize = args.square_resize
|
| 208 |
+
except:
|
| 209 |
+
square_resize = False
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
square_resize_div_64 = args.square_resize_div_64
|
| 213 |
+
except:
|
| 214 |
+
square_resize_div_64 = False
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if square_resize_div_64:
|
| 218 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
|
| 219 |
+
else:
|
| 220 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
|
| 221 |
+
return dataset
|
| 222 |
+
|
| 223 |
+
def build_roboflow(image_set, args, resolution):
|
| 224 |
+
root = Path(args.dataset_dir)
|
| 225 |
+
assert root.exists(), f'provided Roboflow path {root} does not exist'
|
| 226 |
+
mode = 'instances'
|
| 227 |
+
PATHS = {
|
| 228 |
+
"train": (root / "train", root / "train" / "_annotations.coco.json"),
|
| 229 |
+
"val": (root / "valid", root / "valid" / "_annotations.coco.json"),
|
| 230 |
+
"test": (root / "test", root / "test" / "_annotations.coco.json"),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
img_folder, ann_file = PATHS[image_set.split("_")[0]]
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
square_resize = args.square_resize
|
| 237 |
+
except:
|
| 238 |
+
square_resize = False
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
square_resize_div_64 = args.square_resize_div_64
|
| 242 |
+
except:
|
| 243 |
+
square_resize_div_64 = False
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if square_resize_div_64:
|
| 247 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale))
|
| 248 |
+
else:
|
| 249 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale))
|
| 250 |
+
return dataset
|
rfdetr/datasets/coco_eval.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 13 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 14 |
+
# ------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
COCO evaluator that works in distributed mode.
|
| 18 |
+
|
| 19 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
| 20 |
+
The difference is that there is less copy-pasting from pycocotools
|
| 21 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
| 22 |
+
"""
|
| 23 |
+
import os
|
| 24 |
+
import contextlib
|
| 25 |
+
import copy
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from pycocotools.cocoeval import COCOeval
|
| 30 |
+
from pycocotools.coco import COCO
|
| 31 |
+
import pycocotools.mask as mask_util
|
| 32 |
+
|
| 33 |
+
from rfdetr.util.misc import all_gather
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CocoEvaluator(object):
|
| 37 |
+
def __init__(self, coco_gt, iou_types):
|
| 38 |
+
assert isinstance(iou_types, (list, tuple))
|
| 39 |
+
coco_gt = copy.deepcopy(coco_gt)
|
| 40 |
+
self.coco_gt = coco_gt
|
| 41 |
+
|
| 42 |
+
self.iou_types = iou_types
|
| 43 |
+
self.coco_eval = {}
|
| 44 |
+
for iou_type in iou_types:
|
| 45 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
| 46 |
+
|
| 47 |
+
self.img_ids = []
|
| 48 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
| 49 |
+
|
| 50 |
+
def update(self, predictions):
|
| 51 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
| 52 |
+
self.img_ids.extend(img_ids)
|
| 53 |
+
|
| 54 |
+
for iou_type in self.iou_types:
|
| 55 |
+
results = self.prepare(predictions, iou_type)
|
| 56 |
+
|
| 57 |
+
# suppress pycocotools prints
|
| 58 |
+
with open(os.devnull, 'w') as devnull:
|
| 59 |
+
with contextlib.redirect_stdout(devnull):
|
| 60 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
| 61 |
+
coco_eval = self.coco_eval[iou_type]
|
| 62 |
+
|
| 63 |
+
coco_eval.cocoDt = coco_dt
|
| 64 |
+
coco_eval.params.imgIds = list(img_ids)
|
| 65 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
| 66 |
+
|
| 67 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
| 68 |
+
|
| 69 |
+
def synchronize_between_processes(self):
|
| 70 |
+
for iou_type in self.iou_types:
|
| 71 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
| 72 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
| 73 |
+
|
| 74 |
+
def accumulate(self):
|
| 75 |
+
for coco_eval in self.coco_eval.values():
|
| 76 |
+
coco_eval.accumulate()
|
| 77 |
+
|
| 78 |
+
def summarize(self):
|
| 79 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
| 80 |
+
print("IoU metric: {}".format(iou_type))
|
| 81 |
+
coco_eval.summarize()
|
| 82 |
+
|
| 83 |
+
def prepare(self, predictions, iou_type):
|
| 84 |
+
if iou_type == "bbox":
|
| 85 |
+
return self.prepare_for_coco_detection(predictions)
|
| 86 |
+
elif iou_type == "segm":
|
| 87 |
+
return self.prepare_for_coco_segmentation(predictions)
|
| 88 |
+
elif iou_type == "keypoints":
|
| 89 |
+
return self.prepare_for_coco_keypoint(predictions)
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
| 92 |
+
|
| 93 |
+
def prepare_for_coco_detection(self, predictions):
|
| 94 |
+
coco_results = []
|
| 95 |
+
for original_id, prediction in predictions.items():
|
| 96 |
+
if len(prediction) == 0:
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
boxes = prediction["boxes"]
|
| 100 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 101 |
+
scores = prediction["scores"].tolist()
|
| 102 |
+
labels = prediction["labels"].tolist()
|
| 103 |
+
|
| 104 |
+
coco_results.extend(
|
| 105 |
+
[
|
| 106 |
+
{
|
| 107 |
+
"image_id": original_id,
|
| 108 |
+
"category_id": labels[k],
|
| 109 |
+
"bbox": box,
|
| 110 |
+
"score": scores[k],
|
| 111 |
+
}
|
| 112 |
+
for k, box in enumerate(boxes)
|
| 113 |
+
]
|
| 114 |
+
)
|
| 115 |
+
return coco_results
|
| 116 |
+
|
| 117 |
+
def prepare_for_coco_segmentation(self, predictions):
|
| 118 |
+
coco_results = []
|
| 119 |
+
for original_id, prediction in predictions.items():
|
| 120 |
+
if len(prediction) == 0:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
scores = prediction["scores"]
|
| 124 |
+
labels = prediction["labels"]
|
| 125 |
+
masks = prediction["masks"]
|
| 126 |
+
|
| 127 |
+
masks = masks > 0.5
|
| 128 |
+
|
| 129 |
+
scores = prediction["scores"].tolist()
|
| 130 |
+
labels = prediction["labels"].tolist()
|
| 131 |
+
|
| 132 |
+
rles = [
|
| 133 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
| 134 |
+
for mask in masks
|
| 135 |
+
]
|
| 136 |
+
for rle in rles:
|
| 137 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 138 |
+
|
| 139 |
+
coco_results.extend(
|
| 140 |
+
[
|
| 141 |
+
{
|
| 142 |
+
"image_id": original_id,
|
| 143 |
+
"category_id": labels[k],
|
| 144 |
+
"segmentation": rle,
|
| 145 |
+
"score": scores[k],
|
| 146 |
+
}
|
| 147 |
+
for k, rle in enumerate(rles)
|
| 148 |
+
]
|
| 149 |
+
)
|
| 150 |
+
return coco_results
|
| 151 |
+
|
| 152 |
+
def prepare_for_coco_keypoint(self, predictions):
|
| 153 |
+
coco_results = []
|
| 154 |
+
for original_id, prediction in predictions.items():
|
| 155 |
+
if len(prediction) == 0:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
boxes = prediction["boxes"]
|
| 159 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 160 |
+
scores = prediction["scores"].tolist()
|
| 161 |
+
labels = prediction["labels"].tolist()
|
| 162 |
+
keypoints = prediction["keypoints"]
|
| 163 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
| 164 |
+
|
| 165 |
+
coco_results.extend(
|
| 166 |
+
[
|
| 167 |
+
{
|
| 168 |
+
"image_id": original_id,
|
| 169 |
+
"category_id": labels[k],
|
| 170 |
+
'keypoints': keypoint,
|
| 171 |
+
"score": scores[k],
|
| 172 |
+
}
|
| 173 |
+
for k, keypoint in enumerate(keypoints)
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
return coco_results
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def convert_to_xywh(boxes):
|
| 180 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
| 181 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def merge(img_ids, eval_imgs):
|
| 185 |
+
all_img_ids = all_gather(img_ids)
|
| 186 |
+
all_eval_imgs = all_gather(eval_imgs)
|
| 187 |
+
|
| 188 |
+
merged_img_ids = []
|
| 189 |
+
for p in all_img_ids:
|
| 190 |
+
merged_img_ids.extend(p)
|
| 191 |
+
|
| 192 |
+
merged_eval_imgs = []
|
| 193 |
+
for p in all_eval_imgs:
|
| 194 |
+
merged_eval_imgs.append(p)
|
| 195 |
+
|
| 196 |
+
merged_img_ids = np.array(merged_img_ids)
|
| 197 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
| 198 |
+
|
| 199 |
+
# keep only unique (and in sorted order) images
|
| 200 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
| 201 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
| 202 |
+
|
| 203 |
+
return merged_img_ids, merged_eval_imgs
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
| 207 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
| 208 |
+
img_ids = list(img_ids)
|
| 209 |
+
eval_imgs = list(eval_imgs.flatten())
|
| 210 |
+
|
| 211 |
+
coco_eval.evalImgs = eval_imgs
|
| 212 |
+
coco_eval.params.imgIds = img_ids
|
| 213 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
#################################################################
|
| 217 |
+
# From pycocotools, just removed the prints and fixed
|
| 218 |
+
# a Python3 bug about unicode not defined
|
| 219 |
+
#################################################################
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def evaluate(self):
|
| 223 |
+
'''
|
| 224 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 225 |
+
:return: None
|
| 226 |
+
'''
|
| 227 |
+
# tic = time.time()
|
| 228 |
+
# print('Running per image evaluation...')
|
| 229 |
+
p = self.params
|
| 230 |
+
# add backward compatibility if useSegm is specified in params
|
| 231 |
+
if p.useSegm is not None:
|
| 232 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
| 233 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
| 234 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
| 235 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 236 |
+
if p.useCats:
|
| 237 |
+
p.catIds = list(np.unique(p.catIds))
|
| 238 |
+
p.maxDets = sorted(p.maxDets)
|
| 239 |
+
self.params = p
|
| 240 |
+
|
| 241 |
+
self._prepare()
|
| 242 |
+
# loop through images, area range, max detection number
|
| 243 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 244 |
+
|
| 245 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
| 246 |
+
computeIoU = self.computeIoU
|
| 247 |
+
elif p.iouType == 'keypoints':
|
| 248 |
+
computeIoU = self.computeOks
|
| 249 |
+
self.ious = {
|
| 250 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 251 |
+
for imgId in p.imgIds
|
| 252 |
+
for catId in catIds}
|
| 253 |
+
|
| 254 |
+
evaluateImg = self.evaluateImg
|
| 255 |
+
maxDet = p.maxDets[-1]
|
| 256 |
+
evalImgs = [
|
| 257 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
| 258 |
+
for catId in catIds
|
| 259 |
+
for areaRng in p.areaRng
|
| 260 |
+
for imgId in p.imgIds
|
| 261 |
+
]
|
| 262 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 263 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
| 264 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 265 |
+
# toc = time.time()
|
| 266 |
+
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
| 267 |
+
return p.imgIds, evalImgs
|
| 268 |
+
|
| 269 |
+
#################################################################
|
| 270 |
+
# end of straight copy from pycocotools, just removing the prints
|
| 271 |
+
#################################################################
|
rfdetr/datasets/o365.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""Dataset file for Object365."""
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from .coco import (
|
| 14 |
+
CocoDetection, make_coco_transforms, make_coco_transforms_square_div_64
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from PIL import Image
|
| 18 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_o365_raw(image_set, args, resolution):
|
| 22 |
+
root = Path(args.coco_path)
|
| 23 |
+
PATHS = {
|
| 24 |
+
"train": (root, root / 'zhiyuan_objv2_train_val_wo_5k.json'),
|
| 25 |
+
"val": (root, root / 'zhiyuan_objv2_minival5k.json'),
|
| 26 |
+
}
|
| 27 |
+
img_folder, ann_file = PATHS[image_set]
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
square_resize = args.square_resize
|
| 31 |
+
except:
|
| 32 |
+
square_resize = False
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
square_resize_div_64 = args.square_resize_div_64
|
| 36 |
+
except:
|
| 37 |
+
square_resize_div_64 = False
|
| 38 |
+
|
| 39 |
+
if square_resize_div_64:
|
| 40 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
|
| 41 |
+
else:
|
| 42 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
|
| 43 |
+
return dataset
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_o365(image_set, args, resolution):
|
| 47 |
+
if image_set == 'train':
|
| 48 |
+
train_ds = build_o365_raw('train', args, resolution=resolution)
|
| 49 |
+
return train_ds
|
| 50 |
+
if image_set == 'val':
|
| 51 |
+
val_ds = build_o365_raw('val', args, resolution=resolution)
|
| 52 |
+
return val_ds
|
| 53 |
+
raise ValueError('Unknown image_set: {}'.format(image_set))
|
rfdetr/datasets/transforms.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 13 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 14 |
+
# ------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Transforms and data augmentation for both image + bbox.
|
| 18 |
+
"""
|
| 19 |
+
import random
|
| 20 |
+
|
| 21 |
+
import PIL
|
| 22 |
+
import numpy as np
|
| 23 |
+
try:
|
| 24 |
+
from collections.abc import Sequence
|
| 25 |
+
except Exception:
|
| 26 |
+
from collections import Sequence
|
| 27 |
+
from numbers import Number
|
| 28 |
+
import torch
|
| 29 |
+
import torchvision.transforms as T
|
| 30 |
+
# from detectron2.data import transforms as DT
|
| 31 |
+
import torchvision.transforms.functional as F
|
| 32 |
+
|
| 33 |
+
from rfdetr.util.box_ops import box_xyxy_to_cxcywh
|
| 34 |
+
from rfdetr.util.misc import interpolate
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def crop(image, target, region):
|
| 38 |
+
cropped_image = F.crop(image, *region)
|
| 39 |
+
|
| 40 |
+
target = target.copy()
|
| 41 |
+
i, j, h, w = region
|
| 42 |
+
|
| 43 |
+
# should we do something wrt the original size?
|
| 44 |
+
target["size"] = torch.tensor([h, w])
|
| 45 |
+
|
| 46 |
+
fields = ["labels", "area", "iscrowd"]
|
| 47 |
+
|
| 48 |
+
if "boxes" in target:
|
| 49 |
+
boxes = target["boxes"]
|
| 50 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 51 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 52 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 53 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 54 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 55 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 56 |
+
target["area"] = area
|
| 57 |
+
fields.append("boxes")
|
| 58 |
+
|
| 59 |
+
if "masks" in target:
|
| 60 |
+
# FIXME should we update the area here if there are no boxes?
|
| 61 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
| 62 |
+
fields.append("masks")
|
| 63 |
+
|
| 64 |
+
# remove elements for which the boxes or masks that have zero area
|
| 65 |
+
if "boxes" in target or "masks" in target:
|
| 66 |
+
# favor boxes selection when defining which elements to keep
|
| 67 |
+
# this is compatible with previous implementation
|
| 68 |
+
if "boxes" in target:
|
| 69 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
| 70 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 71 |
+
else:
|
| 72 |
+
keep = target['masks'].flatten(1).any(1)
|
| 73 |
+
|
| 74 |
+
for field in fields:
|
| 75 |
+
target[field] = target[field][keep]
|
| 76 |
+
|
| 77 |
+
return cropped_image, target
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def hflip(image, target):
|
| 81 |
+
flipped_image = F.hflip(image)
|
| 82 |
+
|
| 83 |
+
w, h = image.size
|
| 84 |
+
|
| 85 |
+
target = target.copy()
|
| 86 |
+
if "boxes" in target:
|
| 87 |
+
boxes = target["boxes"]
|
| 88 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
| 89 |
+
target["boxes"] = boxes
|
| 90 |
+
|
| 91 |
+
if "masks" in target:
|
| 92 |
+
target['masks'] = target['masks'].flip(-1)
|
| 93 |
+
|
| 94 |
+
return flipped_image, target
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def resize(image, target, size, max_size=None):
|
| 98 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 99 |
+
|
| 100 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 101 |
+
w, h = image_size
|
| 102 |
+
if max_size is not None:
|
| 103 |
+
min_original_size = float(min((w, h)))
|
| 104 |
+
max_original_size = float(max((w, h)))
|
| 105 |
+
if max_original_size / min_original_size * size > max_size:
|
| 106 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
| 107 |
+
|
| 108 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 109 |
+
return (h, w)
|
| 110 |
+
|
| 111 |
+
if w < h:
|
| 112 |
+
ow = size
|
| 113 |
+
oh = int(size * h / w)
|
| 114 |
+
else:
|
| 115 |
+
oh = size
|
| 116 |
+
ow = int(size * w / h)
|
| 117 |
+
|
| 118 |
+
return (oh, ow)
|
| 119 |
+
|
| 120 |
+
def get_size(image_size, size, max_size=None):
|
| 121 |
+
if isinstance(size, (list, tuple)):
|
| 122 |
+
return size[::-1]
|
| 123 |
+
else:
|
| 124 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 125 |
+
|
| 126 |
+
size = get_size(image.size, size, max_size)
|
| 127 |
+
rescaled_image = F.resize(image, size)
|
| 128 |
+
|
| 129 |
+
if target is None:
|
| 130 |
+
return rescaled_image, None
|
| 131 |
+
|
| 132 |
+
ratios = tuple(
|
| 133 |
+
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
| 134 |
+
ratio_width, ratio_height = ratios
|
| 135 |
+
|
| 136 |
+
target = target.copy()
|
| 137 |
+
if "boxes" in target:
|
| 138 |
+
boxes = target["boxes"]
|
| 139 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 140 |
+
[ratio_width, ratio_height, ratio_width, ratio_height])
|
| 141 |
+
target["boxes"] = scaled_boxes
|
| 142 |
+
|
| 143 |
+
if "area" in target:
|
| 144 |
+
area = target["area"]
|
| 145 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 146 |
+
target["area"] = scaled_area
|
| 147 |
+
|
| 148 |
+
h, w = size
|
| 149 |
+
target["size"] = torch.tensor([h, w])
|
| 150 |
+
|
| 151 |
+
if "masks" in target:
|
| 152 |
+
target['masks'] = interpolate(
|
| 153 |
+
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
return rescaled_image, target
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def pad(image, target, padding):
|
| 160 |
+
# assumes that we only pad on the bottom right corners
|
| 161 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
| 162 |
+
if target is None:
|
| 163 |
+
return padded_image, None
|
| 164 |
+
target = target.copy()
|
| 165 |
+
# should we do something wrt the original size?
|
| 166 |
+
target["size"] = torch.tensor(padded_image.size[::-1])
|
| 167 |
+
if "masks" in target:
|
| 168 |
+
target['masks'] = torch.nn.functional.pad(
|
| 169 |
+
target['masks'], (0, padding[0], 0, padding[1]))
|
| 170 |
+
return padded_image, target
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RandomCrop(object):
|
| 174 |
+
def __init__(self, size):
|
| 175 |
+
self.size = size
|
| 176 |
+
|
| 177 |
+
def __call__(self, img, target):
|
| 178 |
+
region = T.RandomCrop.get_params(img, self.size)
|
| 179 |
+
return crop(img, target, region)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class RandomSizeCrop(object):
|
| 183 |
+
def __init__(self, min_size: int, max_size: int):
|
| 184 |
+
self.min_size = min_size
|
| 185 |
+
self.max_size = max_size
|
| 186 |
+
|
| 187 |
+
def __call__(self, img: PIL.Image.Image, target: dict):
|
| 188 |
+
w = random.randint(self.min_size, min(img.width, self.max_size))
|
| 189 |
+
h = random.randint(self.min_size, min(img.height, self.max_size))
|
| 190 |
+
region = T.RandomCrop.get_params(img, [h, w])
|
| 191 |
+
return crop(img, target, region)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class CenterCrop(object):
|
| 195 |
+
def __init__(self, size):
|
| 196 |
+
self.size = size
|
| 197 |
+
|
| 198 |
+
def __call__(self, img, target):
|
| 199 |
+
image_width, image_height = img.size
|
| 200 |
+
crop_height, crop_width = self.size
|
| 201 |
+
crop_top = int(round((image_height - crop_height) / 2.))
|
| 202 |
+
crop_left = int(round((image_width - crop_width) / 2.))
|
| 203 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class RandomHorizontalFlip(object):
|
| 207 |
+
def __init__(self, p=0.5):
|
| 208 |
+
self.p = p
|
| 209 |
+
|
| 210 |
+
def __call__(self, img, target):
|
| 211 |
+
if random.random() < self.p:
|
| 212 |
+
return hflip(img, target)
|
| 213 |
+
return img, target
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class RandomResize(object):
|
| 217 |
+
def __init__(self, sizes, max_size=None):
|
| 218 |
+
assert isinstance(sizes, (list, tuple))
|
| 219 |
+
self.sizes = sizes
|
| 220 |
+
self.max_size = max_size
|
| 221 |
+
|
| 222 |
+
def __call__(self, img, target=None):
|
| 223 |
+
size = random.choice(self.sizes)
|
| 224 |
+
return resize(img, target, size, self.max_size)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SquareResize(object):
|
| 228 |
+
def __init__(self, sizes):
|
| 229 |
+
assert isinstance(sizes, (list, tuple))
|
| 230 |
+
self.sizes = sizes
|
| 231 |
+
|
| 232 |
+
def __call__(self, img, target=None):
|
| 233 |
+
size = random.choice(self.sizes)
|
| 234 |
+
rescaled_img=F.resize(img, (size, size))
|
| 235 |
+
w, h = rescaled_img.size
|
| 236 |
+
if target is None:
|
| 237 |
+
return rescaled_img, None
|
| 238 |
+
ratios = tuple(
|
| 239 |
+
float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
|
| 240 |
+
ratio_width, ratio_height = ratios
|
| 241 |
+
|
| 242 |
+
target = target.copy()
|
| 243 |
+
if "boxes" in target:
|
| 244 |
+
boxes = target["boxes"]
|
| 245 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 246 |
+
[ratio_width, ratio_height, ratio_width, ratio_height])
|
| 247 |
+
target["boxes"] = scaled_boxes
|
| 248 |
+
|
| 249 |
+
if "area" in target:
|
| 250 |
+
area = target["area"]
|
| 251 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 252 |
+
target["area"] = scaled_area
|
| 253 |
+
|
| 254 |
+
target["size"] = torch.tensor([h, w])
|
| 255 |
+
|
| 256 |
+
return rescaled_img, target
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class RandomPad(object):
|
| 260 |
+
def __init__(self, max_pad):
|
| 261 |
+
self.max_pad = max_pad
|
| 262 |
+
|
| 263 |
+
def __call__(self, img, target):
|
| 264 |
+
pad_x = random.randint(0, self.max_pad)
|
| 265 |
+
pad_y = random.randint(0, self.max_pad)
|
| 266 |
+
return pad(img, target, (pad_x, pad_y))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class PILtoNdArray(object):
|
| 270 |
+
|
| 271 |
+
def __call__(self, img, target):
|
| 272 |
+
return np.asarray(img), target
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class NdArraytoPIL(object):
|
| 276 |
+
|
| 277 |
+
def __call__(self, img, target):
|
| 278 |
+
return F.to_pil_image(img.astype('uint8')), target
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Pad(object):
|
| 282 |
+
def __init__(self,
|
| 283 |
+
size=None,
|
| 284 |
+
size_divisor=32,
|
| 285 |
+
pad_mode=0,
|
| 286 |
+
offsets=None,
|
| 287 |
+
fill_value=(127.5, 127.5, 127.5)):
|
| 288 |
+
"""
|
| 289 |
+
Pad image to a specified size or multiple of size_divisor.
|
| 290 |
+
Args:
|
| 291 |
+
size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
|
| 292 |
+
size_divisor (int): size divisor, default 32
|
| 293 |
+
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
|
| 294 |
+
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
|
| 295 |
+
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
|
| 296 |
+
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
if not isinstance(size, (int, Sequence)):
|
| 300 |
+
raise TypeError(
|
| 301 |
+
"Type of target_size is invalid when random_size is True. \
|
| 302 |
+
Must be List, now is {}".format(type(size)))
|
| 303 |
+
|
| 304 |
+
if isinstance(size, int):
|
| 305 |
+
size = [size, size]
|
| 306 |
+
|
| 307 |
+
assert pad_mode in [
|
| 308 |
+
-1, 0, 1, 2
|
| 309 |
+
], 'currently only supports four modes [-1, 0, 1, 2]'
|
| 310 |
+
if pad_mode == -1:
|
| 311 |
+
assert offsets, 'if pad_mode is -1, offsets should not be None'
|
| 312 |
+
|
| 313 |
+
self.size = size
|
| 314 |
+
self.size_divisor = size_divisor
|
| 315 |
+
self.pad_mode = pad_mode
|
| 316 |
+
self.fill_value = fill_value
|
| 317 |
+
self.offsets = offsets
|
| 318 |
+
|
| 319 |
+
def apply_bbox(self, bbox, offsets):
|
| 320 |
+
return bbox + np.array(offsets * 2, dtype=np.float32)
|
| 321 |
+
|
| 322 |
+
def apply_image(self, image, offsets, im_size, size):
|
| 323 |
+
x, y = offsets
|
| 324 |
+
im_h, im_w = im_size
|
| 325 |
+
h, w = size
|
| 326 |
+
canvas = np.ones((h, w, 3), dtype=np.float32)
|
| 327 |
+
canvas *= np.array(self.fill_value, dtype=np.float32)
|
| 328 |
+
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
|
| 329 |
+
return canvas
|
| 330 |
+
|
| 331 |
+
def __call__(self, im, target):
|
| 332 |
+
im_h, im_w = im.shape[:2]
|
| 333 |
+
if self.size:
|
| 334 |
+
h, w = self.size
|
| 335 |
+
assert (
|
| 336 |
+
im_h <= h and im_w <= w
|
| 337 |
+
), '(h, w) of target size should be greater than (im_h, im_w)'
|
| 338 |
+
else:
|
| 339 |
+
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
|
| 340 |
+
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
|
| 341 |
+
|
| 342 |
+
if h == im_h and w == im_w:
|
| 343 |
+
return im.astype(np.float32), target
|
| 344 |
+
|
| 345 |
+
if self.pad_mode == -1:
|
| 346 |
+
offset_x, offset_y = self.offsets
|
| 347 |
+
elif self.pad_mode == 0:
|
| 348 |
+
offset_y, offset_x = 0, 0
|
| 349 |
+
elif self.pad_mode == 1:
|
| 350 |
+
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
|
| 351 |
+
else:
|
| 352 |
+
offset_y, offset_x = h - im_h, w - im_w
|
| 353 |
+
|
| 354 |
+
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
|
| 355 |
+
|
| 356 |
+
im = self.apply_image(im, offsets, im_size, size)
|
| 357 |
+
|
| 358 |
+
if self.pad_mode == 0:
|
| 359 |
+
target["size"] = torch.tensor([h, w])
|
| 360 |
+
return im, target
|
| 361 |
+
if 'boxes' in target and len(target['boxes']) > 0:
|
| 362 |
+
boxes = np.asarray(target["boxes"])
|
| 363 |
+
target["boxes"] = torch.from_numpy(self.apply_bbox(boxes, offsets))
|
| 364 |
+
target["size"] = torch.tensor([h, w])
|
| 365 |
+
|
| 366 |
+
return im, target
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class RandomExpand(object):
|
| 370 |
+
"""Random expand the canvas.
|
| 371 |
+
Args:
|
| 372 |
+
ratio (float): maximum expansion ratio.
|
| 373 |
+
prob (float): probability to expand.
|
| 374 |
+
fill_value (list): color value used to fill the canvas. in RGB order.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
|
| 378 |
+
assert ratio > 1.01, "expand ratio must be larger than 1.01"
|
| 379 |
+
self.ratio = ratio
|
| 380 |
+
self.prob = prob
|
| 381 |
+
assert isinstance(fill_value, (Number, Sequence)), \
|
| 382 |
+
"fill value must be either float or sequence"
|
| 383 |
+
if isinstance(fill_value, Number):
|
| 384 |
+
fill_value = (fill_value, ) * 3
|
| 385 |
+
if not isinstance(fill_value, tuple):
|
| 386 |
+
fill_value = tuple(fill_value)
|
| 387 |
+
self.fill_value = fill_value
|
| 388 |
+
|
| 389 |
+
def __call__(self, img, target):
|
| 390 |
+
if np.random.uniform(0., 1.) < self.prob:
|
| 391 |
+
return img, target
|
| 392 |
+
|
| 393 |
+
height, width = img.shape[:2]
|
| 394 |
+
ratio = np.random.uniform(1., self.ratio)
|
| 395 |
+
h = int(height * ratio)
|
| 396 |
+
w = int(width * ratio)
|
| 397 |
+
if not h > height or not w > width:
|
| 398 |
+
return img, target
|
| 399 |
+
y = np.random.randint(0, h - height)
|
| 400 |
+
x = np.random.randint(0, w - width)
|
| 401 |
+
offsets, size = [x, y], [h, w]
|
| 402 |
+
|
| 403 |
+
pad = Pad(size,
|
| 404 |
+
pad_mode=-1,
|
| 405 |
+
offsets=offsets,
|
| 406 |
+
fill_value=self.fill_value)
|
| 407 |
+
|
| 408 |
+
return pad(img, target)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class RandomSelect(object):
|
| 412 |
+
"""
|
| 413 |
+
Randomly selects between transforms1 and transforms2,
|
| 414 |
+
with probability p for transforms1 and (1 - p) for transforms2
|
| 415 |
+
"""
|
| 416 |
+
def __init__(self, transforms1, transforms2, p=0.5):
|
| 417 |
+
self.transforms1 = transforms1
|
| 418 |
+
self.transforms2 = transforms2
|
| 419 |
+
self.p = p
|
| 420 |
+
|
| 421 |
+
def __call__(self, img, target):
|
| 422 |
+
if random.random() < self.p:
|
| 423 |
+
return self.transforms1(img, target)
|
| 424 |
+
return self.transforms2(img, target)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class ToTensor(object):
|
| 428 |
+
def __call__(self, img, target):
|
| 429 |
+
return F.to_tensor(img), target
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class RandomErasing(object):
|
| 433 |
+
|
| 434 |
+
def __init__(self, *args, **kwargs):
|
| 435 |
+
self.eraser = T.RandomErasing(*args, **kwargs)
|
| 436 |
+
|
| 437 |
+
def __call__(self, img, target):
|
| 438 |
+
return self.eraser(img), target
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class Normalize(object):
|
| 442 |
+
def __init__(self, mean, std):
|
| 443 |
+
self.mean = mean
|
| 444 |
+
self.std = std
|
| 445 |
+
|
| 446 |
+
def __call__(self, image, target=None):
|
| 447 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 448 |
+
if target is None:
|
| 449 |
+
return image, None
|
| 450 |
+
target = target.copy()
|
| 451 |
+
h, w = image.shape[-2:]
|
| 452 |
+
if "boxes" in target:
|
| 453 |
+
boxes = target["boxes"]
|
| 454 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
| 455 |
+
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 456 |
+
target["boxes"] = boxes
|
| 457 |
+
return image, target
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class Compose(object):
|
| 461 |
+
def __init__(self, transforms):
|
| 462 |
+
self.transforms = transforms
|
| 463 |
+
|
| 464 |
+
def __call__(self, image, target):
|
| 465 |
+
for t in self.transforms:
|
| 466 |
+
image, target = t(image, target)
|
| 467 |
+
return image, target
|
| 468 |
+
|
| 469 |
+
def __repr__(self):
|
| 470 |
+
format_string = self.__class__.__name__ + "("
|
| 471 |
+
for t in self.transforms:
|
| 472 |
+
format_string += "\n"
|
| 473 |
+
format_string += " {0}".format(t)
|
| 474 |
+
format_string += "\n)"
|
| 475 |
+
return format_string
|
rfdetr/deploy/__init__.py
ADDED
|
File without changes
|
rfdetr/deploy/_onnx/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# LW-DETR
|
| 3 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
"""
|
| 7 |
+
onnx optimizer and symbolic registry
|
| 8 |
+
"""
|
| 9 |
+
from . import optimizer
|
| 10 |
+
from . import symbolic
|
| 11 |
+
|
| 12 |
+
from .optimizer import OnnxOptimizer
|
| 13 |
+
from .symbolic import CustomOpSymbolicRegistry
|
rfdetr/deploy/_onnx/optimizer.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
OnnxOptimizer
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
from collections import OrderedDict
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import onnx
|
| 19 |
+
import torch
|
| 20 |
+
from onnx import shape_inference
|
| 21 |
+
import onnx_graphsurgeon as gs
|
| 22 |
+
from polygraphy.backend.onnx.loader import fold_constants
|
| 23 |
+
from onnx_graphsurgeon.logger.logger import G_LOGGER
|
| 24 |
+
|
| 25 |
+
from .symbolic import CustomOpSymbolicRegistry
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OnnxOptimizer():
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
input,
|
| 32 |
+
severity=G_LOGGER.INFO
|
| 33 |
+
):
|
| 34 |
+
if isinstance(input, str):
|
| 35 |
+
onnx_graph = self.load_onnx(input)
|
| 36 |
+
else:
|
| 37 |
+
onnx_graph = input
|
| 38 |
+
self.graph = gs.import_onnx(onnx_graph)
|
| 39 |
+
self.severity = severity
|
| 40 |
+
self.set_severity(severity)
|
| 41 |
+
|
| 42 |
+
def set_severity(self, severity):
|
| 43 |
+
G_LOGGER.severity = severity
|
| 44 |
+
|
| 45 |
+
def load_onnx(self, onnx_path:str):
|
| 46 |
+
"""Load onnx from file
|
| 47 |
+
"""
|
| 48 |
+
assert os.path.isfile(onnx_path), f"not found onnx file: {onnx_path}"
|
| 49 |
+
onnx_graph = onnx.load(onnx_path)
|
| 50 |
+
G_LOGGER.info(f"load onnx file: {onnx_path}")
|
| 51 |
+
return onnx_graph
|
| 52 |
+
|
| 53 |
+
def save_onnx(self, onnx_path:str):
|
| 54 |
+
onnx_graph = gs.export_onnx(self.graph)
|
| 55 |
+
G_LOGGER.info(f"save onnx file: {onnx_path}")
|
| 56 |
+
onnx.save(onnx_graph, onnx_path)
|
| 57 |
+
|
| 58 |
+
def info(self, prefix=''):
|
| 59 |
+
G_LOGGER.verbose(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
|
| 60 |
+
|
| 61 |
+
def cleanup(self, return_onnx=False):
|
| 62 |
+
self.graph.cleanup().toposort()
|
| 63 |
+
if return_onnx:
|
| 64 |
+
return gs.export_onnx(self.graph)
|
| 65 |
+
|
| 66 |
+
def select_outputs(self, keep, names=None):
|
| 67 |
+
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
| 68 |
+
if names:
|
| 69 |
+
for i, name in enumerate(names):
|
| 70 |
+
self.graph.outputs[i].name = name
|
| 71 |
+
|
| 72 |
+
def find_node_input(self, node, name:str=None, value=None) -> int:
|
| 73 |
+
for i, inp in enumerate(node.inputs):
|
| 74 |
+
if isinstance(name, str) and inp.name == name:
|
| 75 |
+
index = i
|
| 76 |
+
elif inp == value:
|
| 77 |
+
index = i
|
| 78 |
+
assert index >= 0, f"not found {name}({value}) in node.inputs"
|
| 79 |
+
return index
|
| 80 |
+
|
| 81 |
+
def find_node_output(self, node, name:str=None, value=None) -> int:
|
| 82 |
+
for i, inp in enumerate(node.outputs):
|
| 83 |
+
if isinstance(name, str) and inp.name == name:
|
| 84 |
+
index = i
|
| 85 |
+
elif inp == value:
|
| 86 |
+
index = i
|
| 87 |
+
assert index >= 0, f"not found {name}({value}) in node.outputs"
|
| 88 |
+
return index
|
| 89 |
+
|
| 90 |
+
def common_opt(self, return_onnx=False):
|
| 91 |
+
for fn in CustomOpSymbolicRegistry._OPTIMIZER:
|
| 92 |
+
fn(self)
|
| 93 |
+
self.cleanup()
|
| 94 |
+
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=False)
|
| 95 |
+
if onnx_graph.ByteSize() > 2147483648:
|
| 96 |
+
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
| 97 |
+
else:
|
| 98 |
+
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
| 99 |
+
self.graph = gs.import_onnx(onnx_graph)
|
| 100 |
+
self.cleanup()
|
| 101 |
+
if return_onnx:
|
| 102 |
+
return onnx_graph
|
| 103 |
+
|
| 104 |
+
def resize_fix(self):
|
| 105 |
+
'''
|
| 106 |
+
This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs).
|
| 107 |
+
It substitutes found Resize with Resize that takes the size of the output tensor instead of scales.
|
| 108 |
+
It adds Shape->Slice->Concat
|
| 109 |
+
Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor.
|
| 110 |
+
This fix is required for the dynamic shape support.
|
| 111 |
+
'''
|
| 112 |
+
mResizeNodes = 0
|
| 113 |
+
for node in self.graph.nodes:
|
| 114 |
+
if node.op == "Resize" and len(node.inputs) == 3:
|
| 115 |
+
name = node.name + "/"
|
| 116 |
+
|
| 117 |
+
add_node = node.o().o().i(1)
|
| 118 |
+
div_node = node.i()
|
| 119 |
+
|
| 120 |
+
shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4])
|
| 121 |
+
shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out])
|
| 122 |
+
|
| 123 |
+
const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64))
|
| 124 |
+
const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64))
|
| 125 |
+
const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64))
|
| 126 |
+
|
| 127 |
+
slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2])
|
| 128 |
+
slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out])
|
| 129 |
+
|
| 130 |
+
shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2])
|
| 131 |
+
shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out])
|
| 132 |
+
|
| 133 |
+
slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2])
|
| 134 |
+
slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out])
|
| 135 |
+
|
| 136 |
+
concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4])
|
| 137 |
+
concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out])
|
| 138 |
+
|
| 139 |
+
none_var = gs.Variable.empty()
|
| 140 |
+
|
| 141 |
+
resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]])
|
| 142 |
+
|
| 143 |
+
self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw])
|
| 144 |
+
|
| 145 |
+
node.inputs = []
|
| 146 |
+
node.outputs = []
|
| 147 |
+
|
| 148 |
+
mResizeNodes += 1
|
| 149 |
+
|
| 150 |
+
self.cleanup()
|
| 151 |
+
return mResizeNodes
|
| 152 |
+
|
| 153 |
+
def adjustAddNode(self):
|
| 154 |
+
nAdjustAddNode = 0
|
| 155 |
+
for node in self.graph.nodes:
|
| 156 |
+
# Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT.
|
| 157 |
+
if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant):
|
| 158 |
+
tensor = node.inputs[1]
|
| 159 |
+
bias = node.inputs[0]
|
| 160 |
+
node.inputs = [tensor, bias]
|
| 161 |
+
nAdjustAddNode += 1
|
| 162 |
+
|
| 163 |
+
self.cleanup()
|
| 164 |
+
return nAdjustAddNode
|
| 165 |
+
|
| 166 |
+
def decompose_instancenorms(self):
|
| 167 |
+
nRemoveInstanceNorm = 0
|
| 168 |
+
for node in self.graph.nodes:
|
| 169 |
+
if node.op == "InstanceNormalization":
|
| 170 |
+
name = node.name + "/"
|
| 171 |
+
input_tensor = node.inputs[0]
|
| 172 |
+
output_tensor = node.outputs[0]
|
| 173 |
+
mean_out = gs.Variable(name=name + "mean_out")
|
| 174 |
+
mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
|
| 175 |
+
sub_out = gs.Variable(name=name + "sub_out")
|
| 176 |
+
sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
|
| 177 |
+
pow_out = gs.Variable(name=name + "pow_out")
|
| 178 |
+
pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
|
| 179 |
+
pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
|
| 180 |
+
mean2_out = gs.Variable(name=name + "mean2_out")
|
| 181 |
+
mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
|
| 182 |
+
epsilon_out = gs.Variable(name=name + "epsilon_out")
|
| 183 |
+
epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
|
| 184 |
+
epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
|
| 185 |
+
sqrt_out = gs.Variable(name=name + "sqrt_out")
|
| 186 |
+
sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
|
| 187 |
+
div_out = gs.Variable(name=name + "div_out")
|
| 188 |
+
div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
|
| 189 |
+
constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
|
| 190 |
+
constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
|
| 191 |
+
mul_out = gs.Variable(name=name + "mul_out")
|
| 192 |
+
mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
|
| 193 |
+
add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
|
| 194 |
+
self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
|
| 195 |
+
node.inputs = []
|
| 196 |
+
node.outputs = []
|
| 197 |
+
nRemoveInstanceNorm += 1
|
| 198 |
+
|
| 199 |
+
self.cleanup()
|
| 200 |
+
return nRemoveInstanceNorm
|
| 201 |
+
|
| 202 |
+
def insert_groupnorm_plugin(self):
|
| 203 |
+
nGroupNormPlugin = 0
|
| 204 |
+
for node in self.graph.nodes:
|
| 205 |
+
if node.op == "Reshape" and node.outputs != [] and \
|
| 206 |
+
node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
|
| 207 |
+
node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
|
| 208 |
+
node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
|
| 209 |
+
len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3:
|
| 210 |
+
# "node.outputs != []" is added for VAE
|
| 211 |
+
|
| 212 |
+
inputTensor = node.inputs[0]
|
| 213 |
+
|
| 214 |
+
gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
|
| 215 |
+
index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
|
| 216 |
+
gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
|
| 217 |
+
constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
|
| 218 |
+
|
| 219 |
+
betaNode = gammaNode.o()
|
| 220 |
+
index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
|
| 221 |
+
beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
|
| 222 |
+
constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
|
| 223 |
+
|
| 224 |
+
epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]
|
| 225 |
+
|
| 226 |
+
if betaNode.o().op == "Sigmoid": # need Swish
|
| 227 |
+
bSwish = True
|
| 228 |
+
lastNode = betaNode.o().o() # Mul node of Swish
|
| 229 |
+
else:
|
| 230 |
+
bSwish = False
|
| 231 |
+
lastNode = betaNode # Cast node after Group Norm
|
| 232 |
+
|
| 233 |
+
if lastNode.o().op == "Cast":
|
| 234 |
+
lastNode = lastNode.o()
|
| 235 |
+
inputList = [inputTensor, constantGamma, constantBeta]
|
| 236 |
+
groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
|
| 237 |
+
groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
|
| 238 |
+
self.graph.nodes.append(groupNormN)
|
| 239 |
+
|
| 240 |
+
for subNode in self.graph.nodes:
|
| 241 |
+
if lastNode.outputs[0] in subNode.inputs:
|
| 242 |
+
index = subNode.inputs.index(lastNode.outputs[0])
|
| 243 |
+
subNode.inputs[index] = groupNormV
|
| 244 |
+
node.inputs = []
|
| 245 |
+
lastNode.outputs = []
|
| 246 |
+
nGroupNormPlugin += 1
|
| 247 |
+
|
| 248 |
+
self.cleanup()
|
| 249 |
+
return nGroupNormPlugin
|
| 250 |
+
|
| 251 |
+
def insert_layernorm_plugin(self):
|
| 252 |
+
nLayerNormPlugin = 0
|
| 253 |
+
for node in self.graph.nodes:
|
| 254 |
+
if node.op == 'ReduceMean' and \
|
| 255 |
+
node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \
|
| 256 |
+
node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \
|
| 257 |
+
node.o().o(0).o().op == 'ReduceMean' and \
|
| 258 |
+
node.o().o(0).o().o().op == 'Add' and \
|
| 259 |
+
node.o().o(0).o().o().o().op == 'Sqrt' and \
|
| 260 |
+
node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \
|
| 261 |
+
node.o().o(0).o().o().o().o().o().op == 'Mul' and \
|
| 262 |
+
node.o().o(0).o().o().o().o().o().o().op == 'Add' and \
|
| 263 |
+
len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1:
|
| 264 |
+
|
| 265 |
+
if node.i().op == "Add":
|
| 266 |
+
inputTensor = node.inputs[0] # CLIP
|
| 267 |
+
else:
|
| 268 |
+
inputTensor = node.i().inputs[0] # UNet and VAE
|
| 269 |
+
|
| 270 |
+
gammaNode = node.o().o().o().o().o().o().o()
|
| 271 |
+
index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
|
| 272 |
+
gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
|
| 273 |
+
constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
|
| 274 |
+
|
| 275 |
+
betaNode = gammaNode.o()
|
| 276 |
+
index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
|
| 277 |
+
beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
|
| 278 |
+
constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
|
| 279 |
+
|
| 280 |
+
inputList = [inputTensor, constantGamma, constantBeta]
|
| 281 |
+
layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape)
|
| 282 |
+
layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV])
|
| 283 |
+
self.graph.nodes.append(layerNormN)
|
| 284 |
+
nLayerNormPlugin += 1
|
| 285 |
+
|
| 286 |
+
if betaNode.outputs[0] in self.graph.outputs:
|
| 287 |
+
index = self.graph.outputs.index(betaNode.outputs[0])
|
| 288 |
+
self.graph.outputs[index] = layerNormV
|
| 289 |
+
else:
|
| 290 |
+
if betaNode.o().op == "Cast":
|
| 291 |
+
lastNode = betaNode.o()
|
| 292 |
+
else:
|
| 293 |
+
lastNode = betaNode
|
| 294 |
+
for subNode in self.graph.nodes:
|
| 295 |
+
if lastNode.outputs[0] in subNode.inputs:
|
| 296 |
+
index = subNode.inputs.index(lastNode.outputs[0])
|
| 297 |
+
subNode.inputs[index] = layerNormV
|
| 298 |
+
lastNode.outputs = []
|
| 299 |
+
|
| 300 |
+
self.cleanup()
|
| 301 |
+
return nLayerNormPlugin
|
| 302 |
+
|
| 303 |
+
def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0):
|
| 304 |
+
# Get weights of K
|
| 305 |
+
weights_k = node_k.inputs[1].values
|
| 306 |
+
# Get weights of V
|
| 307 |
+
weights_v = node_v.inputs[1].values
|
| 308 |
+
# Input number of channels to K and V
|
| 309 |
+
C = weights_k.shape[0]
|
| 310 |
+
# Number of heads
|
| 311 |
+
H = heads
|
| 312 |
+
# Dimension per head
|
| 313 |
+
D = weights_k.shape[1] // H
|
| 314 |
+
|
| 315 |
+
# Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape
|
| 316 |
+
weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D)
|
| 317 |
+
|
| 318 |
+
# K and V have the same input
|
| 319 |
+
input_tensor = node_k.inputs[0]
|
| 320 |
+
# K and V must have the same output which we feed into fmha plugin
|
| 321 |
+
output_tensor_k = node_k.outputs[0]
|
| 322 |
+
# Create tensor
|
| 323 |
+
constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv))
|
| 324 |
+
|
| 325 |
+
# Create fused KV node
|
| 326 |
+
fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k])
|
| 327 |
+
self.graph.nodes.append(fused_kv_node)
|
| 328 |
+
|
| 329 |
+
# Connect the output of fused node to the inputs of the nodes after K and V
|
| 330 |
+
node_v.o(num_dynamic).inputs[0] = output_tensor_k
|
| 331 |
+
node_k.o(num_dynamic).inputs[0] = output_tensor_k
|
| 332 |
+
for i in range(0,num_dynamic):
|
| 333 |
+
node_v.o().inputs.clear()
|
| 334 |
+
node_k.o().inputs.clear()
|
| 335 |
+
|
| 336 |
+
# Clear inputs and outputs of K and V to ge these nodes cleared
|
| 337 |
+
node_k.outputs.clear()
|
| 338 |
+
node_v.outputs.clear()
|
| 339 |
+
node_k.inputs.clear()
|
| 340 |
+
node_v.inputs.clear()
|
| 341 |
+
|
| 342 |
+
self.cleanup()
|
| 343 |
+
return fused_kv_node
|
| 344 |
+
|
| 345 |
+
def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0):
|
| 346 |
+
# Get inputs and outputs for the fMHCA plugin
|
| 347 |
+
# We take an output of reshape that follows the Q GEMM
|
| 348 |
+
output_q = node_q.o(num_dynamic).o().inputs[0]
|
| 349 |
+
output_kv = node_kv.o().inputs[0]
|
| 350 |
+
output_final_tranpose = final_tranpose.outputs[0]
|
| 351 |
+
|
| 352 |
+
# Clear the inputs of the nodes that follow the Q and KV GEMM
|
| 353 |
+
# to delete these subgraphs (it will be substituted by fMHCA plugin)
|
| 354 |
+
node_kv.outputs[0].outputs[0].inputs.clear()
|
| 355 |
+
node_kv.outputs[0].outputs[0].inputs.clear()
|
| 356 |
+
node_q.o(num_dynamic).o().inputs.clear()
|
| 357 |
+
for i in range(0,num_dynamic):
|
| 358 |
+
node_q.o(i).o().o(1).inputs.clear()
|
| 359 |
+
|
| 360 |
+
weights_kv = node_kv.inputs[1].values
|
| 361 |
+
dims_per_head = weights_kv.shape[1] // (heads * 2)
|
| 362 |
+
|
| 363 |
+
# Reshape dims
|
| 364 |
+
shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64)))
|
| 365 |
+
|
| 366 |
+
# Reshape output tensor
|
| 367 |
+
output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None)
|
| 368 |
+
# Create fMHA plugin
|
| 369 |
+
reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape])
|
| 370 |
+
# Insert node
|
| 371 |
+
self.graph.nodes.append(reshape)
|
| 372 |
+
|
| 373 |
+
# Create fMHCA plugin
|
| 374 |
+
fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose])
|
| 375 |
+
# Insert node
|
| 376 |
+
self.graph.nodes.append(fmhca)
|
| 377 |
+
|
| 378 |
+
# Connect input of fMHCA to output of Q GEMM
|
| 379 |
+
node_q.o(num_dynamic).outputs[0] = output_q
|
| 380 |
+
|
| 381 |
+
if num_dynamic > 0:
|
| 382 |
+
reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None)
|
| 383 |
+
reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out])
|
| 384 |
+
self.graph.nodes.append(reshape2_input1_shape)
|
| 385 |
+
final_tranpose.o().inputs[1] = reshape2_input1_out
|
| 386 |
+
|
| 387 |
+
# Clear outputs of transpose to get this subgraph cleared
|
| 388 |
+
final_tranpose.outputs.clear()
|
| 389 |
+
|
| 390 |
+
self.cleanup()
|
| 391 |
+
|
| 392 |
+
def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0):
|
| 393 |
+
# Get weights of Q
|
| 394 |
+
weights_q = node_q.inputs[1].values
|
| 395 |
+
# Get weights of K
|
| 396 |
+
weights_k = node_k.inputs[1].values
|
| 397 |
+
# Get weights of V
|
| 398 |
+
weights_v = node_v.inputs[1].values
|
| 399 |
+
|
| 400 |
+
# Input number of channels to Q, K and V
|
| 401 |
+
C = weights_k.shape[0]
|
| 402 |
+
# Number of heads
|
| 403 |
+
H = heads
|
| 404 |
+
# Hidden dimension per head
|
| 405 |
+
D = weights_k.shape[1] // H
|
| 406 |
+
|
| 407 |
+
# Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
|
| 408 |
+
weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D)
|
| 409 |
+
|
| 410 |
+
input_tensor = node_k.inputs[0] # K and V have the same input
|
| 411 |
+
# Q, K and V must have the same output which we feed into fmha plugin
|
| 412 |
+
output_tensor_k = node_k.outputs[0]
|
| 413 |
+
# Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
|
| 414 |
+
constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv))
|
| 415 |
+
|
| 416 |
+
# Created a fused node
|
| 417 |
+
fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k])
|
| 418 |
+
self.graph.nodes.append(fused_qkv_node)
|
| 419 |
+
|
| 420 |
+
# Connect the output of the fused node to the inputs of the nodes after Q, K and V
|
| 421 |
+
node_q.o(num_dynamic).inputs[0] = output_tensor_k
|
| 422 |
+
node_k.o(num_dynamic).inputs[0] = output_tensor_k
|
| 423 |
+
node_v.o(num_dynamic).inputs[0] = output_tensor_k
|
| 424 |
+
for i in range(0,num_dynamic):
|
| 425 |
+
node_q.o().inputs.clear()
|
| 426 |
+
node_k.o().inputs.clear()
|
| 427 |
+
node_v.o().inputs.clear()
|
| 428 |
+
|
| 429 |
+
# Clear inputs and outputs of Q, K and V to ge these nodes cleared
|
| 430 |
+
node_q.outputs.clear()
|
| 431 |
+
node_k.outputs.clear()
|
| 432 |
+
node_v.outputs.clear()
|
| 433 |
+
|
| 434 |
+
node_q.inputs.clear()
|
| 435 |
+
node_k.inputs.clear()
|
| 436 |
+
node_v.inputs.clear()
|
| 437 |
+
|
| 438 |
+
self.cleanup()
|
| 439 |
+
return fused_qkv_node
|
| 440 |
+
|
| 441 |
+
def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0):
|
| 442 |
+
# Get inputs and outputs for the fMHA plugin
|
| 443 |
+
output_qkv = node_qkv.o().inputs[0]
|
| 444 |
+
output_final_tranpose = final_tranpose.outputs[0]
|
| 445 |
+
|
| 446 |
+
# Clear the inputs of the nodes that follow the QKV GEMM
|
| 447 |
+
# to delete these subgraphs (it will be substituted by fMHA plugin)
|
| 448 |
+
node_qkv.outputs[0].outputs[2].inputs.clear()
|
| 449 |
+
node_qkv.outputs[0].outputs[1].inputs.clear()
|
| 450 |
+
node_qkv.outputs[0].outputs[0].inputs.clear()
|
| 451 |
+
|
| 452 |
+
weights_qkv = node_qkv.inputs[1].values
|
| 453 |
+
dims_per_head = weights_qkv.shape[1] // (heads * 3)
|
| 454 |
+
|
| 455 |
+
# Reshape dims
|
| 456 |
+
shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64)))
|
| 457 |
+
|
| 458 |
+
# Reshape output tensor
|
| 459 |
+
output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None)
|
| 460 |
+
# Create fMHA plugin
|
| 461 |
+
reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape])
|
| 462 |
+
# Insert node
|
| 463 |
+
self.graph.nodes.append(reshape)
|
| 464 |
+
|
| 465 |
+
# Create fMHA plugin
|
| 466 |
+
fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose])
|
| 467 |
+
# Insert node
|
| 468 |
+
self.graph.nodes.append(fmha)
|
| 469 |
+
|
| 470 |
+
if num_dynamic > 0:
|
| 471 |
+
reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None)
|
| 472 |
+
reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out])
|
| 473 |
+
self.graph.nodes.append(reshape2_input1_shape)
|
| 474 |
+
final_tranpose.o().inputs[1] = reshape2_input1_out
|
| 475 |
+
|
| 476 |
+
# Clear outputs of transpose to get this subgraph cleared
|
| 477 |
+
final_tranpose.outputs.clear()
|
| 478 |
+
|
| 479 |
+
self.cleanup()
|
| 480 |
+
|
| 481 |
+
def mha_mhca_detected(self, node, mha):
|
| 482 |
+
# Go from V GEMM down to the S*V MatMul and all way up to K GEMM
|
| 483 |
+
# If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
|
| 484 |
+
# If we are looking for MHA inputs (K and V) must be not equal.
|
| 485 |
+
if node.op == "MatMul" and len(node.outputs) == 1 and \
|
| 486 |
+
((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \
|
| 487 |
+
(not mha and len(node.inputs[0].inputs) == 0)):
|
| 488 |
+
|
| 489 |
+
if node.o().op == 'Shape':
|
| 490 |
+
if node.o(1).op == 'Shape':
|
| 491 |
+
num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
|
| 492 |
+
else:
|
| 493 |
+
num_dynamic_kv = 1
|
| 494 |
+
# For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
|
| 495 |
+
num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1
|
| 496 |
+
else:
|
| 497 |
+
num_dynamic_kv = 0
|
| 498 |
+
num_dynamic_q = 0
|
| 499 |
+
|
| 500 |
+
o = node.o(num_dynamic_kv)
|
| 501 |
+
if o.op == "Reshape" and \
|
| 502 |
+
o.o().op == "Transpose" and \
|
| 503 |
+
o.o().o().op == "Reshape" and \
|
| 504 |
+
o.o().o().o().op == "MatMul" and \
|
| 505 |
+
o.o().o().o().i(0).op == "Softmax" and \
|
| 506 |
+
o.o().o().o().i(1).op == "Reshape" and \
|
| 507 |
+
o.o().o().o().i(0).i().op == "Mul" and \
|
| 508 |
+
o.o().o().o().i(0).i().i().op == "MatMul" and \
|
| 509 |
+
o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
|
| 510 |
+
o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
|
| 511 |
+
o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
|
| 512 |
+
o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
|
| 513 |
+
o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
|
| 514 |
+
o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
|
| 515 |
+
node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
|
| 516 |
+
# "len(node.outputs) == 1" to make sure we are not in the already fused node
|
| 517 |
+
node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
|
| 518 |
+
node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
|
| 519 |
+
node_v = node
|
| 520 |
+
final_tranpose = o.o().o().o().o(num_dynamic_q).o()
|
| 521 |
+
# Sanity check to make sure that the graph looks like expected
|
| 522 |
+
if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
|
| 523 |
+
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose
|
| 524 |
+
return False, 0, 0, None, None, None, None
|
| 525 |
+
|
| 526 |
+
def fuse_kv_insert_fmhca(self, heads, mhca_index, sm):
|
| 527 |
+
nodes = self.graph.nodes
|
| 528 |
+
# Iterate over graph and search for MHCA pattern
|
| 529 |
+
for idx, _ in enumerate(nodes):
|
| 530 |
+
# fMHCA can't be at the 2 last layers of the network. It is a guard from OOB
|
| 531 |
+
if idx + 1 > len(nodes) or idx + 2 > len(nodes):
|
| 532 |
+
continue
|
| 533 |
+
|
| 534 |
+
# Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected
|
| 535 |
+
detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
|
| 536 |
+
self.mha_mhca_detected(nodes[idx], mha=False)
|
| 537 |
+
if detected:
|
| 538 |
+
assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1
|
| 539 |
+
# Skip the FMHCA plugin for SM75 except for when the dim per head is 40.
|
| 540 |
+
if sm == 75 and node_q.inputs[1].shape[1] // heads == 160:
|
| 541 |
+
continue
|
| 542 |
+
# Fuse K and V GEMMS
|
| 543 |
+
node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv)
|
| 544 |
+
# Insert fMHCA plugin
|
| 545 |
+
self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q)
|
| 546 |
+
return True
|
| 547 |
+
return False
|
| 548 |
+
|
| 549 |
+
def fuse_qkv_insert_fmha(self, heads, mha_index):
|
| 550 |
+
nodes = self.graph.nodes
|
| 551 |
+
# Iterate over graph and search for MHA pattern
|
| 552 |
+
for idx, _ in enumerate(nodes):
|
| 553 |
+
# fMHA can't be at the 2 last layers of the network. It is a guard from OOB
|
| 554 |
+
if idx + 1 > len(nodes) or idx + 2 > len(nodes):
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
# Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected
|
| 558 |
+
detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
|
| 559 |
+
self.mha_mhca_detected(nodes[idx], mha=True)
|
| 560 |
+
if detected:
|
| 561 |
+
assert num_dynamic_q == num_dynamic_kv
|
| 562 |
+
# Fuse Q, K and V GEMMS
|
| 563 |
+
node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
|
| 564 |
+
# Insert fMHA plugin
|
| 565 |
+
self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
|
| 566 |
+
return True
|
| 567 |
+
return False
|
| 568 |
+
|
| 569 |
+
def insert_fmhca_plugin(self, num_heads, sm):
|
| 570 |
+
mhca_index = 0
|
| 571 |
+
while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm):
|
| 572 |
+
mhca_index += 1
|
| 573 |
+
return mhca_index
|
| 574 |
+
|
| 575 |
+
def insert_fmha_plugin(self, num_heads):
|
| 576 |
+
mha_index = 0
|
| 577 |
+
while self.fuse_qkv_insert_fmha(num_heads, mha_index):
|
| 578 |
+
mha_index += 1
|
| 579 |
+
return mha_index
|
rfdetr/deploy/_onnx/symbolic.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
"""
|
| 10 |
+
CustomOpSymbolicRegistry class
|
| 11 |
+
"""
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
|
| 14 |
+
import onnx
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.onnx import register_custom_op_symbolic
|
| 19 |
+
from torch.onnx.symbolic_helper import parse_args
|
| 20 |
+
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes
|
| 21 |
+
from torch.autograd import Function
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CustomOpSymbolicRegistry:
|
| 25 |
+
# _SYMBOLICS = {}
|
| 26 |
+
_OPTIMIZER = []
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def optimizer(cls, fn):
|
| 30 |
+
cls._OPTIMIZER.append(fn)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def register_optimizer():
|
| 34 |
+
def optimizer_wrapper(fn):
|
| 35 |
+
CustomOpSymbolicRegistry.optimizer(fn)
|
| 36 |
+
return fn
|
| 37 |
+
return optimizer_wrapper
|
rfdetr/deploy/benchmark.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
This tool provides performance benchmarks by using ONNX Runtime and TensorRT
|
| 12 |
+
to run inference on a given model with the COCO validation set. It offers
|
| 13 |
+
reliable measurements of inference latency using ONNX Runtime or TensorRT
|
| 14 |
+
on the device.
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
import copy
|
| 18 |
+
import contextlib
|
| 19 |
+
import datetime
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import os.path as osp
|
| 23 |
+
import random
|
| 24 |
+
import time
|
| 25 |
+
import ast
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from collections import namedtuple, OrderedDict
|
| 28 |
+
|
| 29 |
+
from pycocotools.cocoeval import COCOeval
|
| 30 |
+
from pycocotools.coco import COCO
|
| 31 |
+
import pycocotools.mask as mask_util
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
from PIL import Image
|
| 35 |
+
import torch
|
| 36 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 37 |
+
import torchvision.transforms as T
|
| 38 |
+
import torchvision.transforms.functional as F
|
| 39 |
+
import tqdm
|
| 40 |
+
|
| 41 |
+
import pycuda.driver as cuda
|
| 42 |
+
import pycuda.autoinit
|
| 43 |
+
import onnxruntime as nxrun
|
| 44 |
+
import tensorrt as trt
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parser_args():
|
| 48 |
+
parser = argparse.ArgumentParser('performance benchmark tool for onnx/trt model')
|
| 49 |
+
parser.add_argument('--path', type=str, help='engine file path')
|
| 50 |
+
parser.add_argument('--coco_path', type=str, default="data/coco", help='coco dataset path')
|
| 51 |
+
parser.add_argument('--device', default=0, type=int)
|
| 52 |
+
parser.add_argument('--run_benchmark', action='store_true', help='repeat the inference to benchmark the latency')
|
| 53 |
+
parser.add_argument('--disable_eval', action='store_true', help='disable evaluation')
|
| 54 |
+
return parser.parse_args()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CocoEvaluator(object):
|
| 58 |
+
def __init__(self, coco_gt, iou_types):
|
| 59 |
+
assert isinstance(iou_types, (list, tuple))
|
| 60 |
+
coco_gt = COCO(coco_gt)
|
| 61 |
+
coco_gt = copy.deepcopy(coco_gt)
|
| 62 |
+
self.coco_gt = coco_gt
|
| 63 |
+
|
| 64 |
+
self.iou_types = iou_types
|
| 65 |
+
self.coco_eval = {}
|
| 66 |
+
for iou_type in iou_types:
|
| 67 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
| 68 |
+
|
| 69 |
+
self.img_ids = []
|
| 70 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
| 71 |
+
|
| 72 |
+
def update(self, predictions):
|
| 73 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
| 74 |
+
self.img_ids.extend(img_ids)
|
| 75 |
+
|
| 76 |
+
for iou_type in self.iou_types:
|
| 77 |
+
results = self.prepare(predictions, iou_type)
|
| 78 |
+
|
| 79 |
+
# suppress pycocotools prints
|
| 80 |
+
with open(os.devnull, 'w') as devnull:
|
| 81 |
+
with contextlib.redirect_stdout(devnull):
|
| 82 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
| 83 |
+
coco_eval = self.coco_eval[iou_type]
|
| 84 |
+
|
| 85 |
+
coco_eval.cocoDt = coco_dt
|
| 86 |
+
coco_eval.params.imgIds = list(img_ids)
|
| 87 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
| 88 |
+
|
| 89 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
| 90 |
+
|
| 91 |
+
def synchronize_between_processes(self):
|
| 92 |
+
for iou_type in self.iou_types:
|
| 93 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
| 94 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
| 95 |
+
|
| 96 |
+
def accumulate(self):
|
| 97 |
+
for coco_eval in self.coco_eval.values():
|
| 98 |
+
coco_eval.accumulate()
|
| 99 |
+
|
| 100 |
+
def summarize(self):
|
| 101 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
| 102 |
+
print("IoU metric: {}".format(iou_type))
|
| 103 |
+
coco_eval.summarize()
|
| 104 |
+
|
| 105 |
+
def prepare(self, predictions, iou_type):
|
| 106 |
+
if iou_type == "bbox":
|
| 107 |
+
return self.prepare_for_coco_detection(predictions)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
| 110 |
+
|
| 111 |
+
def prepare_for_coco_detection(self, predictions):
|
| 112 |
+
coco_results = []
|
| 113 |
+
for original_id, prediction in predictions.items():
|
| 114 |
+
if len(prediction) == 0:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
boxes = prediction["boxes"]
|
| 118 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 119 |
+
scores = prediction["scores"].tolist()
|
| 120 |
+
labels = prediction["labels"].tolist()
|
| 121 |
+
|
| 122 |
+
coco_results.extend(
|
| 123 |
+
[
|
| 124 |
+
{
|
| 125 |
+
"image_id": original_id,
|
| 126 |
+
"category_id": labels[k],
|
| 127 |
+
"bbox": box,
|
| 128 |
+
"score": scores[k],
|
| 129 |
+
}
|
| 130 |
+
for k, box in enumerate(boxes)
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
return coco_results
|
| 134 |
+
|
| 135 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
| 136 |
+
img_ids = list(img_ids)
|
| 137 |
+
eval_imgs = list(eval_imgs.flatten())
|
| 138 |
+
|
| 139 |
+
coco_eval.evalImgs = eval_imgs
|
| 140 |
+
coco_eval.params.imgIds = img_ids
|
| 141 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
| 142 |
+
|
| 143 |
+
def evaluate(self):
|
| 144 |
+
'''
|
| 145 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 146 |
+
:return: None
|
| 147 |
+
'''
|
| 148 |
+
# Running per image evaluation...
|
| 149 |
+
p = self.params
|
| 150 |
+
# add backward compatibility if useSegm is specified in params
|
| 151 |
+
if p.useSegm is not None:
|
| 152 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
| 153 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
| 154 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
| 155 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 156 |
+
if p.useCats:
|
| 157 |
+
p.catIds = list(np.unique(p.catIds))
|
| 158 |
+
p.maxDets = sorted(p.maxDets)
|
| 159 |
+
self.params = p
|
| 160 |
+
|
| 161 |
+
self._prepare()
|
| 162 |
+
# loop through images, area range, max detection number
|
| 163 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 164 |
+
|
| 165 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
| 166 |
+
computeIoU = self.computeIoU
|
| 167 |
+
elif p.iouType == 'keypoints':
|
| 168 |
+
computeIoU = self.computeOks
|
| 169 |
+
self.ious = {
|
| 170 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 171 |
+
for imgId in p.imgIds
|
| 172 |
+
for catId in catIds}
|
| 173 |
+
|
| 174 |
+
evaluateImg = self.evaluateImg
|
| 175 |
+
maxDet = p.maxDets[-1]
|
| 176 |
+
evalImgs = [
|
| 177 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
| 178 |
+
for catId in catIds
|
| 179 |
+
for areaRng in p.areaRng
|
| 180 |
+
for imgId in p.imgIds
|
| 181 |
+
]
|
| 182 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 183 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
| 184 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 185 |
+
return p.imgIds, evalImgs
|
| 186 |
+
|
| 187 |
+
def convert_to_xywh(boxes):
|
| 188 |
+
boxes[:, 2:] -= boxes[:, :2]
|
| 189 |
+
return boxes
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_image_list(ann_file):
|
| 193 |
+
with open(ann_file, 'r') as fin:
|
| 194 |
+
data = json.load(fin)
|
| 195 |
+
return data['images']
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def load_image(file_path):
|
| 199 |
+
return Image.open(file_path).convert("RGB")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Compose(object):
|
| 203 |
+
def __init__(self, transforms):
|
| 204 |
+
self.transforms = transforms
|
| 205 |
+
|
| 206 |
+
def __call__(self, image, target):
|
| 207 |
+
for t in self.transforms:
|
| 208 |
+
image, target = t(image, target)
|
| 209 |
+
return image, target
|
| 210 |
+
|
| 211 |
+
def __repr__(self):
|
| 212 |
+
format_string = self.__class__.__name__ + "("
|
| 213 |
+
for t in self.transforms:
|
| 214 |
+
format_string += "\n"
|
| 215 |
+
format_string += " {0}".format(t)
|
| 216 |
+
format_string += "\n)"
|
| 217 |
+
return format_string
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ToTensor(object):
|
| 221 |
+
def __call__(self, img, target):
|
| 222 |
+
return F.to_tensor(img), target
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class Normalize(object):
|
| 226 |
+
def __init__(self, mean, std):
|
| 227 |
+
self.mean = mean
|
| 228 |
+
self.std = std
|
| 229 |
+
|
| 230 |
+
def __call__(self, image, target=None):
|
| 231 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 232 |
+
if target is None:
|
| 233 |
+
return image, None
|
| 234 |
+
target = target.copy()
|
| 235 |
+
h, w = image.shape[-2:]
|
| 236 |
+
if "boxes" in target:
|
| 237 |
+
boxes = target["boxes"]
|
| 238 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
| 239 |
+
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 240 |
+
target["boxes"] = boxes
|
| 241 |
+
return image, target
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SquareResize(object):
|
| 245 |
+
def __init__(self, sizes):
|
| 246 |
+
assert isinstance(sizes, (list, tuple))
|
| 247 |
+
self.sizes = sizes
|
| 248 |
+
|
| 249 |
+
def __call__(self, img, target=None):
|
| 250 |
+
size = random.choice(self.sizes)
|
| 251 |
+
rescaled_img=F.resize(img, (size, size))
|
| 252 |
+
w, h = rescaled_img.size
|
| 253 |
+
if target is None:
|
| 254 |
+
return rescaled_img, None
|
| 255 |
+
ratios = tuple(
|
| 256 |
+
float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
|
| 257 |
+
ratio_width, ratio_height = ratios
|
| 258 |
+
|
| 259 |
+
target = target.copy()
|
| 260 |
+
if "boxes" in target:
|
| 261 |
+
boxes = target["boxes"]
|
| 262 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 263 |
+
[ratio_width, ratio_height, ratio_width, ratio_height])
|
| 264 |
+
target["boxes"] = scaled_boxes
|
| 265 |
+
|
| 266 |
+
if "area" in target:
|
| 267 |
+
area = target["area"]
|
| 268 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 269 |
+
target["area"] = scaled_area
|
| 270 |
+
|
| 271 |
+
target["size"] = torch.tensor([h, w])
|
| 272 |
+
|
| 273 |
+
return rescaled_img, target
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def infer_transforms():
|
| 277 |
+
normalize = Compose([
|
| 278 |
+
ToTensor(),
|
| 279 |
+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 280 |
+
])
|
| 281 |
+
return Compose([
|
| 282 |
+
SquareResize([640]),
|
| 283 |
+
normalize,
|
| 284 |
+
])
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def box_cxcywh_to_xyxy(x):
|
| 288 |
+
x_c, y_c, w, h = x.unbind(-1)
|
| 289 |
+
b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)),
|
| 290 |
+
(x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))]
|
| 291 |
+
return torch.stack(b, dim=-1)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def post_process(outputs, target_sizes):
|
| 295 |
+
out_logits, out_bbox = outputs['labels'], outputs['dets']
|
| 296 |
+
|
| 297 |
+
assert len(out_logits) == len(target_sizes)
|
| 298 |
+
assert target_sizes.shape[1] == 2
|
| 299 |
+
|
| 300 |
+
prob = out_logits.sigmoid()
|
| 301 |
+
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
|
| 302 |
+
scores = topk_values
|
| 303 |
+
topk_boxes = topk_indexes // out_logits.shape[2]
|
| 304 |
+
labels = topk_indexes % out_logits.shape[2]
|
| 305 |
+
boxes = box_cxcywh_to_xyxy(out_bbox)
|
| 306 |
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
|
| 307 |
+
|
| 308 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 309 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 310 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 311 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 312 |
+
|
| 313 |
+
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
|
| 314 |
+
|
| 315 |
+
return results
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
|
| 319 |
+
time_list = []
|
| 320 |
+
for img_dict in tqdm.tqdm(img_list):
|
| 321 |
+
image = load_image(os.path.join(prefix, img_dict['file_name']))
|
| 322 |
+
width, height = image.size
|
| 323 |
+
orig_target_sizes = torch.Tensor([height, width])
|
| 324 |
+
image_tensor, _ = infer_transforms()(image, None) # target is None
|
| 325 |
+
|
| 326 |
+
samples = image_tensor[None].numpy()
|
| 327 |
+
|
| 328 |
+
time_profile.reset()
|
| 329 |
+
with time_profile:
|
| 330 |
+
for _ in range(repeats):
|
| 331 |
+
res = sess.run(None, {"input": samples})
|
| 332 |
+
time_list.append(time_profile.total / repeats)
|
| 333 |
+
outputs = {}
|
| 334 |
+
outputs['labels'] = torch.Tensor(res[1]).to(device)
|
| 335 |
+
outputs['dets'] = torch.Tensor(res[0]).to(device)
|
| 336 |
+
|
| 337 |
+
orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
|
| 338 |
+
results = post_process(outputs, orig_target_sizes)
|
| 339 |
+
res = {img_dict['id']: results[0]}
|
| 340 |
+
if coco_evaluator is not None:
|
| 341 |
+
coco_evaluator.update(res)
|
| 342 |
+
|
| 343 |
+
print("Model latency with ONNX Runtime: {}ms".format(1000 * sum(time_list) / len(img_list)))
|
| 344 |
+
|
| 345 |
+
# accumulate predictions from all images
|
| 346 |
+
stats = {}
|
| 347 |
+
if coco_evaluator is not None:
|
| 348 |
+
coco_evaluator.synchronize_between_processes()
|
| 349 |
+
coco_evaluator.accumulate()
|
| 350 |
+
coco_evaluator.summarize()
|
| 351 |
+
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
|
| 352 |
+
print(stats)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
|
| 356 |
+
time_list = []
|
| 357 |
+
for img_dict in tqdm.tqdm(img_list):
|
| 358 |
+
image = load_image(os.path.join(prefix, img_dict['file_name']))
|
| 359 |
+
width, height = image.size
|
| 360 |
+
orig_target_sizes = torch.Tensor([height, width])
|
| 361 |
+
image_tensor, _ = infer_transforms()(image, None) # target is None
|
| 362 |
+
|
| 363 |
+
samples = image_tensor[None].to(device)
|
| 364 |
+
_, _, h, w = samples.shape
|
| 365 |
+
im_shape = torch.Tensor(np.array([h, w]).reshape((1, 2)).astype(np.float32)).to(device)
|
| 366 |
+
scale_factor = torch.Tensor(np.array([h / height, w / width]).reshape((1, 2)).astype(np.float32)).to(device)
|
| 367 |
+
|
| 368 |
+
time_profile.reset()
|
| 369 |
+
with time_profile:
|
| 370 |
+
for _ in range(repeats):
|
| 371 |
+
outputs = model({"input": samples})
|
| 372 |
+
|
| 373 |
+
time_list.append(time_profile.total / repeats)
|
| 374 |
+
orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
|
| 375 |
+
if coco_evaluator is not None:
|
| 376 |
+
results = post_process(outputs, orig_target_sizes)
|
| 377 |
+
res = {img_dict['id']: results[0]}
|
| 378 |
+
coco_evaluator.update(res)
|
| 379 |
+
|
| 380 |
+
print("Model latency with TensorRT: {}ms".format(1000 * sum(time_list) / len(img_list)))
|
| 381 |
+
|
| 382 |
+
# accumulate predictions from all images
|
| 383 |
+
stats = {}
|
| 384 |
+
if coco_evaluator is not None:
|
| 385 |
+
coco_evaluator.synchronize_between_processes()
|
| 386 |
+
coco_evaluator.accumulate()
|
| 387 |
+
coco_evaluator.summarize()
|
| 388 |
+
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
|
| 389 |
+
print(stats)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class TRTInference(object):
|
| 393 |
+
"""TensorRT inference engine
|
| 394 |
+
"""
|
| 395 |
+
def __init__(self, engine_path='dino.engine', device='cuda:0', sync_mode:bool=False, max_batch_size=32, verbose=False):
|
| 396 |
+
self.engine_path = engine_path
|
| 397 |
+
self.device = device
|
| 398 |
+
self.sync_mode = sync_mode
|
| 399 |
+
self.max_batch_size = max_batch_size
|
| 400 |
+
|
| 401 |
+
self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
|
| 402 |
+
|
| 403 |
+
self.engine = self.load_engine(engine_path)
|
| 404 |
+
|
| 405 |
+
self.context = self.engine.create_execution_context()
|
| 406 |
+
|
| 407 |
+
self.bindings = self.get_bindings(self.engine, self.context, self.max_batch_size, self.device)
|
| 408 |
+
self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items())
|
| 409 |
+
|
| 410 |
+
self.input_names = self.get_input_names()
|
| 411 |
+
self.output_names = self.get_output_names()
|
| 412 |
+
|
| 413 |
+
if not self.sync_mode:
|
| 414 |
+
self.stream = cuda.Stream()
|
| 415 |
+
|
| 416 |
+
# self.time_profile = TimeProfiler()
|
| 417 |
+
self.time_profile = None
|
| 418 |
+
|
| 419 |
+
def get_dummy_input(self, batch_size:int):
|
| 420 |
+
blob = {}
|
| 421 |
+
for name, binding in self.bindings.items():
|
| 422 |
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
| 423 |
+
print(f"make dummy input {name} with shape {binding.shape}")
|
| 424 |
+
blob[name] = torch.rand(batch_size, *binding.shape[1:]).float().to('cuda:0')
|
| 425 |
+
return blob
|
| 426 |
+
|
| 427 |
+
def load_engine(self, path):
|
| 428 |
+
'''load engine
|
| 429 |
+
'''
|
| 430 |
+
trt.init_libnvinfer_plugins(self.logger, '')
|
| 431 |
+
with open(path, 'rb') as f, trt.Runtime(self.logger) as runtime:
|
| 432 |
+
return runtime.deserialize_cuda_engine(f.read())
|
| 433 |
+
|
| 434 |
+
def get_input_names(self, ):
|
| 435 |
+
names = []
|
| 436 |
+
for _, name in enumerate(self.engine):
|
| 437 |
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
| 438 |
+
names.append(name)
|
| 439 |
+
return names
|
| 440 |
+
|
| 441 |
+
def get_output_names(self, ):
|
| 442 |
+
names = []
|
| 443 |
+
for _, name in enumerate(self.engine):
|
| 444 |
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
| 445 |
+
names.append(name)
|
| 446 |
+
return names
|
| 447 |
+
|
| 448 |
+
def get_bindings(self, engine, context, max_batch_size=32, device=None):
|
| 449 |
+
'''build binddings
|
| 450 |
+
'''
|
| 451 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
| 452 |
+
bindings = OrderedDict()
|
| 453 |
+
|
| 454 |
+
for i, name in enumerate(engine):
|
| 455 |
+
shape = engine.get_tensor_shape(name)
|
| 456 |
+
dtype = trt.nptype(engine.get_tensor_dtype(name))
|
| 457 |
+
|
| 458 |
+
if shape[0] == -1:
|
| 459 |
+
raise NotImplementedError
|
| 460 |
+
|
| 461 |
+
if False:
|
| 462 |
+
if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
| 463 |
+
data = np.random.randn(*shape).astype(dtype)
|
| 464 |
+
ptr = cuda.mem_alloc(data.nbytes)
|
| 465 |
+
bindings[name] = Binding(name, dtype, shape, data, ptr)
|
| 466 |
+
else:
|
| 467 |
+
data = cuda.pagelocked_empty(trt.volume(shape), dtype)
|
| 468 |
+
ptr = cuda.mem_alloc(data.nbytes)
|
| 469 |
+
bindings[name] = Binding(name, dtype, shape, data, ptr)
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
| 473 |
+
bindings[name] = Binding(name, dtype, shape, data, data.data_ptr())
|
| 474 |
+
|
| 475 |
+
return bindings
|
| 476 |
+
|
| 477 |
+
def run_sync(self, blob):
|
| 478 |
+
self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
|
| 479 |
+
self.context.execute_v2(list(self.bindings_addr.values()))
|
| 480 |
+
outputs = {n: self.bindings[n].data for n in self.output_names}
|
| 481 |
+
return outputs
|
| 482 |
+
|
| 483 |
+
def run_async(self, blob):
|
| 484 |
+
self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
|
| 485 |
+
bindings_addr = [int(v) for _, v in self.bindings_addr.items()]
|
| 486 |
+
self.context.execute_async_v2(bindings=bindings_addr, stream_handle=self.stream.handle)
|
| 487 |
+
outputs = {n: self.bindings[n].data for n in self.output_names}
|
| 488 |
+
self.stream.synchronize()
|
| 489 |
+
return outputs
|
| 490 |
+
|
| 491 |
+
def __call__(self, blob):
|
| 492 |
+
if self.sync_mode:
|
| 493 |
+
return self.run_sync(blob)
|
| 494 |
+
else:
|
| 495 |
+
return self.run_async(blob)
|
| 496 |
+
|
| 497 |
+
def synchronize(self, ):
|
| 498 |
+
if not self.sync_mode and torch.cuda.is_available():
|
| 499 |
+
torch.cuda.synchronize()
|
| 500 |
+
elif self.sync_mode:
|
| 501 |
+
self.stream.synchronize()
|
| 502 |
+
|
| 503 |
+
def speed(self, blob, n):
|
| 504 |
+
self.time_profile.reset()
|
| 505 |
+
with self.time_profile:
|
| 506 |
+
for _ in range(n):
|
| 507 |
+
_ = self(blob)
|
| 508 |
+
return self.time_profile.total / n
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def build_engine(self, onnx_file_path, engine_file_path, max_batch_size=32):
|
| 512 |
+
'''Takes an ONNX file and creates a TensorRT engine to run inference with
|
| 513 |
+
http://gitlab.baidu.com/paddle-inference/benchmark/blob/main/backend_trt.py#L57
|
| 514 |
+
'''
|
| 515 |
+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 516 |
+
with trt.Builder(self.logger) as builder, \
|
| 517 |
+
builder.create_network(EXPLICIT_BATCH) as network, \
|
| 518 |
+
trt.OnnxParser(network, self.logger) as parser, \
|
| 519 |
+
builder.create_builder_config() as config:
|
| 520 |
+
|
| 521 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1024 MiB
|
| 522 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 523 |
+
|
| 524 |
+
with open(onnx_file_path, 'rb') as model:
|
| 525 |
+
if not parser.parse(model.read()):
|
| 526 |
+
print('ERROR: Failed to parse the ONNX file.')
|
| 527 |
+
for error in range(parser.num_errors):
|
| 528 |
+
print(parser.get_error(error))
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
serialized_engine = builder.build_serialized_network(network, config)
|
| 532 |
+
with open(engine_file_path, 'wb') as f:
|
| 533 |
+
f.write(serialized_engine)
|
| 534 |
+
|
| 535 |
+
return serialized_engine
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class TimeProfiler(contextlib.ContextDecorator):
|
| 539 |
+
def __init__(self, ):
|
| 540 |
+
self.total = 0
|
| 541 |
+
|
| 542 |
+
def __enter__(self, ):
|
| 543 |
+
self.start = self.time()
|
| 544 |
+
return self
|
| 545 |
+
|
| 546 |
+
def __exit__(self, type, value, traceback):
|
| 547 |
+
self.total += self.time() - self.start
|
| 548 |
+
|
| 549 |
+
def reset(self, ):
|
| 550 |
+
self.total = 0
|
| 551 |
+
|
| 552 |
+
def time(self, ):
|
| 553 |
+
if torch.cuda.is_available():
|
| 554 |
+
torch.cuda.synchronize()
|
| 555 |
+
return time.perf_counter()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def main(args):
|
| 559 |
+
print(args)
|
| 560 |
+
|
| 561 |
+
coco_gt = osp.join(args.coco_path, 'annotations/instances_val2017.json')
|
| 562 |
+
img_list = get_image_list(coco_gt)
|
| 563 |
+
prefix = osp.join(args.coco_path, 'val2017')
|
| 564 |
+
if args.run_benchmark:
|
| 565 |
+
repeats = 10
|
| 566 |
+
print('Inference for each image will be repeated 10 times to obtain '
|
| 567 |
+
'a reliable measurement of inference latency.')
|
| 568 |
+
else:
|
| 569 |
+
repeats = 1
|
| 570 |
+
|
| 571 |
+
if args.disable_eval:
|
| 572 |
+
coco_evaluator = None
|
| 573 |
+
else:
|
| 574 |
+
coco_evaluator = CocoEvaluator(coco_gt, ('bbox',))
|
| 575 |
+
|
| 576 |
+
time_profile = TimeProfiler()
|
| 577 |
+
|
| 578 |
+
if args.path.endswith(".onnx"):
|
| 579 |
+
sess = nxrun.InferenceSession(args.path, providers=['CUDAExecutionProvider'])
|
| 580 |
+
infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
|
| 581 |
+
elif args.path.endswith(".engine"):
|
| 582 |
+
model = TRTInference(args.path, sync_mode=True, device=f'cuda:{args.device}')
|
| 583 |
+
infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
|
| 584 |
+
else:
|
| 585 |
+
raise NotImplementedError('Only model file names ending with ".onnx" and ".engine" are supported.')
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if __name__ == '__main__':
|
| 589 |
+
args = parser_args()
|
| 590 |
+
main(args)
|
rfdetr/deploy/export.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
export ONNX model and TensorRT engine for deployment
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
import ast
|
| 15 |
+
import random
|
| 16 |
+
import argparse
|
| 17 |
+
import subprocess
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
import time
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
|
| 23 |
+
import onnx
|
| 24 |
+
import torch
|
| 25 |
+
import onnxsim
|
| 26 |
+
import numpy as np
|
| 27 |
+
from PIL import Image
|
| 28 |
+
|
| 29 |
+
import rfdetr.util.misc as utils
|
| 30 |
+
import rfdetr.datasets.transforms as T
|
| 31 |
+
from rfdetr.models import build_model
|
| 32 |
+
from rfdetr.deploy._onnx import OnnxOptimizer
|
| 33 |
+
import re
|
| 34 |
+
import sys
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run_command_shell(command, dry_run:bool = False) -> int:
|
| 38 |
+
if dry_run:
|
| 39 |
+
print("")
|
| 40 |
+
print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} {command}")
|
| 41 |
+
print("")
|
| 42 |
+
try:
|
| 43 |
+
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
| 44 |
+
return result
|
| 45 |
+
except subprocess.CalledProcessError as e:
|
| 46 |
+
print(f"Command failed with exit code {e.returncode}")
|
| 47 |
+
print(f"Error output:\n{e.stderr.decode('utf-8')}")
|
| 48 |
+
raise
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def make_infer_image(infer_dir, shape, batch_size, device="cuda"):
|
| 52 |
+
if infer_dir is None:
|
| 53 |
+
dummy = np.random.randint(0, 256, (shape[0], shape[1], 3), dtype=np.uint8)
|
| 54 |
+
image = Image.fromarray(dummy, mode="RGB")
|
| 55 |
+
else:
|
| 56 |
+
image = Image.open(infer_dir).convert("RGB")
|
| 57 |
+
|
| 58 |
+
transforms = T.Compose([
|
| 59 |
+
T.SquareResize([shape[0]]),
|
| 60 |
+
T.ToTensor(),
|
| 61 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
inps, _ = transforms(image, None)
|
| 65 |
+
inps = inps.to(device)
|
| 66 |
+
# inps = utils.nested_tensor_from_tensor_list([inps for _ in range(args.batch_size)])
|
| 67 |
+
inps = torch.stack([inps for _ in range(batch_size)])
|
| 68 |
+
return inps
|
| 69 |
+
|
| 70 |
+
def export_onnx(output_dir, model, input_names, input_tensors, output_names, dynamic_axes, backbone_only=False, verbose=True, opset_version=17):
|
| 71 |
+
export_name = "backbone_model" if backbone_only else "inference_model"
|
| 72 |
+
output_file = os.path.join(output_dir, f"{export_name}.onnx")
|
| 73 |
+
|
| 74 |
+
# Prepare model for export
|
| 75 |
+
if hasattr(model, "export"):
|
| 76 |
+
model.export()
|
| 77 |
+
|
| 78 |
+
torch.onnx.export(
|
| 79 |
+
model,
|
| 80 |
+
input_tensors,
|
| 81 |
+
output_file,
|
| 82 |
+
input_names=input_names,
|
| 83 |
+
output_names=output_names,
|
| 84 |
+
export_params=True,
|
| 85 |
+
keep_initializers_as_inputs=False,
|
| 86 |
+
do_constant_folding=True,
|
| 87 |
+
verbose=verbose,
|
| 88 |
+
opset_version=opset_version,
|
| 89 |
+
dynamic_axes=dynamic_axes)
|
| 90 |
+
|
| 91 |
+
print(f'\nSuccessfully exported ONNX model: {output_file}')
|
| 92 |
+
return output_file
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def onnx_simplify(onnx_dir:str, input_names, input_tensors, force=False):
|
| 96 |
+
sim_onnx_dir = onnx_dir.replace(".onnx", ".sim.onnx")
|
| 97 |
+
if os.path.isfile(sim_onnx_dir) and not force:
|
| 98 |
+
return sim_onnx_dir
|
| 99 |
+
|
| 100 |
+
if isinstance(input_tensors, torch.Tensor):
|
| 101 |
+
input_tensors = [input_tensors]
|
| 102 |
+
|
| 103 |
+
print(f'start simplify ONNX model: {onnx_dir}')
|
| 104 |
+
opt = OnnxOptimizer(onnx_dir)
|
| 105 |
+
opt.info('Model: original')
|
| 106 |
+
opt.common_opt()
|
| 107 |
+
opt.info('Model: optimized')
|
| 108 |
+
opt.save_onnx(sim_onnx_dir)
|
| 109 |
+
input_dict = {name: tensor.detach().cpu().numpy() for name, tensor in zip(input_names, input_tensors)}
|
| 110 |
+
model_opt, check_ok = onnxsim.simplify(
|
| 111 |
+
onnx_dir,
|
| 112 |
+
check_n = 3,
|
| 113 |
+
input_data=input_dict,
|
| 114 |
+
dynamic_input_shape=False)
|
| 115 |
+
if check_ok:
|
| 116 |
+
onnx.save(model_opt, sim_onnx_dir)
|
| 117 |
+
else:
|
| 118 |
+
raise RuntimeError("Failed to simplify ONNX model.")
|
| 119 |
+
print(f'Successfully simplified ONNX model: {sim_onnx_dir}')
|
| 120 |
+
return sim_onnx_dir
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def trtexec(onnx_dir:str, args) -> None:
|
| 124 |
+
engine_dir = onnx_dir.replace(".onnx", f".engine")
|
| 125 |
+
|
| 126 |
+
# Base trtexec command
|
| 127 |
+
trt_command = " ".join([
|
| 128 |
+
"trtexec",
|
| 129 |
+
f"--onnx={onnx_dir}",
|
| 130 |
+
f"--saveEngine={engine_dir}",
|
| 131 |
+
f"--memPoolSize=workspace:4096 --fp16",
|
| 132 |
+
f"--useCudaGraph --useSpinWait --warmUp=500 --avgRuns=1000 --duration=10",
|
| 133 |
+
f"{'--verbose' if args.verbose else ''}"])
|
| 134 |
+
|
| 135 |
+
if args.profile:
|
| 136 |
+
profile_dir = onnx_dir.replace(".onnx", f".nsys-rep")
|
| 137 |
+
# Wrap with nsys profile command
|
| 138 |
+
command = " ".join([
|
| 139 |
+
"nsys profile",
|
| 140 |
+
f"--output={profile_dir}",
|
| 141 |
+
"--trace=cuda,nvtx",
|
| 142 |
+
"--force-overwrite true",
|
| 143 |
+
trt_command
|
| 144 |
+
])
|
| 145 |
+
print(f'Profile data will be saved to: {profile_dir}')
|
| 146 |
+
else:
|
| 147 |
+
command = trt_command
|
| 148 |
+
|
| 149 |
+
output = run_command_shell(command, args.dry_run)
|
| 150 |
+
stats = parse_trtexec_output(output.stdout)
|
| 151 |
+
|
| 152 |
+
def parse_trtexec_output(output_text):
|
| 153 |
+
print(output_text)
|
| 154 |
+
# Common patterns in trtexec output
|
| 155 |
+
gpu_compute_pattern = r"GPU Compute Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms, median = (\d+\.\d+) ms"
|
| 156 |
+
h2d_pattern = r"Host to Device Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
|
| 157 |
+
d2h_pattern = r"Device to Host Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
|
| 158 |
+
latency_pattern = r"Latency: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
|
| 159 |
+
throughput_pattern = r"Throughput: (\d+\.\d+) qps"
|
| 160 |
+
|
| 161 |
+
stats = {}
|
| 162 |
+
|
| 163 |
+
# Extract compute times
|
| 164 |
+
if match := re.search(gpu_compute_pattern, output_text):
|
| 165 |
+
stats.update({
|
| 166 |
+
'compute_min_ms': float(match.group(1)),
|
| 167 |
+
'compute_max_ms': float(match.group(2)),
|
| 168 |
+
'compute_mean_ms': float(match.group(3)),
|
| 169 |
+
'compute_median_ms': float(match.group(4))
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
# Extract H2D times
|
| 173 |
+
if match := re.search(h2d_pattern, output_text):
|
| 174 |
+
stats.update({
|
| 175 |
+
'h2d_min_ms': float(match.group(1)),
|
| 176 |
+
'h2d_max_ms': float(match.group(2)),
|
| 177 |
+
'h2d_mean_ms': float(match.group(3))
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
# Extract D2H times
|
| 181 |
+
if match := re.search(d2h_pattern, output_text):
|
| 182 |
+
stats.update({
|
| 183 |
+
'd2h_min_ms': float(match.group(1)),
|
| 184 |
+
'd2h_max_ms': float(match.group(2)),
|
| 185 |
+
'd2h_mean_ms': float(match.group(3))
|
| 186 |
+
})
|
| 187 |
+
|
| 188 |
+
if match := re.search(latency_pattern, output_text):
|
| 189 |
+
stats.update({
|
| 190 |
+
'latency_min_ms': float(match.group(1)),
|
| 191 |
+
'latency_max_ms': float(match.group(2)),
|
| 192 |
+
'latency_mean_ms': float(match.group(3))
|
| 193 |
+
})
|
| 194 |
+
|
| 195 |
+
# Extract throughput
|
| 196 |
+
if match := re.search(throughput_pattern, output_text):
|
| 197 |
+
stats['throughput_qps'] = float(match.group(1))
|
| 198 |
+
|
| 199 |
+
return stats
|
| 200 |
+
|
| 201 |
+
def no_batch_norm(model):
|
| 202 |
+
for module in model.modules():
|
| 203 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 204 |
+
raise ValueError("BatchNorm2d found in the model. Please remove it.")
|
| 205 |
+
|
| 206 |
+
def main(args):
|
| 207 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
| 208 |
+
print(args)
|
| 209 |
+
# convert device to device_id
|
| 210 |
+
if args.device == 'cuda':
|
| 211 |
+
device_id = "0"
|
| 212 |
+
elif args.device == 'cpu':
|
| 213 |
+
device_id = ""
|
| 214 |
+
else:
|
| 215 |
+
device_id = str(int(args.device))
|
| 216 |
+
args.device = f"cuda:{device_id}"
|
| 217 |
+
|
| 218 |
+
# device for export onnx
|
| 219 |
+
# TODO: export onnx with cuda failed with onnx error
|
| 220 |
+
device = torch.device("cpu")
|
| 221 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = device_id
|
| 222 |
+
|
| 223 |
+
# fix the seed for reproducibility
|
| 224 |
+
seed = args.seed + utils.get_rank()
|
| 225 |
+
torch.manual_seed(seed)
|
| 226 |
+
np.random.seed(seed)
|
| 227 |
+
random.seed(seed)
|
| 228 |
+
|
| 229 |
+
model, criterion, postprocessors = build_model(args)
|
| 230 |
+
n_parameters = sum(p.numel() for p in model.parameters())
|
| 231 |
+
print(f"number of parameters: {n_parameters}")
|
| 232 |
+
n_backbone_parameters = sum(p.numel() for p in model.backbone.parameters())
|
| 233 |
+
print(f"number of backbone parameters: {n_backbone_parameters}")
|
| 234 |
+
n_projector_parameters = sum(p.numel() for p in model.backbone[0].projector.parameters())
|
| 235 |
+
print(f"number of projector parameters: {n_projector_parameters}")
|
| 236 |
+
n_backbone_encoder_parameters = sum(p.numel() for p in model.backbone[0].encoder.parameters())
|
| 237 |
+
print(f"number of backbone encoder parameters: {n_backbone_encoder_parameters}")
|
| 238 |
+
n_transformer_parameters = sum(p.numel() for p in model.transformer.parameters())
|
| 239 |
+
print(f"number of transformer parameters: {n_transformer_parameters}")
|
| 240 |
+
if args.resume:
|
| 241 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 242 |
+
model.load_state_dict(checkpoint['model'], strict=True)
|
| 243 |
+
print(f"load checkpoints {args.resume}")
|
| 244 |
+
|
| 245 |
+
if args.layer_norm:
|
| 246 |
+
no_batch_norm(model)
|
| 247 |
+
|
| 248 |
+
model.to(device)
|
| 249 |
+
|
| 250 |
+
input_tensors = make_infer_image(args, device)
|
| 251 |
+
input_names = ['input']
|
| 252 |
+
output_names = ['features'] if args.backbone_only else ['dets', 'labels']
|
| 253 |
+
dynamic_axes = None
|
| 254 |
+
# Run model inference in pytorch mode
|
| 255 |
+
model.eval().to("cuda")
|
| 256 |
+
input_tensors = input_tensors.to("cuda")
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
if args.backbone_only:
|
| 259 |
+
features = model(input_tensors)
|
| 260 |
+
print(f"PyTorch inference output shape: {features.shape}")
|
| 261 |
+
else:
|
| 262 |
+
outputs = model(input_tensors)
|
| 263 |
+
dets = outputs['pred_boxes']
|
| 264 |
+
labels = outputs['pred_logits']
|
| 265 |
+
print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
|
| 266 |
+
model.cpu()
|
| 267 |
+
input_tensors = input_tensors.cpu()
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
output_file = export_onnx(model, args, input_names, input_tensors, output_names, dynamic_axes)
|
| 271 |
+
|
| 272 |
+
if args.simplify:
|
| 273 |
+
output_file = onnx_simplify(output_file, input_names, input_tensors, args)
|
| 274 |
+
|
| 275 |
+
if args.tensorrt:
|
| 276 |
+
output_file = trtexec(output_file, args)
|
rfdetr/deploy/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pycuda
|
| 2 |
+
onnx
|
| 3 |
+
onnxsim
|
| 4 |
+
onnxruntime
|
| 5 |
+
onnxruntime-gpu
|
| 6 |
+
onnx_graphsurgeon
|
| 7 |
+
tensorrt>=8.6.1
|
| 8 |
+
polygraphy
|
rfdetr/detr.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from logging import getLogger
|
| 12 |
+
from typing import Union, List
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import supervision as sv
|
| 17 |
+
import torch
|
| 18 |
+
import torchvision.transforms.functional as F
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
torch.set_float32_matmul_precision('high')
|
| 23 |
+
except:
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
from rfdetr.config import RFDETRBaseConfig, RFDETRLargeConfig, TrainConfig, ModelConfig
|
| 27 |
+
from rfdetr.main import Model, download_pretrain_weights
|
| 28 |
+
from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink
|
| 29 |
+
from rfdetr.util.coco_classes import COCO_CLASSES
|
| 30 |
+
|
| 31 |
+
logger = getLogger(__name__)
|
| 32 |
+
class RFDETR:
|
| 33 |
+
means = [0.485, 0.456, 0.406]
|
| 34 |
+
stds = [0.229, 0.224, 0.225]
|
| 35 |
+
|
| 36 |
+
def __init__(self, **kwargs):
|
| 37 |
+
self.model_config = self.get_model_config(**kwargs)
|
| 38 |
+
self.maybe_download_pretrain_weights()
|
| 39 |
+
self.model = self.get_model(self.model_config)
|
| 40 |
+
self.callbacks = defaultdict(list)
|
| 41 |
+
|
| 42 |
+
self.model.inference_model = None
|
| 43 |
+
self._is_optimized_for_inference = False
|
| 44 |
+
self._has_warned_about_not_being_optimized_for_inference = False
|
| 45 |
+
self._optimized_has_been_compiled = False
|
| 46 |
+
self._optimized_batch_size = None
|
| 47 |
+
self._optimized_resolution = None
|
| 48 |
+
self._optimized_dtype = None
|
| 49 |
+
|
| 50 |
+
def maybe_download_pretrain_weights(self):
|
| 51 |
+
download_pretrain_weights(self.model_config.pretrain_weights)
|
| 52 |
+
|
| 53 |
+
def get_model_config(self, **kwargs):
|
| 54 |
+
return ModelConfig(**kwargs)
|
| 55 |
+
|
| 56 |
+
def train(self, **kwargs):
|
| 57 |
+
config = self.get_train_config(**kwargs)
|
| 58 |
+
self.train_from_config(config, **kwargs)
|
| 59 |
+
|
| 60 |
+
def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
|
| 61 |
+
self.remove_optimized_model()
|
| 62 |
+
|
| 63 |
+
self.model.inference_model = deepcopy(self.model.model)
|
| 64 |
+
self.model.inference_model.eval()
|
| 65 |
+
self.model.inference_model.export()
|
| 66 |
+
|
| 67 |
+
self._optimized_resolution = self.model.resolution
|
| 68 |
+
self._is_optimized_for_inference = True
|
| 69 |
+
|
| 70 |
+
self.model.inference_model = self.model.inference_model.to(dtype=dtype)
|
| 71 |
+
self._optimized_dtype = dtype
|
| 72 |
+
|
| 73 |
+
if compile:
|
| 74 |
+
self.model.inference_model = torch.jit.trace(
|
| 75 |
+
self.model.inference_model,
|
| 76 |
+
torch.randn(
|
| 77 |
+
batch_size, 3, self.model.resolution, self.model.resolution,
|
| 78 |
+
device=self.model.device,
|
| 79 |
+
dtype=dtype
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
self._optimized_has_been_compiled = True
|
| 83 |
+
self._optimized_batch_size = batch_size
|
| 84 |
+
|
| 85 |
+
def remove_optimized_model(self):
|
| 86 |
+
self.model.inference_model = None
|
| 87 |
+
self._is_optimized_for_inference = False
|
| 88 |
+
self._optimized_has_been_compiled = False
|
| 89 |
+
self._optimized_batch_size = None
|
| 90 |
+
self._optimized_resolution = None
|
| 91 |
+
self._optimized_half = False
|
| 92 |
+
|
| 93 |
+
def export(self, **kwargs):
|
| 94 |
+
self.model.export(**kwargs)
|
| 95 |
+
|
| 96 |
+
def train_from_config(self, config: TrainConfig, **kwargs):
|
| 97 |
+
with open(
|
| 98 |
+
os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r"
|
| 99 |
+
) as f:
|
| 100 |
+
anns = json.load(f)
|
| 101 |
+
num_classes = len(anns["categories"])
|
| 102 |
+
class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"]
|
| 103 |
+
self.model.class_names = class_names
|
| 104 |
+
|
| 105 |
+
if self.model_config.num_classes != num_classes:
|
| 106 |
+
logger.warning(
|
| 107 |
+
f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n"
|
| 108 |
+
f"reinitializing your detection head with {num_classes} classes."
|
| 109 |
+
)
|
| 110 |
+
self.model.reinitialize_detection_head(num_classes)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
train_config = config.dict()
|
| 114 |
+
model_config = self.model_config.dict()
|
| 115 |
+
model_config.pop("num_classes")
|
| 116 |
+
if "class_names" in model_config:
|
| 117 |
+
model_config.pop("class_names")
|
| 118 |
+
|
| 119 |
+
if "class_names" in train_config and train_config["class_names"] is None:
|
| 120 |
+
train_config["class_names"] = class_names
|
| 121 |
+
|
| 122 |
+
for k, v in train_config.items():
|
| 123 |
+
if k in model_config:
|
| 124 |
+
model_config.pop(k)
|
| 125 |
+
if k in kwargs:
|
| 126 |
+
kwargs.pop(k)
|
| 127 |
+
|
| 128 |
+
all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes}
|
| 129 |
+
|
| 130 |
+
metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)
|
| 131 |
+
self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update)
|
| 132 |
+
self.callbacks["on_train_end"].append(metrics_plot_sink.save)
|
| 133 |
+
|
| 134 |
+
if config.tensorboard:
|
| 135 |
+
metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=config.output_dir)
|
| 136 |
+
self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update)
|
| 137 |
+
self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close)
|
| 138 |
+
|
| 139 |
+
if config.wandb:
|
| 140 |
+
metrics_wandb_sink = MetricsWandBSink(
|
| 141 |
+
output_dir=config.output_dir,
|
| 142 |
+
project=config.project,
|
| 143 |
+
run=config.run,
|
| 144 |
+
config=config.model_dump()
|
| 145 |
+
)
|
| 146 |
+
self.callbacks["on_fit_epoch_end"].append(metrics_wandb_sink.update)
|
| 147 |
+
self.callbacks["on_train_end"].append(metrics_wandb_sink.close)
|
| 148 |
+
|
| 149 |
+
if config.early_stopping:
|
| 150 |
+
from rfdetr.util.early_stopping import EarlyStoppingCallback
|
| 151 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 152 |
+
model=self.model,
|
| 153 |
+
patience=config.early_stopping_patience,
|
| 154 |
+
min_delta=config.early_stopping_min_delta,
|
| 155 |
+
use_ema=config.early_stopping_use_ema
|
| 156 |
+
)
|
| 157 |
+
self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update)
|
| 158 |
+
|
| 159 |
+
self.model.train(
|
| 160 |
+
**all_kwargs,
|
| 161 |
+
callbacks=self.callbacks,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def get_train_config(self, **kwargs):
|
| 165 |
+
return TrainConfig(**kwargs)
|
| 166 |
+
|
| 167 |
+
def get_model(self, config: ModelConfig):
|
| 168 |
+
return Model(**config.dict())
|
| 169 |
+
|
| 170 |
+
# Get class_names from the model
|
| 171 |
+
@property
|
| 172 |
+
def class_names(self):
|
| 173 |
+
if hasattr(self.model, 'class_names') and self.model.class_names:
|
| 174 |
+
return {i+1: name for i, name in enumerate(self.model.class_names)}
|
| 175 |
+
|
| 176 |
+
return COCO_CLASSES
|
| 177 |
+
|
| 178 |
+
def predict(
|
| 179 |
+
self,
|
| 180 |
+
images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]],
|
| 181 |
+
threshold: float = 0.5,
|
| 182 |
+
**kwargs,
|
| 183 |
+
) -> Union[sv.Detections, List[sv.Detections]]:
|
| 184 |
+
"""Performs object detection on the input images and returns bounding box
|
| 185 |
+
predictions.
|
| 186 |
+
|
| 187 |
+
This method accepts a single image or a list of images in various formats
|
| 188 |
+
(file path, PIL Image, NumPy array, or torch.Tensor). The images should be in
|
| 189 |
+
RGB channel order. If a torch.Tensor is provided, it must already be normalized
|
| 190 |
+
to values in the [0, 1] range and have the shape (C, H, W).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
images (Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]]):
|
| 194 |
+
A single image or a list of images to process. Images can be provided
|
| 195 |
+
as file paths, PIL Images, NumPy arrays, or torch.Tensors.
|
| 196 |
+
threshold (float, optional):
|
| 197 |
+
The minimum confidence score needed to consider a detected bounding box valid.
|
| 198 |
+
**kwargs:
|
| 199 |
+
Additional keyword arguments.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Union[sv.Detections, List[sv.Detections]]: A single or multiple Detections
|
| 203 |
+
objects, each containing bounding box coordinates, confidence scores,
|
| 204 |
+
and class IDs.
|
| 205 |
+
"""
|
| 206 |
+
if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
|
| 207 |
+
logger.warning(
|
| 208 |
+
"Model is not optimized for inference. "
|
| 209 |
+
"Latency may be higher than expected. "
|
| 210 |
+
"You can optimize the model for inference by calling model.optimize_for_inference()."
|
| 211 |
+
)
|
| 212 |
+
self._has_warned_about_not_being_optimized_for_inference = True
|
| 213 |
+
|
| 214 |
+
self.model.model.eval()
|
| 215 |
+
|
| 216 |
+
if not isinstance(images, list):
|
| 217 |
+
images = [images]
|
| 218 |
+
|
| 219 |
+
orig_sizes = []
|
| 220 |
+
processed_images = []
|
| 221 |
+
|
| 222 |
+
for img in images:
|
| 223 |
+
|
| 224 |
+
if isinstance(img, str):
|
| 225 |
+
img = Image.open(img)
|
| 226 |
+
|
| 227 |
+
if not isinstance(img, torch.Tensor):
|
| 228 |
+
img = F.to_tensor(img)
|
| 229 |
+
|
| 230 |
+
if (img > 1).any():
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"Image has pixel values above 1. Please ensure the image is "
|
| 233 |
+
"normalized (scaled to [0, 1])."
|
| 234 |
+
)
|
| 235 |
+
if img.shape[0] != 3:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"Invalid image shape. Expected 3 channels (RGB), but got "
|
| 238 |
+
f"{img.shape[0]} channels."
|
| 239 |
+
)
|
| 240 |
+
img_tensor = img
|
| 241 |
+
|
| 242 |
+
h, w = img_tensor.shape[1:]
|
| 243 |
+
orig_sizes.append((h, w))
|
| 244 |
+
|
| 245 |
+
img_tensor = img_tensor.to(self.model.device)
|
| 246 |
+
img_tensor = F.normalize(img_tensor, self.means, self.stds)
|
| 247 |
+
img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))
|
| 248 |
+
|
| 249 |
+
processed_images.append(img_tensor)
|
| 250 |
+
|
| 251 |
+
batch_tensor = torch.stack(processed_images)
|
| 252 |
+
|
| 253 |
+
if self._is_optimized_for_inference:
|
| 254 |
+
if self._optimized_resolution != batch_tensor.shape[2]:
|
| 255 |
+
# this could happen if someone manually changes self.model.resolution after optimizing the model
|
| 256 |
+
raise ValueError(f"Resolution mismatch. "
|
| 257 |
+
f"Model was optimized for resolution {self._optimized_resolution}, "
|
| 258 |
+
f"but got {batch_tensor.shape[2]}. "
|
| 259 |
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model().")
|
| 260 |
+
if self._optimized_has_been_compiled:
|
| 261 |
+
if self._optimized_batch_size != batch_tensor.shape[0]:
|
| 262 |
+
raise ValueError(f"Batch size mismatch. "
|
| 263 |
+
f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
|
| 264 |
+
f"but got {batch_tensor.shape[0]}. "
|
| 265 |
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
|
| 266 |
+
"Alternatively, you can recompile the optimized model for a different batch size "
|
| 267 |
+
"by calling model.optimize_for_inference(batch_size=<new_batch_size>).")
|
| 268 |
+
|
| 269 |
+
with torch.inference_mode():
|
| 270 |
+
if self._is_optimized_for_inference:
|
| 271 |
+
predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
|
| 272 |
+
else:
|
| 273 |
+
predictions = self.model.model(batch_tensor)
|
| 274 |
+
if isinstance(predictions, tuple):
|
| 275 |
+
predictions = {
|
| 276 |
+
"pred_logits": predictions[1],
|
| 277 |
+
"pred_boxes": predictions[0]
|
| 278 |
+
}
|
| 279 |
+
target_sizes = torch.tensor(orig_sizes, device=self.model.device)
|
| 280 |
+
results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes)
|
| 281 |
+
|
| 282 |
+
detections_list = []
|
| 283 |
+
for result in results:
|
| 284 |
+
scores = result["scores"]
|
| 285 |
+
labels = result["labels"]
|
| 286 |
+
boxes = result["boxes"]
|
| 287 |
+
|
| 288 |
+
keep = scores > threshold
|
| 289 |
+
scores = scores[keep]
|
| 290 |
+
labels = labels[keep]
|
| 291 |
+
boxes = boxes[keep]
|
| 292 |
+
|
| 293 |
+
detections = sv.Detections(
|
| 294 |
+
xyxy=boxes.float().cpu().numpy(),
|
| 295 |
+
confidence=scores.float().cpu().numpy(),
|
| 296 |
+
class_id=labels.cpu().numpy(),
|
| 297 |
+
)
|
| 298 |
+
detections_list.append(detections)
|
| 299 |
+
|
| 300 |
+
return detections_list if len(detections_list) > 1 else detections_list[0]
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class RFDETRBase(RFDETR):
|
| 304 |
+
def get_model_config(self, **kwargs):
|
| 305 |
+
return RFDETRBaseConfig(**kwargs)
|
| 306 |
+
|
| 307 |
+
def get_train_config(self, **kwargs):
|
| 308 |
+
return TrainConfig(**kwargs)
|
| 309 |
+
|
| 310 |
+
class RFDETRLarge(RFDETR):
|
| 311 |
+
def get_model_config(self, **kwargs):
|
| 312 |
+
return RFDETRLargeConfig(**kwargs)
|
| 313 |
+
|
| 314 |
+
def get_train_config(self, **kwargs):
|
| 315 |
+
return TrainConfig(**kwargs)
|
rfdetr/engine.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Conditional DETR
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 12 |
+
# ------------------------------------------------------------------------
|
| 13 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 14 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Train and eval functions used in main.py
|
| 19 |
+
"""
|
| 20 |
+
import math
|
| 21 |
+
import sys
|
| 22 |
+
from typing import Iterable
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
import rfdetr.util.misc as utils
|
| 27 |
+
from rfdetr.datasets.coco_eval import CocoEvaluator
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from torch.amp import autocast, GradScaler
|
| 31 |
+
DEPRECATED_AMP = False
|
| 32 |
+
except ImportError:
|
| 33 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 34 |
+
DEPRECATED_AMP = True
|
| 35 |
+
from typing import DefaultDict, List, Callable
|
| 36 |
+
from rfdetr.util.misc import NestedTensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_autocast_args(args):
|
| 41 |
+
if DEPRECATED_AMP:
|
| 42 |
+
return {'enabled': args.amp, 'dtype': torch.bfloat16}
|
| 43 |
+
else:
|
| 44 |
+
return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def train_one_epoch(
|
| 48 |
+
model: torch.nn.Module,
|
| 49 |
+
criterion: torch.nn.Module,
|
| 50 |
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 51 |
+
data_loader: Iterable,
|
| 52 |
+
optimizer: torch.optim.Optimizer,
|
| 53 |
+
device: torch.device,
|
| 54 |
+
epoch: int,
|
| 55 |
+
batch_size: int,
|
| 56 |
+
max_norm: float = 0,
|
| 57 |
+
ema_m: torch.nn.Module = None,
|
| 58 |
+
schedules: dict = {},
|
| 59 |
+
num_training_steps_per_epoch=None,
|
| 60 |
+
vit_encoder_num_layers=None,
|
| 61 |
+
args=None,
|
| 62 |
+
callbacks: DefaultDict[str, List[Callable]] = None,
|
| 63 |
+
):
|
| 64 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 65 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 66 |
+
metric_logger.add_meter(
|
| 67 |
+
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 68 |
+
)
|
| 69 |
+
header = "Epoch: [{}]".format(epoch)
|
| 70 |
+
print_freq = 10
|
| 71 |
+
start_steps = epoch * num_training_steps_per_epoch
|
| 72 |
+
|
| 73 |
+
print("Grad accum steps: ", args.grad_accum_steps)
|
| 74 |
+
print("Total batch size: ", batch_size * utils.get_world_size())
|
| 75 |
+
|
| 76 |
+
# Add gradient scaler for AMP
|
| 77 |
+
if DEPRECATED_AMP:
|
| 78 |
+
scaler = GradScaler(enabled=args.amp)
|
| 79 |
+
else:
|
| 80 |
+
scaler = GradScaler('cuda', enabled=args.amp)
|
| 81 |
+
|
| 82 |
+
optimizer.zero_grad()
|
| 83 |
+
assert batch_size % args.grad_accum_steps == 0
|
| 84 |
+
sub_batch_size = batch_size // args.grad_accum_steps
|
| 85 |
+
print("LENGTH OF DATA LOADER:", len(data_loader))
|
| 86 |
+
for data_iter_step, (samples, targets) in enumerate(
|
| 87 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 88 |
+
):
|
| 89 |
+
it = start_steps + data_iter_step
|
| 90 |
+
callback_dict = {
|
| 91 |
+
"step": it,
|
| 92 |
+
"model": model,
|
| 93 |
+
"epoch": epoch,
|
| 94 |
+
}
|
| 95 |
+
for callback in callbacks["on_train_batch_start"]:
|
| 96 |
+
callback(callback_dict)
|
| 97 |
+
if "dp" in schedules:
|
| 98 |
+
if args.distributed:
|
| 99 |
+
model.module.update_drop_path(
|
| 100 |
+
schedules["dp"][it], vit_encoder_num_layers
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers)
|
| 104 |
+
if "do" in schedules:
|
| 105 |
+
if args.distributed:
|
| 106 |
+
model.module.update_dropout(schedules["do"][it])
|
| 107 |
+
else:
|
| 108 |
+
model.update_dropout(schedules["do"][it])
|
| 109 |
+
|
| 110 |
+
for i in range(args.grad_accum_steps):
|
| 111 |
+
start_idx = i * sub_batch_size
|
| 112 |
+
final_idx = start_idx + sub_batch_size
|
| 113 |
+
new_samples_tensors = samples.tensors[start_idx:final_idx]
|
| 114 |
+
new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
|
| 115 |
+
new_samples = new_samples.to(device)
|
| 116 |
+
new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]
|
| 117 |
+
|
| 118 |
+
with autocast(**get_autocast_args(args)):
|
| 119 |
+
outputs = model(new_samples, new_targets)
|
| 120 |
+
loss_dict = criterion(outputs, new_targets)
|
| 121 |
+
weight_dict = criterion.weight_dict
|
| 122 |
+
losses = sum(
|
| 123 |
+
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
|
| 124 |
+
for k in loss_dict.keys()
|
| 125 |
+
if k in weight_dict
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
scaler.scale(losses).backward()
|
| 130 |
+
|
| 131 |
+
# reduce losses over all GPUs for logging purposes
|
| 132 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 133 |
+
loss_dict_reduced_unscaled = {
|
| 134 |
+
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
| 135 |
+
}
|
| 136 |
+
loss_dict_reduced_scaled = {
|
| 137 |
+
k: v * weight_dict[k]
|
| 138 |
+
for k, v in loss_dict_reduced.items()
|
| 139 |
+
if k in weight_dict
|
| 140 |
+
}
|
| 141 |
+
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
|
| 142 |
+
|
| 143 |
+
loss_value = losses_reduced_scaled.item()
|
| 144 |
+
|
| 145 |
+
if not math.isfinite(loss_value):
|
| 146 |
+
print(loss_dict_reduced)
|
| 147 |
+
raise ValueError("Loss is {}, stopping training".format(loss_value))
|
| 148 |
+
|
| 149 |
+
if max_norm > 0:
|
| 150 |
+
scaler.unscale_(optimizer)
|
| 151 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
| 152 |
+
|
| 153 |
+
scaler.step(optimizer)
|
| 154 |
+
scaler.update()
|
| 155 |
+
lr_scheduler.step()
|
| 156 |
+
optimizer.zero_grad()
|
| 157 |
+
if ema_m is not None:
|
| 158 |
+
if epoch >= 0:
|
| 159 |
+
ema_m.update(model)
|
| 160 |
+
metric_logger.update(
|
| 161 |
+
loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
|
| 162 |
+
)
|
| 163 |
+
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
| 164 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 165 |
+
# gather the stats from all processes
|
| 166 |
+
metric_logger.synchronize_between_processes()
|
| 167 |
+
print("Averaged stats:", metric_logger)
|
| 168 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
|
| 172 |
+
model.eval()
|
| 173 |
+
if args.fp16_eval:
|
| 174 |
+
model.half()
|
| 175 |
+
criterion.eval()
|
| 176 |
+
|
| 177 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 178 |
+
metric_logger.add_meter(
|
| 179 |
+
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 180 |
+
)
|
| 181 |
+
header = "Test:"
|
| 182 |
+
|
| 183 |
+
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
|
| 184 |
+
coco_evaluator = CocoEvaluator(base_ds, iou_types)
|
| 185 |
+
|
| 186 |
+
for samples, targets in metric_logger.log_every(data_loader, 10, header):
|
| 187 |
+
samples = samples.to(device)
|
| 188 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 189 |
+
|
| 190 |
+
if args.fp16_eval:
|
| 191 |
+
samples.tensors = samples.tensors.half()
|
| 192 |
+
|
| 193 |
+
# Add autocast for evaluation
|
| 194 |
+
with autocast(**get_autocast_args(args)):
|
| 195 |
+
outputs = model(samples)
|
| 196 |
+
|
| 197 |
+
if args.fp16_eval:
|
| 198 |
+
for key in outputs.keys():
|
| 199 |
+
if key == "enc_outputs":
|
| 200 |
+
for sub_key in outputs[key].keys():
|
| 201 |
+
outputs[key][sub_key] = outputs[key][sub_key].float()
|
| 202 |
+
elif key == "aux_outputs":
|
| 203 |
+
for idx in range(len(outputs[key])):
|
| 204 |
+
for sub_key in outputs[key][idx].keys():
|
| 205 |
+
outputs[key][idx][sub_key] = outputs[key][idx][
|
| 206 |
+
sub_key
|
| 207 |
+
].float()
|
| 208 |
+
else:
|
| 209 |
+
outputs[key] = outputs[key].float()
|
| 210 |
+
|
| 211 |
+
loss_dict = criterion(outputs, targets)
|
| 212 |
+
weight_dict = criterion.weight_dict
|
| 213 |
+
|
| 214 |
+
# reduce losses over all GPUs for logging purposes
|
| 215 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 216 |
+
loss_dict_reduced_scaled = {
|
| 217 |
+
k: v * weight_dict[k]
|
| 218 |
+
for k, v in loss_dict_reduced.items()
|
| 219 |
+
if k in weight_dict
|
| 220 |
+
}
|
| 221 |
+
loss_dict_reduced_unscaled = {
|
| 222 |
+
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
| 223 |
+
}
|
| 224 |
+
metric_logger.update(
|
| 225 |
+
loss=sum(loss_dict_reduced_scaled.values()),
|
| 226 |
+
**loss_dict_reduced_scaled,
|
| 227 |
+
**loss_dict_reduced_unscaled,
|
| 228 |
+
)
|
| 229 |
+
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
| 230 |
+
|
| 231 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 232 |
+
results = postprocessors["bbox"](outputs, orig_target_sizes)
|
| 233 |
+
res = {
|
| 234 |
+
target["image_id"].item(): output
|
| 235 |
+
for target, output in zip(targets, results)
|
| 236 |
+
}
|
| 237 |
+
if coco_evaluator is not None:
|
| 238 |
+
coco_evaluator.update(res)
|
| 239 |
+
|
| 240 |
+
# gather the stats from all processes
|
| 241 |
+
metric_logger.synchronize_between_processes()
|
| 242 |
+
print("Averaged stats:", metric_logger)
|
| 243 |
+
if coco_evaluator is not None:
|
| 244 |
+
coco_evaluator.synchronize_between_processes()
|
| 245 |
+
|
| 246 |
+
# accumulate predictions from all images
|
| 247 |
+
if coco_evaluator is not None:
|
| 248 |
+
coco_evaluator.accumulate()
|
| 249 |
+
coco_evaluator.summarize()
|
| 250 |
+
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 251 |
+
if coco_evaluator is not None:
|
| 252 |
+
if "bbox" in postprocessors.keys():
|
| 253 |
+
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
|
| 254 |
+
if "segm" in postprocessors.keys():
|
| 255 |
+
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
|
| 256 |
+
return stats, coco_evaluator
|
rfdetr/main.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 13 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 14 |
+
# ------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
cleaned main file
|
| 18 |
+
"""
|
| 19 |
+
import argparse
|
| 20 |
+
import ast
|
| 21 |
+
import copy
|
| 22 |
+
import datetime
|
| 23 |
+
import json
|
| 24 |
+
import math
|
| 25 |
+
import os
|
| 26 |
+
import random
|
| 27 |
+
import shutil
|
| 28 |
+
import time
|
| 29 |
+
from copy import deepcopy
|
| 30 |
+
from logging import getLogger
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import DefaultDict, List, Callable
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
from peft import LoraConfig, get_peft_model
|
| 37 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 38 |
+
|
| 39 |
+
import rfdetr.util.misc as utils
|
| 40 |
+
from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
|
| 41 |
+
from rfdetr.engine import evaluate, train_one_epoch
|
| 42 |
+
from rfdetr.models import build_model, build_criterion_and_postprocessors
|
| 43 |
+
from rfdetr.util.benchmark import benchmark
|
| 44 |
+
from rfdetr.util.drop_scheduler import drop_scheduler
|
| 45 |
+
from rfdetr.util.files import download_file
|
| 46 |
+
from rfdetr.util.get_param_dicts import get_param_dict
|
| 47 |
+
from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict
|
| 48 |
+
|
| 49 |
+
if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
|
| 50 |
+
import torch.multiprocessing
|
| 51 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
| 52 |
+
|
| 53 |
+
logger = getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
HOSTED_MODELS = {
|
| 56 |
+
"rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
|
| 57 |
+
# below is a less converged model that may be better for finetuning but worse for inference
|
| 58 |
+
"rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
|
| 59 |
+
"rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def download_pretrain_weights(pretrain_weights: str, redownload=False):
|
| 63 |
+
if pretrain_weights in HOSTED_MODELS:
|
| 64 |
+
if redownload or not os.path.exists(pretrain_weights):
|
| 65 |
+
logger.info(
|
| 66 |
+
f"Downloading pretrained weights for {pretrain_weights}"
|
| 67 |
+
)
|
| 68 |
+
download_file(
|
| 69 |
+
HOSTED_MODELS[pretrain_weights],
|
| 70 |
+
pretrain_weights,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
class Model:
|
| 74 |
+
def __init__(self, **kwargs):
|
| 75 |
+
args = populate_args(**kwargs)
|
| 76 |
+
self.resolution = args.resolution
|
| 77 |
+
self.model = build_model(args)
|
| 78 |
+
self.device = torch.device(args.device)
|
| 79 |
+
if args.pretrain_weights is not None:
|
| 80 |
+
print("Loading pretrain weights")
|
| 81 |
+
try:
|
| 82 |
+
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Failed to load pretrain weights: {e}")
|
| 85 |
+
# re-download weights if they are corrupted
|
| 86 |
+
print("Failed to load pretrain weights, re-downloading")
|
| 87 |
+
download_pretrain_weights(args.pretrain_weights, redownload=True)
|
| 88 |
+
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
|
| 89 |
+
|
| 90 |
+
# Extract class_names from checkpoint if available
|
| 91 |
+
if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'):
|
| 92 |
+
self.class_names = checkpoint['args'].class_names
|
| 93 |
+
|
| 94 |
+
checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
|
| 95 |
+
if checkpoint_num_classes != args.num_classes + 1:
|
| 96 |
+
logger.warning(
|
| 97 |
+
f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n"
|
| 98 |
+
f"reinitializing detection head with {checkpoint_num_classes - 1} classes"
|
| 99 |
+
)
|
| 100 |
+
self.reinitialize_detection_head(checkpoint_num_classes)
|
| 101 |
+
# add support to exclude_keys
|
| 102 |
+
# e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]`
|
| 103 |
+
if args.pretrain_exclude_keys is not None:
|
| 104 |
+
assert isinstance(args.pretrain_exclude_keys, list)
|
| 105 |
+
for exclude_key in args.pretrain_exclude_keys:
|
| 106 |
+
checkpoint['model'].pop(exclude_key)
|
| 107 |
+
if args.pretrain_keys_modify_to_load is not None:
|
| 108 |
+
from util.obj365_to_coco_model import get_coco_pretrain_from_obj365
|
| 109 |
+
assert isinstance(args.pretrain_keys_modify_to_load, list)
|
| 110 |
+
for modify_key_to_load in args.pretrain_keys_modify_to_load:
|
| 111 |
+
try:
|
| 112 |
+
checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365(
|
| 113 |
+
model_without_ddp.state_dict()[modify_key_to_load],
|
| 114 |
+
checkpoint['model'][modify_key_to_load]
|
| 115 |
+
)
|
| 116 |
+
except:
|
| 117 |
+
print(f"Failed to load {modify_key_to_load}, deleting from checkpoint")
|
| 118 |
+
checkpoint['model'].pop(modify_key_to_load)
|
| 119 |
+
|
| 120 |
+
# we may want to resume training with a smaller number of groups for group detr
|
| 121 |
+
num_desired_queries = args.num_queries * args.group_detr
|
| 122 |
+
query_param_names = ["refpoint_embed.weight", "query_feat.weight"]
|
| 123 |
+
for name, state in checkpoint['model'].items():
|
| 124 |
+
if any(name.endswith(x) for x in query_param_names):
|
| 125 |
+
checkpoint['model'][name] = state[:num_desired_queries]
|
| 126 |
+
|
| 127 |
+
self.model.load_state_dict(checkpoint['model'], strict=False)
|
| 128 |
+
|
| 129 |
+
if args.backbone_lora:
|
| 130 |
+
print("Applying LORA to backbone")
|
| 131 |
+
lora_config = LoraConfig(
|
| 132 |
+
r=16,
|
| 133 |
+
lora_alpha=16,
|
| 134 |
+
use_dora=True,
|
| 135 |
+
target_modules=[
|
| 136 |
+
"q_proj", "v_proj", "k_proj", # covers OWL-ViT
|
| 137 |
+
"qkv", # covers open_clip ie Siglip2
|
| 138 |
+
"query", "key", "value", "cls_token", "register_tokens", # covers Dinov2 with windowed attn
|
| 139 |
+
]
|
| 140 |
+
)
|
| 141 |
+
self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
|
| 142 |
+
self.model = self.model.to(self.device)
|
| 143 |
+
self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
|
| 144 |
+
self.stop_early = False
|
| 145 |
+
|
| 146 |
+
def reinitialize_detection_head(self, num_classes):
|
| 147 |
+
self.model.reinitialize_detection_head(num_classes)
|
| 148 |
+
|
| 149 |
+
def request_early_stop(self):
|
| 150 |
+
self.stop_early = True
|
| 151 |
+
print("Early stopping requested, will complete current epoch and stop")
|
| 152 |
+
|
| 153 |
+
def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
|
| 154 |
+
currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
|
| 155 |
+
for key in callbacks.keys():
|
| 156 |
+
if key not in currently_supported_callbacks:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
f"Callback {key} is not currently supported, please file an issue if you need it!\n"
|
| 159 |
+
f"Currently supported callbacks: {currently_supported_callbacks}"
|
| 160 |
+
)
|
| 161 |
+
args = populate_args(**kwargs)
|
| 162 |
+
utils.init_distributed_mode(args)
|
| 163 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
| 164 |
+
print(args)
|
| 165 |
+
device = torch.device(args.device)
|
| 166 |
+
|
| 167 |
+
# fix the seed for reproducibility
|
| 168 |
+
seed = args.seed + utils.get_rank()
|
| 169 |
+
torch.manual_seed(seed)
|
| 170 |
+
np.random.seed(seed)
|
| 171 |
+
random.seed(seed)
|
| 172 |
+
|
| 173 |
+
criterion, postprocessors = build_criterion_and_postprocessors(args)
|
| 174 |
+
model = self.model
|
| 175 |
+
model.to(device)
|
| 176 |
+
|
| 177 |
+
model_without_ddp = model
|
| 178 |
+
if args.distributed:
|
| 179 |
+
if args.sync_bn:
|
| 180 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 181 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 182 |
+
model_without_ddp = model.module
|
| 183 |
+
|
| 184 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 185 |
+
print('number of params:', n_parameters)
|
| 186 |
+
param_dicts = get_param_dict(args, model_without_ddp)
|
| 187 |
+
|
| 188 |
+
param_dicts = [p for p in param_dicts if p['params'].requires_grad]
|
| 189 |
+
|
| 190 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
| 191 |
+
weight_decay=args.weight_decay)
|
| 192 |
+
# Choose the learning rate scheduler based on the new argument
|
| 193 |
+
|
| 194 |
+
dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
|
| 195 |
+
dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)
|
| 196 |
+
|
| 197 |
+
# for cosine annealing, calculate total training steps and warmup steps
|
| 198 |
+
total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps
|
| 199 |
+
num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr
|
| 200 |
+
total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs
|
| 201 |
+
warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs
|
| 202 |
+
def lr_lambda(current_step: int):
|
| 203 |
+
if current_step < warmup_steps_lr:
|
| 204 |
+
# Linear warmup
|
| 205 |
+
return float(current_step) / float(max(1, warmup_steps_lr))
|
| 206 |
+
else:
|
| 207 |
+
# Cosine annealing from multiplier 1.0 down to lr_min_factor
|
| 208 |
+
if args.lr_scheduler == 'cosine':
|
| 209 |
+
progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr))
|
| 210 |
+
return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress))
|
| 211 |
+
elif args.lr_scheduler == 'step':
|
| 212 |
+
if current_step < args.lr_drop * num_training_steps_per_epoch_lr:
|
| 213 |
+
return 1.0
|
| 214 |
+
else:
|
| 215 |
+
return 0.1
|
| 216 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 217 |
+
|
| 218 |
+
if args.distributed:
|
| 219 |
+
sampler_train = DistributedSampler(dataset_train)
|
| 220 |
+
sampler_val = DistributedSampler(dataset_val, shuffle=False)
|
| 221 |
+
else:
|
| 222 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 223 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 224 |
+
|
| 225 |
+
effective_batch_size = args.batch_size * args.grad_accum_steps
|
| 226 |
+
min_batches = kwargs.get('min_batches', 5)
|
| 227 |
+
if len(dataset_train) < effective_batch_size * min_batches:
|
| 228 |
+
logger.info(
|
| 229 |
+
f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}"
|
| 230 |
+
)
|
| 231 |
+
sampler = torch.utils.data.RandomSampler(
|
| 232 |
+
dataset_train,
|
| 233 |
+
replacement=True,
|
| 234 |
+
num_samples=effective_batch_size * min_batches,
|
| 235 |
+
)
|
| 236 |
+
data_loader_train = DataLoader(
|
| 237 |
+
dataset_train,
|
| 238 |
+
batch_size=effective_batch_size,
|
| 239 |
+
collate_fn=utils.collate_fn,
|
| 240 |
+
num_workers=args.num_workers,
|
| 241 |
+
sampler=sampler,
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
batch_sampler_train = torch.utils.data.BatchSampler(
|
| 245 |
+
sampler_train, effective_batch_size, drop_last=True)
|
| 246 |
+
data_loader_train = DataLoader(
|
| 247 |
+
dataset_train,
|
| 248 |
+
batch_sampler=batch_sampler_train,
|
| 249 |
+
collate_fn=utils.collate_fn,
|
| 250 |
+
num_workers=args.num_workers
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
|
| 254 |
+
drop_last=False, collate_fn=utils.collate_fn,
|
| 255 |
+
num_workers=args.num_workers)
|
| 256 |
+
|
| 257 |
+
base_ds = get_coco_api_from_dataset(dataset_val)
|
| 258 |
+
|
| 259 |
+
if args.use_ema:
|
| 260 |
+
self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau)
|
| 261 |
+
else:
|
| 262 |
+
self.ema_m = None
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
output_dir = Path(args.output_dir)
|
| 266 |
+
|
| 267 |
+
if utils.is_main_process():
|
| 268 |
+
print("Get benchmark")
|
| 269 |
+
if args.do_benchmark:
|
| 270 |
+
benchmark_model = copy.deepcopy(model_without_ddp)
|
| 271 |
+
bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
|
| 272 |
+
print(json.dumps(bm, indent=2))
|
| 273 |
+
del benchmark_model
|
| 274 |
+
|
| 275 |
+
if args.resume:
|
| 276 |
+
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
| 277 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
|
| 278 |
+
if args.use_ema:
|
| 279 |
+
if 'ema_model' in checkpoint:
|
| 280 |
+
self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
|
| 281 |
+
else:
|
| 282 |
+
del self.ema_m
|
| 283 |
+
self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau)
|
| 284 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 285 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 286 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 287 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 288 |
+
|
| 289 |
+
if args.eval:
|
| 290 |
+
test_stats, coco_evaluator = evaluate(
|
| 291 |
+
model, criterion, postprocessors, data_loader_val, base_ds, device, args)
|
| 292 |
+
if args.output_dir:
|
| 293 |
+
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
|
| 294 |
+
return
|
| 295 |
+
|
| 296 |
+
# for drop
|
| 297 |
+
total_batch_size = effective_batch_size * utils.get_world_size()
|
| 298 |
+
num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size
|
| 299 |
+
schedules = {}
|
| 300 |
+
if args.dropout > 0:
|
| 301 |
+
schedules['do'] = drop_scheduler(
|
| 302 |
+
args.dropout, args.epochs, num_training_steps_per_epoch,
|
| 303 |
+
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
|
| 304 |
+
print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
|
| 305 |
+
|
| 306 |
+
if args.drop_path > 0:
|
| 307 |
+
schedules['dp'] = drop_scheduler(
|
| 308 |
+
args.drop_path, args.epochs, num_training_steps_per_epoch,
|
| 309 |
+
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
|
| 310 |
+
print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
|
| 311 |
+
|
| 312 |
+
print("Start training")
|
| 313 |
+
start_time = time.time()
|
| 314 |
+
best_map_holder = BestMetricHolder(use_ema=args.use_ema)
|
| 315 |
+
best_map_5095 = 0
|
| 316 |
+
best_map_50 = 0
|
| 317 |
+
best_map_ema_5095 = 0
|
| 318 |
+
best_map_ema_50 = 0
|
| 319 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 320 |
+
epoch_start_time = time.time()
|
| 321 |
+
if args.distributed:
|
| 322 |
+
sampler_train.set_epoch(epoch)
|
| 323 |
+
|
| 324 |
+
model.train()
|
| 325 |
+
criterion.train()
|
| 326 |
+
train_stats = train_one_epoch(
|
| 327 |
+
model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch,
|
| 328 |
+
effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules,
|
| 329 |
+
num_training_steps_per_epoch=num_training_steps_per_epoch,
|
| 330 |
+
vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks)
|
| 331 |
+
train_epoch_time = time.time() - epoch_start_time
|
| 332 |
+
train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time)))
|
| 333 |
+
if args.output_dir:
|
| 334 |
+
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
| 335 |
+
# extra checkpoint before LR drop and every `checkpoint_interval` epochs
|
| 336 |
+
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
|
| 337 |
+
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
|
| 338 |
+
for checkpoint_path in checkpoint_paths:
|
| 339 |
+
weights = {
|
| 340 |
+
'model': model_without_ddp.state_dict(),
|
| 341 |
+
'optimizer': optimizer.state_dict(),
|
| 342 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 343 |
+
'epoch': epoch,
|
| 344 |
+
'args': args,
|
| 345 |
+
}
|
| 346 |
+
if args.use_ema:
|
| 347 |
+
weights.update({
|
| 348 |
+
'ema_model': self.ema_m.module.state_dict(),
|
| 349 |
+
})
|
| 350 |
+
if not args.dont_save_weights:
|
| 351 |
+
# create checkpoint dir
|
| 352 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 353 |
+
|
| 354 |
+
utils.save_on_master(weights, checkpoint_path)
|
| 355 |
+
|
| 356 |
+
with torch.inference_mode():
|
| 357 |
+
test_stats, coco_evaluator = evaluate(
|
| 358 |
+
model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
map_regular = test_stats['coco_eval_bbox'][0]
|
| 362 |
+
_isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
|
| 363 |
+
if _isbest:
|
| 364 |
+
best_map_5095 = max(best_map_5095, map_regular)
|
| 365 |
+
best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1])
|
| 366 |
+
checkpoint_path = output_dir / 'checkpoint0009.pth'
|
| 367 |
+
if not args.dont_save_weights:
|
| 368 |
+
utils.save_on_master({
|
| 369 |
+
'model': model_without_ddp.state_dict(),
|
| 370 |
+
'optimizer': optimizer.state_dict(),
|
| 371 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 372 |
+
'epoch': epoch,
|
| 373 |
+
'args': args,
|
| 374 |
+
}, checkpoint_path)
|
| 375 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 376 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
| 377 |
+
'epoch': epoch,
|
| 378 |
+
'n_parameters': n_parameters}
|
| 379 |
+
if args.use_ema:
|
| 380 |
+
ema_test_stats, _ = evaluate(
|
| 381 |
+
self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
|
| 382 |
+
)
|
| 383 |
+
log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
|
| 384 |
+
map_ema = ema_test_stats['coco_eval_bbox'][0]
|
| 385 |
+
best_map_ema_5095 = max(best_map_ema_5095, map_ema)
|
| 386 |
+
_isbest = best_map_holder.update(map_ema, epoch, is_ema=True)
|
| 387 |
+
if _isbest:
|
| 388 |
+
best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1])
|
| 389 |
+
checkpoint_path = output_dir / 'checkpoint_best_ema.pth'
|
| 390 |
+
if not args.dont_save_weights:
|
| 391 |
+
utils.save_on_master({
|
| 392 |
+
'model': self.ema_m.module.state_dict(),
|
| 393 |
+
'optimizer': optimizer.state_dict(),
|
| 394 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 395 |
+
'epoch': epoch,
|
| 396 |
+
'args': args,
|
| 397 |
+
}, checkpoint_path)
|
| 398 |
+
log_stats.update(best_map_holder.summary())
|
| 399 |
+
|
| 400 |
+
# epoch parameters
|
| 401 |
+
ep_paras = {
|
| 402 |
+
'epoch': epoch,
|
| 403 |
+
'n_parameters': n_parameters
|
| 404 |
+
}
|
| 405 |
+
log_stats.update(ep_paras)
|
| 406 |
+
try:
|
| 407 |
+
log_stats.update({'now_time': str(datetime.datetime.now())})
|
| 408 |
+
except:
|
| 409 |
+
pass
|
| 410 |
+
log_stats['train_epoch_time'] = train_epoch_time_str
|
| 411 |
+
epoch_time = time.time() - epoch_start_time
|
| 412 |
+
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
|
| 413 |
+
log_stats['epoch_time'] = epoch_time_str
|
| 414 |
+
if args.output_dir and utils.is_main_process():
|
| 415 |
+
with (output_dir / "log.txt").open("a") as f:
|
| 416 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 417 |
+
|
| 418 |
+
# for evaluation logs
|
| 419 |
+
if coco_evaluator is not None:
|
| 420 |
+
(output_dir / 'eval').mkdir(exist_ok=True)
|
| 421 |
+
if "bbox" in coco_evaluator.coco_eval:
|
| 422 |
+
filenames = ['latest.pth']
|
| 423 |
+
if epoch % 50 == 0:
|
| 424 |
+
filenames.append(f'{epoch:03}.pth')
|
| 425 |
+
for name in filenames:
|
| 426 |
+
torch.save(coco_evaluator.coco_eval["bbox"].eval,
|
| 427 |
+
output_dir / "eval" / name)
|
| 428 |
+
|
| 429 |
+
for callback in callbacks["on_fit_epoch_end"]:
|
| 430 |
+
callback(log_stats)
|
| 431 |
+
|
| 432 |
+
if self.stop_early:
|
| 433 |
+
print(f"Early stopping requested, stopping at epoch {epoch}")
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
best_is_ema = best_map_ema_5095 > best_map_5095
|
| 437 |
+
|
| 438 |
+
if utils.is_main_process():
|
| 439 |
+
if best_is_ema:
|
| 440 |
+
shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
|
| 441 |
+
else:
|
| 442 |
+
shutil.copy2(output_dir / 'checkpoint0009.pth', output_dir / 'checkpoint_best_total.pth')
|
| 443 |
+
|
| 444 |
+
utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth')
|
| 445 |
+
|
| 446 |
+
best_map_5095 = max(best_map_5095, best_map_ema_5095)
|
| 447 |
+
best_map_50 = max(best_map_50, best_map_ema_50)
|
| 448 |
+
|
| 449 |
+
results_json = {
|
| 450 |
+
"map95": best_map_5095,
|
| 451 |
+
"map50": best_map_50,
|
| 452 |
+
"class": "all"
|
| 453 |
+
}
|
| 454 |
+
results = {
|
| 455 |
+
"class_map": {
|
| 456 |
+
"valid": [
|
| 457 |
+
results_json
|
| 458 |
+
]
|
| 459 |
+
}
|
| 460 |
+
}
|
| 461 |
+
with open(output_dir / "results.json", "w") as f:
|
| 462 |
+
json.dump(results, f)
|
| 463 |
+
|
| 464 |
+
total_time = time.time() - start_time
|
| 465 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 466 |
+
print('Training time {}'.format(total_time_str))
|
| 467 |
+
print('Results saved to {}'.format(output_dir / "results.json"))
|
| 468 |
+
|
| 469 |
+
if best_is_ema:
|
| 470 |
+
self.model = self.ema_m.module
|
| 471 |
+
self.model.eval()
|
| 472 |
+
|
| 473 |
+
for callback in callbacks["on_train_end"]:
|
| 474 |
+
callback()
|
| 475 |
+
|
| 476 |
+
def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs):
|
| 477 |
+
"""Export the trained model to ONNX format"""
|
| 478 |
+
print(f"Exporting model to ONNX format")
|
| 479 |
+
try:
|
| 480 |
+
from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image
|
| 481 |
+
except ImportError:
|
| 482 |
+
print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.")
|
| 483 |
+
raise
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
device = self.device
|
| 487 |
+
model = deepcopy(self.model.to("cpu"))
|
| 488 |
+
model.to(device)
|
| 489 |
+
|
| 490 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 491 |
+
output_dir = Path(output_dir)
|
| 492 |
+
if shape is None:
|
| 493 |
+
shape = (self.resolution, self.resolution)
|
| 494 |
+
else:
|
| 495 |
+
if shape[0] % 14 != 0 or shape[1] % 14 != 0:
|
| 496 |
+
raise ValueError("Shape must be divisible by 14")
|
| 497 |
+
|
| 498 |
+
input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
|
| 499 |
+
input_names = ['input']
|
| 500 |
+
output_names = ['features'] if backbone_only else ['dets', 'labels']
|
| 501 |
+
dynamic_axes = None
|
| 502 |
+
self.model.eval()
|
| 503 |
+
with torch.no_grad():
|
| 504 |
+
if backbone_only:
|
| 505 |
+
features = model(input_tensors)
|
| 506 |
+
print(f"PyTorch inference output shape: {features.shape}")
|
| 507 |
+
else:
|
| 508 |
+
outputs = model(input_tensors)
|
| 509 |
+
dets = outputs['pred_boxes']
|
| 510 |
+
labels = outputs['pred_logits']
|
| 511 |
+
print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
|
| 512 |
+
model.cpu()
|
| 513 |
+
input_tensors = input_tensors.cpu()
|
| 514 |
+
|
| 515 |
+
# Export to ONNX
|
| 516 |
+
output_file = export_onnx(
|
| 517 |
+
output_dir=output_dir,
|
| 518 |
+
model=model,
|
| 519 |
+
input_names=input_names,
|
| 520 |
+
input_tensors=input_tensors,
|
| 521 |
+
output_names=output_names,
|
| 522 |
+
dynamic_axes=dynamic_axes,
|
| 523 |
+
backbone_only=backbone_only,
|
| 524 |
+
verbose=verbose,
|
| 525 |
+
opset_version=opset_version
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
print(f"Successfully exported ONNX model to: {output_file}")
|
| 529 |
+
|
| 530 |
+
if simplify:
|
| 531 |
+
sim_output_file = onnx_simplify(
|
| 532 |
+
onnx_dir=output_file,
|
| 533 |
+
input_names=input_names,
|
| 534 |
+
input_tensors=input_tensors,
|
| 535 |
+
force=force
|
| 536 |
+
)
|
| 537 |
+
print(f"Successfully simplified ONNX model to: {sim_output_file}")
|
| 538 |
+
|
| 539 |
+
print("ONNX export completed successfully")
|
| 540 |
+
self.model = self.model.to(device)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
if __name__ == '__main__':
|
| 544 |
+
parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()])
|
| 545 |
+
args = parser.parse_args()
|
| 546 |
+
|
| 547 |
+
if args.output_dir:
|
| 548 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 549 |
+
|
| 550 |
+
config = vars(args) # Convert Namespace to dictionary
|
| 551 |
+
|
| 552 |
+
if args.subcommand == 'distill':
|
| 553 |
+
distill(**config)
|
| 554 |
+
elif args.subcommand is None:
|
| 555 |
+
main(**config)
|
| 556 |
+
elif args.subcommand == 'export_model':
|
| 557 |
+
filter_keys = [
|
| 558 |
+
"num_classes",
|
| 559 |
+
"grad_accum_steps",
|
| 560 |
+
"lr",
|
| 561 |
+
"lr_encoder",
|
| 562 |
+
"weight_decay",
|
| 563 |
+
"epochs",
|
| 564 |
+
"lr_drop",
|
| 565 |
+
"clip_max_norm",
|
| 566 |
+
"lr_vit_layer_decay",
|
| 567 |
+
"lr_component_decay",
|
| 568 |
+
"dropout",
|
| 569 |
+
"drop_path",
|
| 570 |
+
"drop_mode",
|
| 571 |
+
"drop_schedule",
|
| 572 |
+
"cutoff_epoch",
|
| 573 |
+
"pretrained_encoder",
|
| 574 |
+
"pretrain_weights",
|
| 575 |
+
"pretrain_exclude_keys",
|
| 576 |
+
"pretrain_keys_modify_to_load",
|
| 577 |
+
"freeze_florence",
|
| 578 |
+
"freeze_aimv2",
|
| 579 |
+
"decoder_norm",
|
| 580 |
+
"set_cost_class",
|
| 581 |
+
"set_cost_bbox",
|
| 582 |
+
"set_cost_giou",
|
| 583 |
+
"cls_loss_coef",
|
| 584 |
+
"bbox_loss_coef",
|
| 585 |
+
"giou_loss_coef",
|
| 586 |
+
"focal_alpha",
|
| 587 |
+
"aux_loss",
|
| 588 |
+
"sum_group_losses",
|
| 589 |
+
"use_varifocal_loss",
|
| 590 |
+
"use_position_supervised_loss",
|
| 591 |
+
"ia_bce_loss",
|
| 592 |
+
"dataset_file",
|
| 593 |
+
"coco_path",
|
| 594 |
+
"dataset_dir",
|
| 595 |
+
"square_resize_div_64",
|
| 596 |
+
"output_dir",
|
| 597 |
+
"checkpoint_interval",
|
| 598 |
+
"seed",
|
| 599 |
+
"resume",
|
| 600 |
+
"start_epoch",
|
| 601 |
+
"eval",
|
| 602 |
+
"use_ema",
|
| 603 |
+
"ema_decay",
|
| 604 |
+
"ema_tau",
|
| 605 |
+
"num_workers",
|
| 606 |
+
"device",
|
| 607 |
+
"world_size",
|
| 608 |
+
"dist_url",
|
| 609 |
+
"sync_bn",
|
| 610 |
+
"fp16_eval",
|
| 611 |
+
"infer_dir",
|
| 612 |
+
"verbose",
|
| 613 |
+
"opset_version",
|
| 614 |
+
"dry_run",
|
| 615 |
+
"shape",
|
| 616 |
+
]
|
| 617 |
+
for key in filter_keys:
|
| 618 |
+
config.pop(key, None) # Use pop with None to avoid KeyError
|
| 619 |
+
|
| 620 |
+
from deploy.export import main as export_main
|
| 621 |
+
if args.batch_size != 1:
|
| 622 |
+
config['batch_size'] = 1
|
| 623 |
+
print(f"Only batch_size 1 is supported for onnx export, \
|
| 624 |
+
but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.")
|
| 625 |
+
export_main(**config)
|
| 626 |
+
|
| 627 |
+
def get_args_parser():
|
| 628 |
+
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
| 629 |
+
parser.add_argument('--num_classes', default=2, type=int)
|
| 630 |
+
parser.add_argument('--grad_accum_steps', default=1, type=int)
|
| 631 |
+
parser.add_argument('--amp', default=False, type=bool)
|
| 632 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
| 633 |
+
parser.add_argument('--lr_encoder', default=1.5e-4, type=float)
|
| 634 |
+
parser.add_argument('--batch_size', default=2, type=int)
|
| 635 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
| 636 |
+
parser.add_argument('--epochs', default=12, type=int)
|
| 637 |
+
parser.add_argument('--lr_drop', default=11, type=int)
|
| 638 |
+
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
| 639 |
+
help='gradient clipping max norm')
|
| 640 |
+
parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float)
|
| 641 |
+
parser.add_argument('--lr_component_decay', default=1.0, type=float)
|
| 642 |
+
parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model')
|
| 643 |
+
|
| 644 |
+
# drop args
|
| 645 |
+
# dropout and stochastic depth drop rate; set at most one to non-zero
|
| 646 |
+
parser.add_argument('--dropout', type=float, default=0,
|
| 647 |
+
help='Drop path rate (default: 0.0)')
|
| 648 |
+
parser.add_argument('--drop_path', type=float, default=0,
|
| 649 |
+
help='Drop path rate (default: 0.0)')
|
| 650 |
+
|
| 651 |
+
# early / late dropout and stochastic depth settings
|
| 652 |
+
parser.add_argument('--drop_mode', type=str, default='standard',
|
| 653 |
+
choices=['standard', 'early', 'late'], help='drop mode')
|
| 654 |
+
parser.add_argument('--drop_schedule', type=str, default='constant',
|
| 655 |
+
choices=['constant', 'linear'],
|
| 656 |
+
help='drop schedule for early dropout / s.d. only')
|
| 657 |
+
parser.add_argument('--cutoff_epoch', type=int, default=0,
|
| 658 |
+
help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
|
| 659 |
+
|
| 660 |
+
# Model parameters
|
| 661 |
+
parser.add_argument('--pretrained_encoder', type=str, default=None,
|
| 662 |
+
help="Path to the pretrained encoder.")
|
| 663 |
+
parser.add_argument('--pretrain_weights', type=str, default=None,
|
| 664 |
+
help="Path to the pretrained model.")
|
| 665 |
+
parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
|
| 666 |
+
help="Keys you do not want to load.")
|
| 667 |
+
parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+',
|
| 668 |
+
help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.")
|
| 669 |
+
|
| 670 |
+
# * Backbone
|
| 671 |
+
parser.add_argument('--encoder', default='vit_tiny', type=str,
|
| 672 |
+
help="Name of the transformer or convolutional encoder to use")
|
| 673 |
+
parser.add_argument('--vit_encoder_num_layers', default=12, type=int,
|
| 674 |
+
help="Number of layers used in ViT encoder")
|
| 675 |
+
parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+')
|
| 676 |
+
parser.add_argument('--position_embedding', default='sine', type=str,
|
| 677 |
+
choices=('sine', 'learned'),
|
| 678 |
+
help="Type of positional embedding to use on top of the image features")
|
| 679 |
+
parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now')
|
| 680 |
+
parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder")
|
| 681 |
+
parser.add_argument("--layer_norm", action="store_true", dest="layer_norm")
|
| 682 |
+
parser.add_argument("--rms_norm", action="store_true", dest="rms_norm")
|
| 683 |
+
parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora")
|
| 684 |
+
parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain")
|
| 685 |
+
|
| 686 |
+
# * Transformer
|
| 687 |
+
parser.add_argument('--dec_layers', default=3, type=int,
|
| 688 |
+
help="Number of decoding layers in the transformer")
|
| 689 |
+
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
| 690 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
| 691 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
| 692 |
+
help="Size of the embeddings (dimension of the transformer)")
|
| 693 |
+
parser.add_argument('--sa_nheads', default=8, type=int,
|
| 694 |
+
help="Number of attention heads inside the transformer's self-attentions")
|
| 695 |
+
parser.add_argument('--ca_nheads', default=8, type=int,
|
| 696 |
+
help="Number of attention heads inside the transformer's cross-attentions")
|
| 697 |
+
parser.add_argument('--num_queries', default=300, type=int,
|
| 698 |
+
help="Number of query slots")
|
| 699 |
+
parser.add_argument('--group_detr', default=13, type=int,
|
| 700 |
+
help="Number of groups to speed up detr training")
|
| 701 |
+
parser.add_argument('--two_stage', action='store_true')
|
| 702 |
+
parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6'))
|
| 703 |
+
parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up')
|
| 704 |
+
parser.add_argument('--num_select', default=100, type=int,
|
| 705 |
+
help='the number of predictions selected for evaluation')
|
| 706 |
+
parser.add_argument('--dec_n_points', default=4, type=int,
|
| 707 |
+
help='the number of sampling points')
|
| 708 |
+
parser.add_argument('--decoder_norm', default='LN', type=str)
|
| 709 |
+
parser.add_argument('--bbox_reparam', action='store_true')
|
| 710 |
+
parser.add_argument('--freeze_batch_norm', action='store_true')
|
| 711 |
+
# * Matcher
|
| 712 |
+
parser.add_argument('--set_cost_class', default=2, type=float,
|
| 713 |
+
help="Class coefficient in the matching cost")
|
| 714 |
+
parser.add_argument('--set_cost_bbox', default=5, type=float,
|
| 715 |
+
help="L1 box coefficient in the matching cost")
|
| 716 |
+
parser.add_argument('--set_cost_giou', default=2, type=float,
|
| 717 |
+
help="giou box coefficient in the matching cost")
|
| 718 |
+
|
| 719 |
+
# * Loss coefficients
|
| 720 |
+
parser.add_argument('--cls_loss_coef', default=2, type=float)
|
| 721 |
+
parser.add_argument('--bbox_loss_coef', default=5, type=float)
|
| 722 |
+
parser.add_argument('--giou_loss_coef', default=2, type=float)
|
| 723 |
+
parser.add_argument('--focal_alpha', default=0.25, type=float)
|
| 724 |
+
|
| 725 |
+
# Loss
|
| 726 |
+
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
|
| 727 |
+
help="Disables auxiliary decoding losses (loss at each layer)")
|
| 728 |
+
parser.add_argument('--sum_group_losses', action='store_true',
|
| 729 |
+
help="To sum losses across groups or mean losses.")
|
| 730 |
+
parser.add_argument('--use_varifocal_loss', action='store_true')
|
| 731 |
+
parser.add_argument('--use_position_supervised_loss', action='store_true')
|
| 732 |
+
parser.add_argument('--ia_bce_loss', action='store_true')
|
| 733 |
+
|
| 734 |
+
# dataset parameters
|
| 735 |
+
parser.add_argument('--dataset_file', default='coco')
|
| 736 |
+
parser.add_argument('--coco_path', type=str)
|
| 737 |
+
parser.add_argument('--dataset_dir', type=str)
|
| 738 |
+
parser.add_argument('--square_resize_div_64', action='store_true')
|
| 739 |
+
|
| 740 |
+
parser.add_argument('--output_dir', default='output',
|
| 741 |
+
help='path where to save, empty for no saving')
|
| 742 |
+
parser.add_argument('--dont_save_weights', action='store_true')
|
| 743 |
+
parser.add_argument('--checkpoint_interval', default=10, type=int,
|
| 744 |
+
help='epoch interval to save checkpoint')
|
| 745 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 746 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 747 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 748 |
+
help='start epoch')
|
| 749 |
+
parser.add_argument('--eval', action='store_true')
|
| 750 |
+
parser.add_argument('--use_ema', action='store_true')
|
| 751 |
+
parser.add_argument('--ema_decay', default=0.9997, type=float)
|
| 752 |
+
parser.add_argument('--ema_tau', default=0, type=float)
|
| 753 |
+
|
| 754 |
+
parser.add_argument('--num_workers', default=2, type=int)
|
| 755 |
+
|
| 756 |
+
# distributed training parameters
|
| 757 |
+
parser.add_argument('--device', default='cuda',
|
| 758 |
+
help='device to use for training / testing')
|
| 759 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 760 |
+
help='number of distributed processes')
|
| 761 |
+
parser.add_argument('--dist_url', default='env://',
|
| 762 |
+
help='url used to set up distributed training')
|
| 763 |
+
parser.add_argument('--sync_bn', default=True, type=bool,
|
| 764 |
+
help='setup synchronized BatchNorm for distributed training')
|
| 765 |
+
|
| 766 |
+
# fp16
|
| 767 |
+
parser.add_argument('--fp16_eval', default=False, action='store_true',
|
| 768 |
+
help='evaluate in fp16 precision.')
|
| 769 |
+
|
| 770 |
+
# custom args
|
| 771 |
+
parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only')
|
| 772 |
+
parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only')
|
| 773 |
+
parser.add_argument('--resolution', type=int, default=640, help="input resolution")
|
| 774 |
+
parser.add_argument('--use_cls_token', action='store_true', help='use cls token')
|
| 775 |
+
parser.add_argument('--multi_scale', action='store_true', help='use multi scale')
|
| 776 |
+
parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales')
|
| 777 |
+
parser.add_argument('--warmup_epochs', default=1, type=float,
|
| 778 |
+
help='Number of warmup epochs for linear warmup before cosine annealing')
|
| 779 |
+
# Add scheduler type argument: 'step' or 'cosine'
|
| 780 |
+
parser.add_argument(
|
| 781 |
+
'--lr_scheduler',
|
| 782 |
+
default='step',
|
| 783 |
+
choices=['step', 'cosine'],
|
| 784 |
+
help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'"
|
| 785 |
+
)
|
| 786 |
+
parser.add_argument('--lr_min_factor', default=0.0, type=float,
|
| 787 |
+
help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
|
| 788 |
+
# Early stopping parameters
|
| 789 |
+
parser.add_argument('--early_stopping', action='store_true',
|
| 790 |
+
help='Enable early stopping based on mAP improvement')
|
| 791 |
+
parser.add_argument('--early_stopping_patience', default=10, type=int,
|
| 792 |
+
help='Number of epochs with no improvement after which training will be stopped')
|
| 793 |
+
parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
|
| 794 |
+
help='Minimum change in mAP to qualify as an improvement')
|
| 795 |
+
parser.add_argument('--early_stopping_use_ema', action='store_true',
|
| 796 |
+
help='Use EMA model metrics for early stopping')
|
| 797 |
+
# subparsers
|
| 798 |
+
subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
|
| 799 |
+
description='valid subcommands', help='additional help')
|
| 800 |
+
|
| 801 |
+
# subparser for export model
|
| 802 |
+
parser_export = subparsers.add_parser('export_model', help='LWDETR model export')
|
| 803 |
+
parser_export.add_argument('--infer_dir', type=str, default=None)
|
| 804 |
+
parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True)
|
| 805 |
+
parser_export.add_argument('--opset_version', type=int, default=17)
|
| 806 |
+
parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model")
|
| 807 |
+
parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true',
|
| 808 |
+
help="build tensorrt engine")
|
| 809 |
+
parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command")
|
| 810 |
+
parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export')
|
| 811 |
+
parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)")
|
| 812 |
+
return parser
|
| 813 |
+
|
| 814 |
+
def populate_args(
|
| 815 |
+
# Basic training parameters
|
| 816 |
+
num_classes=2,
|
| 817 |
+
grad_accum_steps=1,
|
| 818 |
+
amp=False,
|
| 819 |
+
lr=1e-4,
|
| 820 |
+
lr_encoder=1.5e-4,
|
| 821 |
+
batch_size=2,
|
| 822 |
+
weight_decay=1e-4,
|
| 823 |
+
epochs=12,
|
| 824 |
+
lr_drop=11,
|
| 825 |
+
clip_max_norm=0.1,
|
| 826 |
+
lr_vit_layer_decay=0.8,
|
| 827 |
+
lr_component_decay=1.0,
|
| 828 |
+
do_benchmark=False,
|
| 829 |
+
|
| 830 |
+
# Drop parameters
|
| 831 |
+
dropout=0,
|
| 832 |
+
drop_path=0,
|
| 833 |
+
drop_mode='standard',
|
| 834 |
+
drop_schedule='constant',
|
| 835 |
+
cutoff_epoch=0,
|
| 836 |
+
|
| 837 |
+
# Model parameters
|
| 838 |
+
pretrained_encoder=None,
|
| 839 |
+
pretrain_weights=None,
|
| 840 |
+
pretrain_exclude_keys=None,
|
| 841 |
+
pretrain_keys_modify_to_load=None,
|
| 842 |
+
pretrained_distiller=None,
|
| 843 |
+
|
| 844 |
+
# Backbone parameters
|
| 845 |
+
encoder='vit_tiny',
|
| 846 |
+
vit_encoder_num_layers=12,
|
| 847 |
+
window_block_indexes=None,
|
| 848 |
+
position_embedding='sine',
|
| 849 |
+
out_feature_indexes=[-1],
|
| 850 |
+
freeze_encoder=False,
|
| 851 |
+
layer_norm=False,
|
| 852 |
+
rms_norm=False,
|
| 853 |
+
backbone_lora=False,
|
| 854 |
+
force_no_pretrain=False,
|
| 855 |
+
|
| 856 |
+
# Transformer parameters
|
| 857 |
+
dec_layers=3,
|
| 858 |
+
dim_feedforward=2048,
|
| 859 |
+
hidden_dim=256,
|
| 860 |
+
sa_nheads=8,
|
| 861 |
+
ca_nheads=8,
|
| 862 |
+
num_queries=300,
|
| 863 |
+
group_detr=13,
|
| 864 |
+
two_stage=False,
|
| 865 |
+
projector_scale='P4',
|
| 866 |
+
lite_refpoint_refine=False,
|
| 867 |
+
num_select=100,
|
| 868 |
+
dec_n_points=4,
|
| 869 |
+
decoder_norm='LN',
|
| 870 |
+
bbox_reparam=False,
|
| 871 |
+
freeze_batch_norm=False,
|
| 872 |
+
|
| 873 |
+
# Matcher parameters
|
| 874 |
+
set_cost_class=2,
|
| 875 |
+
set_cost_bbox=5,
|
| 876 |
+
set_cost_giou=2,
|
| 877 |
+
|
| 878 |
+
# Loss coefficients
|
| 879 |
+
cls_loss_coef=2,
|
| 880 |
+
bbox_loss_coef=5,
|
| 881 |
+
giou_loss_coef=2,
|
| 882 |
+
focal_alpha=0.25,
|
| 883 |
+
aux_loss=True,
|
| 884 |
+
sum_group_losses=False,
|
| 885 |
+
use_varifocal_loss=False,
|
| 886 |
+
use_position_supervised_loss=False,
|
| 887 |
+
ia_bce_loss=False,
|
| 888 |
+
|
| 889 |
+
# Dataset parameters
|
| 890 |
+
dataset_file='coco',
|
| 891 |
+
coco_path=None,
|
| 892 |
+
dataset_dir=None,
|
| 893 |
+
square_resize_div_64=False,
|
| 894 |
+
|
| 895 |
+
# Output parameters
|
| 896 |
+
output_dir='output',
|
| 897 |
+
dont_save_weights=False,
|
| 898 |
+
checkpoint_interval=10,
|
| 899 |
+
seed=42,
|
| 900 |
+
resume='',
|
| 901 |
+
start_epoch=0,
|
| 902 |
+
eval=False,
|
| 903 |
+
use_ema=False,
|
| 904 |
+
ema_decay=0.9997,
|
| 905 |
+
ema_tau=0,
|
| 906 |
+
num_workers=2,
|
| 907 |
+
|
| 908 |
+
# Distributed training parameters
|
| 909 |
+
device='cuda',
|
| 910 |
+
world_size=1,
|
| 911 |
+
dist_url='env://',
|
| 912 |
+
sync_bn=True,
|
| 913 |
+
|
| 914 |
+
# FP16
|
| 915 |
+
fp16_eval=False,
|
| 916 |
+
|
| 917 |
+
# Custom args
|
| 918 |
+
encoder_only=False,
|
| 919 |
+
backbone_only=False,
|
| 920 |
+
resolution=640,
|
| 921 |
+
use_cls_token=False,
|
| 922 |
+
multi_scale=False,
|
| 923 |
+
expanded_scales=False,
|
| 924 |
+
warmup_epochs=1,
|
| 925 |
+
lr_scheduler='step',
|
| 926 |
+
lr_min_factor=0.0,
|
| 927 |
+
# Early stopping parameters
|
| 928 |
+
early_stopping=True,
|
| 929 |
+
early_stopping_patience=10,
|
| 930 |
+
early_stopping_min_delta=0.001,
|
| 931 |
+
early_stopping_use_ema=False,
|
| 932 |
+
gradient_checkpointing=False,
|
| 933 |
+
# Additional
|
| 934 |
+
subcommand=None,
|
| 935 |
+
**extra_kwargs # To handle any unexpected arguments
|
| 936 |
+
):
|
| 937 |
+
args = argparse.Namespace(
|
| 938 |
+
num_classes=num_classes,
|
| 939 |
+
grad_accum_steps=grad_accum_steps,
|
| 940 |
+
amp=amp,
|
| 941 |
+
lr=lr,
|
| 942 |
+
lr_encoder=lr_encoder,
|
| 943 |
+
batch_size=batch_size,
|
| 944 |
+
weight_decay=weight_decay,
|
| 945 |
+
epochs=epochs,
|
| 946 |
+
lr_drop=lr_drop,
|
| 947 |
+
clip_max_norm=clip_max_norm,
|
| 948 |
+
lr_vit_layer_decay=lr_vit_layer_decay,
|
| 949 |
+
lr_component_decay=lr_component_decay,
|
| 950 |
+
do_benchmark=do_benchmark,
|
| 951 |
+
dropout=dropout,
|
| 952 |
+
drop_path=drop_path,
|
| 953 |
+
drop_mode=drop_mode,
|
| 954 |
+
drop_schedule=drop_schedule,
|
| 955 |
+
cutoff_epoch=cutoff_epoch,
|
| 956 |
+
pretrained_encoder=pretrained_encoder,
|
| 957 |
+
pretrain_weights=pretrain_weights,
|
| 958 |
+
pretrain_exclude_keys=pretrain_exclude_keys,
|
| 959 |
+
pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
|
| 960 |
+
pretrained_distiller=pretrained_distiller,
|
| 961 |
+
encoder=encoder,
|
| 962 |
+
vit_encoder_num_layers=vit_encoder_num_layers,
|
| 963 |
+
window_block_indexes=window_block_indexes,
|
| 964 |
+
position_embedding=position_embedding,
|
| 965 |
+
out_feature_indexes=out_feature_indexes,
|
| 966 |
+
freeze_encoder=freeze_encoder,
|
| 967 |
+
layer_norm=layer_norm,
|
| 968 |
+
rms_norm=rms_norm,
|
| 969 |
+
backbone_lora=backbone_lora,
|
| 970 |
+
force_no_pretrain=force_no_pretrain,
|
| 971 |
+
dec_layers=dec_layers,
|
| 972 |
+
dim_feedforward=dim_feedforward,
|
| 973 |
+
hidden_dim=hidden_dim,
|
| 974 |
+
sa_nheads=sa_nheads,
|
| 975 |
+
ca_nheads=ca_nheads,
|
| 976 |
+
num_queries=num_queries,
|
| 977 |
+
group_detr=group_detr,
|
| 978 |
+
two_stage=two_stage,
|
| 979 |
+
projector_scale=projector_scale,
|
| 980 |
+
lite_refpoint_refine=lite_refpoint_refine,
|
| 981 |
+
num_select=num_select,
|
| 982 |
+
dec_n_points=dec_n_points,
|
| 983 |
+
decoder_norm=decoder_norm,
|
| 984 |
+
bbox_reparam=bbox_reparam,
|
| 985 |
+
freeze_batch_norm=freeze_batch_norm,
|
| 986 |
+
set_cost_class=set_cost_class,
|
| 987 |
+
set_cost_bbox=set_cost_bbox,
|
| 988 |
+
set_cost_giou=set_cost_giou,
|
| 989 |
+
cls_loss_coef=cls_loss_coef,
|
| 990 |
+
bbox_loss_coef=bbox_loss_coef,
|
| 991 |
+
giou_loss_coef=giou_loss_coef,
|
| 992 |
+
focal_alpha=focal_alpha,
|
| 993 |
+
aux_loss=aux_loss,
|
| 994 |
+
sum_group_losses=sum_group_losses,
|
| 995 |
+
use_varifocal_loss=use_varifocal_loss,
|
| 996 |
+
use_position_supervised_loss=use_position_supervised_loss,
|
| 997 |
+
ia_bce_loss=ia_bce_loss,
|
| 998 |
+
dataset_file=dataset_file,
|
| 999 |
+
coco_path=coco_path,
|
| 1000 |
+
dataset_dir=dataset_dir,
|
| 1001 |
+
square_resize_div_64=square_resize_div_64,
|
| 1002 |
+
output_dir=output_dir,
|
| 1003 |
+
dont_save_weights=dont_save_weights,
|
| 1004 |
+
checkpoint_interval=checkpoint_interval,
|
| 1005 |
+
seed=seed,
|
| 1006 |
+
resume=resume,
|
| 1007 |
+
start_epoch=start_epoch,
|
| 1008 |
+
eval=eval,
|
| 1009 |
+
use_ema=use_ema,
|
| 1010 |
+
ema_decay=ema_decay,
|
| 1011 |
+
ema_tau=ema_tau,
|
| 1012 |
+
num_workers=num_workers,
|
| 1013 |
+
device=device,
|
| 1014 |
+
world_size=world_size,
|
| 1015 |
+
dist_url=dist_url,
|
| 1016 |
+
sync_bn=sync_bn,
|
| 1017 |
+
fp16_eval=fp16_eval,
|
| 1018 |
+
encoder_only=encoder_only,
|
| 1019 |
+
backbone_only=backbone_only,
|
| 1020 |
+
resolution=resolution,
|
| 1021 |
+
use_cls_token=use_cls_token,
|
| 1022 |
+
multi_scale=multi_scale,
|
| 1023 |
+
expanded_scales=expanded_scales,
|
| 1024 |
+
warmup_epochs=warmup_epochs,
|
| 1025 |
+
lr_scheduler=lr_scheduler,
|
| 1026 |
+
lr_min_factor=lr_min_factor,
|
| 1027 |
+
early_stopping=early_stopping,
|
| 1028 |
+
early_stopping_patience=early_stopping_patience,
|
| 1029 |
+
early_stopping_min_delta=early_stopping_min_delta,
|
| 1030 |
+
early_stopping_use_ema=early_stopping_use_ema,
|
| 1031 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 1032 |
+
**extra_kwargs
|
| 1033 |
+
)
|
| 1034 |
+
return args
|
rfdetr/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# RF-DETR
|
| 3 |
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Copied from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
| 7 |
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
# Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 10 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
| 13 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 14 |
+
# ------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
from .lwdetr import build_model, build_criterion_and_postprocessors
|
rfdetr/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (277 Bytes). View file
|
|
|