asammoud commited on
Commit
3f2c461
·
1 Parent(s): e912493

add redetr

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. rfdetr/__init__.py +12 -0
  2. rfdetr/__pycache__/__init__.cpython-310.pyc +0 -0
  3. rfdetr/__pycache__/__init__.cpython-312.pyc +0 -0
  4. rfdetr/__pycache__/__init__.cpython-39.pyc +0 -0
  5. rfdetr/__pycache__/config.cpython-310.pyc +0 -0
  6. rfdetr/__pycache__/config.cpython-312.pyc +0 -0
  7. rfdetr/__pycache__/config.cpython-39.pyc +0 -0
  8. rfdetr/__pycache__/detr.cpython-310.pyc +0 -0
  9. rfdetr/__pycache__/detr.cpython-312.pyc +0 -0
  10. rfdetr/__pycache__/detr.cpython-39.pyc +0 -0
  11. rfdetr/__pycache__/engine.cpython-310.pyc +0 -0
  12. rfdetr/__pycache__/engine.cpython-312.pyc +0 -0
  13. rfdetr/__pycache__/engine.cpython-39.pyc +0 -0
  14. rfdetr/__pycache__/main.cpython-310.pyc +0 -0
  15. rfdetr/__pycache__/main.cpython-312.pyc +0 -0
  16. rfdetr/__pycache__/main.cpython-39.pyc +0 -0
  17. rfdetr/cli/main.py +87 -0
  18. rfdetr/config.py +90 -0
  19. rfdetr/datasets/__init__.py +36 -0
  20. rfdetr/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  21. rfdetr/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
  22. rfdetr/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  23. rfdetr/datasets/__pycache__/coco.cpython-310.pyc +0 -0
  24. rfdetr/datasets/__pycache__/coco.cpython-312.pyc +0 -0
  25. rfdetr/datasets/__pycache__/coco.cpython-39.pyc +0 -0
  26. rfdetr/datasets/__pycache__/coco_eval.cpython-310.pyc +0 -0
  27. rfdetr/datasets/__pycache__/coco_eval.cpython-312.pyc +0 -0
  28. rfdetr/datasets/__pycache__/coco_eval.cpython-39.pyc +0 -0
  29. rfdetr/datasets/__pycache__/o365.cpython-310.pyc +0 -0
  30. rfdetr/datasets/__pycache__/o365.cpython-312.pyc +0 -0
  31. rfdetr/datasets/__pycache__/o365.cpython-39.pyc +0 -0
  32. rfdetr/datasets/__pycache__/transforms.cpython-310.pyc +0 -0
  33. rfdetr/datasets/__pycache__/transforms.cpython-312.pyc +0 -0
  34. rfdetr/datasets/__pycache__/transforms.cpython-39.pyc +0 -0
  35. rfdetr/datasets/coco.py +250 -0
  36. rfdetr/datasets/coco_eval.py +271 -0
  37. rfdetr/datasets/o365.py +53 -0
  38. rfdetr/datasets/transforms.py +475 -0
  39. rfdetr/deploy/__init__.py +0 -0
  40. rfdetr/deploy/_onnx/__init__.py +13 -0
  41. rfdetr/deploy/_onnx/optimizer.py +579 -0
  42. rfdetr/deploy/_onnx/symbolic.py +37 -0
  43. rfdetr/deploy/benchmark.py +590 -0
  44. rfdetr/deploy/export.py +276 -0
  45. rfdetr/deploy/requirements.txt +8 -0
  46. rfdetr/detr.py +315 -0
  47. rfdetr/engine.py +256 -0
  48. rfdetr/main.py +1034 -0
  49. rfdetr/models/__init__.py +16 -0
  50. rfdetr/models/__pycache__/__init__.cpython-310.pyc +0 -0
rfdetr/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ import os
9
+ if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
10
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11
+
12
+ from rfdetr.detr import RFDETRBase, RFDETRLarge
rfdetr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (341 Bytes). View file
 
rfdetr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (441 Bytes). View file
 
rfdetr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (314 Bytes). View file
 
rfdetr/__pycache__/config.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
rfdetr/__pycache__/config.cpython-312.pyc ADDED
Binary file (4.82 kB). View file
 
rfdetr/__pycache__/config.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
 
rfdetr/__pycache__/detr.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
rfdetr/__pycache__/detr.cpython-312.pyc ADDED
Binary file (16.9 kB). View file
 
rfdetr/__pycache__/detr.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
rfdetr/__pycache__/engine.cpython-310.pyc ADDED
Binary file (7.08 kB). View file
 
rfdetr/__pycache__/engine.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
rfdetr/__pycache__/engine.cpython-39.pyc ADDED
Binary file (7.13 kB). View file
 
rfdetr/__pycache__/main.cpython-310.pyc ADDED
Binary file (25.9 kB). View file
 
rfdetr/__pycache__/main.cpython-312.pyc ADDED
Binary file (45.4 kB). View file
 
rfdetr/__pycache__/main.cpython-39.pyc ADDED
Binary file (25.3 kB). View file
 
rfdetr/cli/main.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ import argparse
11
+ from rf100vl import get_rf100vl_projects
12
+ import roboflow
13
+ from rfdetr import RFDETRBase
14
+ import torch
15
+ import os
16
+
17
+ def download_dataset(rf_project: roboflow.Project, dataset_version: int):
18
+ versions = rf_project.versions()
19
+ if dataset_version is not None:
20
+ versions = [v for v in versions if v.version == str(dataset_version)]
21
+ if len(versions) == 0:
22
+ raise ValueError(f"Dataset version {dataset_version} not found")
23
+ version = versions[0]
24
+ else:
25
+ version = max(versions, key=lambda v: v.id)
26
+ location = os.path.join("datasets/", rf_project.name + "_v" + version.version)
27
+ if not os.path.exists(location):
28
+ location = version.download(
29
+ model_format="coco", location=location, overwrite=False
30
+ ).location
31
+
32
+ return location
33
+
34
+
35
+ def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int):
36
+ location = download_dataset(rf_project, dataset_version)
37
+ print(location)
38
+ rf_detr = RFDETRBase()
39
+ device_supports_cuda = torch.cuda.is_available()
40
+ rf_detr.train(
41
+ dataset_dir=location,
42
+ epochs=1,
43
+ device="cuda" if device_supports_cuda else "cpu",
44
+ )
45
+
46
+
47
+ def train_from_coco_dir(coco_dir: str):
48
+ rf_detr = RFDETRBase()
49
+ rf_detr.train(
50
+ dataset_dir=coco_dir,
51
+ epochs=1,
52
+ device="cuda" if device_supports_cuda else "cpu",
53
+ )
54
+
55
+
56
+ def trainer():
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--coco_dir", type=str, required=False)
59
+ parser.add_argument("--api_key", type=str, required=False)
60
+ parser.add_argument("--workspace", type=str, required=False, default=None)
61
+ parser.add_argument("--project_name", type=str, required=False, default=None)
62
+ parser.add_argument("--dataset_version", type=int, required=False, default=None)
63
+ args = parser.parse_args()
64
+
65
+ if args.coco_dir is not None:
66
+ train_from_coco_dir(args.coco_dir)
67
+ return
68
+
69
+ if (args.workspace is None and args.project_name is not None) or (
70
+ args.workspace is not None and args.project_name is None
71
+ ):
72
+ raise ValueError(
73
+ "Either both workspace and project_name must be provided or none of them"
74
+ )
75
+
76
+ if args.workspace is not None:
77
+ rf = roboflow.Roboflow(api_key=args.api_key)
78
+ project = rf.workspace(args.workspace).project(args.project_name)
79
+ else:
80
+ projects = get_rf100vl_projects(api_key=args.api_key)
81
+ project = projects[0].rf_project
82
+
83
+ train_from_rf_project(project, args.dataset_version)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ trainer()
rfdetr/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ from pydantic import BaseModel
9
+ from typing import List, Optional, Literal, Type
10
+ import torch
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12
+
13
+ class ModelConfig(BaseModel):
14
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
15
+ out_feature_indexes: List[int]
16
+ dec_layers: int = 3
17
+ two_stage: bool = True
18
+ projector_scale: List[Literal["P3", "P4", "P5"]]
19
+ hidden_dim: int
20
+ sa_nheads: int
21
+ ca_nheads: int
22
+ dec_n_points: int
23
+ bbox_reparam: bool = True
24
+ lite_refpoint_refine: bool = True
25
+ layer_norm: bool = True
26
+ amp: bool = True
27
+ num_classes: int = 90
28
+ pretrain_weights: Optional[str] = None
29
+ device: Literal["cpu", "cuda", "mps"] = DEVICE
30
+ resolution: int = 560
31
+ group_detr: int = 13
32
+ gradient_checkpointing: bool = False
33
+
34
+ class RFDETRBaseConfig(ModelConfig):
35
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
36
+ hidden_dim: int = 256
37
+ sa_nheads: int = 8
38
+ ca_nheads: int = 16
39
+ dec_n_points: int = 2
40
+ num_queries: int = 300
41
+ num_select: int = 300
42
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
43
+ out_feature_indexes: List[int] = [2, 5, 8, 11]
44
+ pretrain_weights: Optional[str] = "rf-detr-base.pth"
45
+
46
+ class RFDETRLargeConfig(RFDETRBaseConfig):
47
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
48
+ hidden_dim: int = 384
49
+ sa_nheads: int = 12
50
+ ca_nheads: int = 24
51
+ dec_n_points: int = 4
52
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
53
+ pretrain_weights: Optional[str] = "rf-detr-large.pth"
54
+
55
+ class TrainConfig(BaseModel):
56
+ lr: float = 1e-4
57
+ lr_encoder: float = 1.5e-4
58
+ batch_size: int = 4
59
+ grad_accum_steps: int = 4
60
+ epochs: int = 100
61
+ ema_decay: float = 0.993
62
+ ema_tau: int = 100
63
+ lr_drop: int = 100
64
+ checkpoint_interval: int = 10
65
+ warmup_epochs: int = 0
66
+ lr_vit_layer_decay: float = 0.8
67
+ lr_component_decay: float = 0.7
68
+ drop_path: float = 0.0
69
+ group_detr: int = 13
70
+ ia_bce_loss: bool = True
71
+ cls_loss_coef: float = 1.0
72
+ num_select: int = 300
73
+ dataset_file: Literal["coco", "o365", "roboflow"] = "roboflow"
74
+ square_resize_div_64: bool = True
75
+ dataset_dir: str
76
+ output_dir: str = "output"
77
+ multi_scale: bool = True
78
+ expanded_scales: bool = True
79
+ use_ema: bool = True
80
+ num_workers: int = 2
81
+ weight_decay: float = 1e-4
82
+ early_stopping: bool = False
83
+ early_stopping_patience: int = 10
84
+ early_stopping_min_delta: float = 0.001
85
+ early_stopping_use_ema: bool = False
86
+ tensorboard: bool = True
87
+ wandb: bool = False
88
+ project: Optional[str] = None
89
+ run: Optional[str] = None
90
+ class_names: List[str] = None
rfdetr/datasets/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # LW-DETR
3
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from DETR (https://github.com/facebookresearch/detr)
10
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+
13
+ import torch.utils.data
14
+ import torchvision
15
+
16
+ from .coco import build as build_coco
17
+ from .o365 import build_o365
18
+ from .coco import build_roboflow
19
+
20
+
21
+ def get_coco_api_from_dataset(dataset):
22
+ for _ in range(10):
23
+ if isinstance(dataset, torch.utils.data.Subset):
24
+ dataset = dataset.dataset
25
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
26
+ return dataset.coco
27
+
28
+
29
+ def build_dataset(image_set, args, resolution):
30
+ if args.dataset_file == 'coco':
31
+ return build_coco(image_set, args, resolution)
32
+ if args.dataset_file == 'o365':
33
+ return build_o365(image_set, args, resolution)
34
+ if args.dataset_file == 'roboflow':
35
+ return build_roboflow(image_set, args, resolution)
36
+ raise ValueError(f'dataset {args.dataset_file} not supported')
rfdetr/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (971 Bytes). View file
 
rfdetr/datasets/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.49 kB). View file
 
rfdetr/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (944 Bytes). View file
 
rfdetr/datasets/__pycache__/coco.cpython-310.pyc ADDED
Binary file (6.28 kB). View file
 
rfdetr/datasets/__pycache__/coco.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
rfdetr/datasets/__pycache__/coco.cpython-39.pyc ADDED
Binary file (6.3 kB). View file
 
rfdetr/datasets/__pycache__/coco_eval.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
rfdetr/datasets/__pycache__/coco_eval.cpython-312.pyc ADDED
Binary file (11.8 kB). View file
 
rfdetr/datasets/__pycache__/coco_eval.cpython-39.pyc ADDED
Binary file (7.29 kB). View file
 
rfdetr/datasets/__pycache__/o365.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
rfdetr/datasets/__pycache__/o365.cpython-312.pyc ADDED
Binary file (1.97 kB). View file
 
rfdetr/datasets/__pycache__/o365.cpython-39.pyc ADDED
Binary file (1.29 kB). View file
 
rfdetr/datasets/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
rfdetr/datasets/__pycache__/transforms.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
rfdetr/datasets/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (15.1 kB). View file
 
rfdetr/datasets/coco.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ COCO dataset which returns image_id for evaluation.
18
+
19
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
20
+ """
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ import torch.utils.data
25
+ import torchvision
26
+
27
+ import rfdetr.datasets.transforms as T
28
+
29
+
30
+ def compute_multi_scale_scales(resolution, expanded_scales=False):
31
+ if resolution == 640:
32
+ # assume we're doing the original 640x640 and therefore patch_size is 16
33
+ patch_size = 16
34
+ elif resolution % (14 * 4) == 0:
35
+ # assume we're doing some dinov2 resolution variant and therefore patch_size is 14
36
+ patch_size = 14
37
+ elif resolution % (16 * 4) == 0:
38
+ # assume we're doing some other resolution and therefore patch_size is 16
39
+ patch_size = 16
40
+ else:
41
+ raise ValueError(f"Resolution {resolution} is not divisible by 16*4 or 14*4")
42
+ # round to the nearest multiple of 4*patch_size to enable both patching and windowing
43
+ base_num_patches_per_window = resolution // (patch_size * 4)
44
+ offsets = [-3, -2, -1, 0, 1, 2, 3, 4] if not expanded_scales else [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
45
+ scales = [base_num_patches_per_window + offset for offset in offsets]
46
+ proposed_scales = [scale * patch_size * 4 for scale in scales]
47
+ proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * 4] # ensure minimum image size
48
+ return proposed_scales
49
+
50
+
51
+ class CocoDetection(torchvision.datasets.CocoDetection):
52
+ def __init__(self, img_folder, ann_file, transforms):
53
+ super(CocoDetection, self).__init__(img_folder, ann_file)
54
+ self._transforms = transforms
55
+ self.prepare = ConvertCoco()
56
+
57
+ def __getitem__(self, idx):
58
+ img, target = super(CocoDetection, self).__getitem__(idx)
59
+ image_id = self.ids[idx]
60
+ target = {'image_id': image_id, 'annotations': target}
61
+ img, target = self.prepare(img, target)
62
+ if self._transforms is not None:
63
+ img, target = self._transforms(img, target)
64
+ return img, target
65
+
66
+
67
+ class ConvertCoco(object):
68
+
69
+ def __call__(self, image, target):
70
+ w, h = image.size
71
+
72
+ image_id = target["image_id"]
73
+ image_id = torch.tensor([image_id])
74
+
75
+ anno = target["annotations"]
76
+
77
+ anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
78
+
79
+ boxes = [obj["bbox"] for obj in anno]
80
+ # guard against no boxes via resizing
81
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
82
+ boxes[:, 2:] += boxes[:, :2]
83
+ boxes[:, 0::2].clamp_(min=0, max=w)
84
+ boxes[:, 1::2].clamp_(min=0, max=h)
85
+
86
+ classes = [obj["category_id"] for obj in anno]
87
+ classes = torch.tensor(classes, dtype=torch.int64)
88
+
89
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
90
+ boxes = boxes[keep]
91
+ classes = classes[keep]
92
+
93
+ target = {}
94
+ target["boxes"] = boxes
95
+ target["labels"] = classes
96
+ target["image_id"] = image_id
97
+
98
+ # for conversion to coco api
99
+ area = torch.tensor([obj["area"] for obj in anno])
100
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
101
+ target["area"] = area[keep]
102
+ target["iscrowd"] = iscrowd[keep]
103
+
104
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
105
+ target["size"] = torch.as_tensor([int(h), int(w)])
106
+
107
+ return image, target
108
+
109
+
110
+ def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False):
111
+
112
+ normalize = T.Compose([
113
+ T.ToTensor(),
114
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
115
+ ])
116
+
117
+ scales = [resolution]
118
+ if multi_scale:
119
+ # scales = [448, 512, 576, 640, 704, 768, 832, 896]
120
+ scales = compute_multi_scale_scales(resolution, expanded_scales)
121
+ print(scales)
122
+
123
+ if image_set == 'train':
124
+ return T.Compose([
125
+ T.RandomHorizontalFlip(),
126
+ T.RandomSelect(
127
+ T.RandomResize(scales, max_size=1333),
128
+ T.Compose([
129
+ T.RandomResize([400, 500, 600]),
130
+ T.RandomSizeCrop(384, 600),
131
+ T.RandomResize(scales, max_size=1333),
132
+ ])
133
+ ),
134
+ normalize,
135
+ ])
136
+
137
+ if image_set == 'val':
138
+ return T.Compose([
139
+ T.RandomResize([resolution], max_size=1333),
140
+ normalize,
141
+ ])
142
+ if image_set == 'val_speed':
143
+ return T.Compose([
144
+ T.SquareResize([resolution]),
145
+ normalize,
146
+ ])
147
+
148
+ raise ValueError(f'unknown {image_set}')
149
+
150
+
151
+ def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False):
152
+ """
153
+ """
154
+
155
+ normalize = T.Compose([
156
+ T.ToTensor(),
157
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
158
+ ])
159
+
160
+
161
+ scales = [resolution]
162
+ if multi_scale:
163
+ # scales = [448, 512, 576, 640, 704, 768, 832, 896]
164
+ scales = compute_multi_scale_scales(resolution, expanded_scales)
165
+ print(scales)
166
+
167
+ if image_set == 'train':
168
+ return T.Compose([
169
+ T.RandomHorizontalFlip(),
170
+ T.RandomSelect(
171
+ T.SquareResize(scales),
172
+ T.Compose([
173
+ T.RandomResize([400, 500, 600]),
174
+ T.RandomSizeCrop(384, 600),
175
+ T.SquareResize(scales),
176
+ ]),
177
+ ),
178
+ normalize,
179
+ ])
180
+
181
+ if image_set == 'val':
182
+ return T.Compose([
183
+ T.SquareResize([resolution]),
184
+ normalize,
185
+ ])
186
+ if image_set == 'val_speed':
187
+ return T.Compose([
188
+ T.SquareResize([resolution]),
189
+ normalize,
190
+ ])
191
+
192
+ raise ValueError(f'unknown {image_set}')
193
+
194
+ def build(image_set, args, resolution):
195
+ root = Path(args.coco_path)
196
+ assert root.exists(), f'provided COCO path {root} does not exist'
197
+ mode = 'instances'
198
+ PATHS = {
199
+ "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
200
+ "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
201
+ "test": (root / "test2017", root / "annotations" / f'image_info_test-dev2017.json'),
202
+ }
203
+
204
+ img_folder, ann_file = PATHS[image_set.split("_")[0]]
205
+
206
+ try:
207
+ square_resize = args.square_resize
208
+ except:
209
+ square_resize = False
210
+
211
+ try:
212
+ square_resize_div_64 = args.square_resize_div_64
213
+ except:
214
+ square_resize_div_64 = False
215
+
216
+
217
+ if square_resize_div_64:
218
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
219
+ else:
220
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
221
+ return dataset
222
+
223
+ def build_roboflow(image_set, args, resolution):
224
+ root = Path(args.dataset_dir)
225
+ assert root.exists(), f'provided Roboflow path {root} does not exist'
226
+ mode = 'instances'
227
+ PATHS = {
228
+ "train": (root / "train", root / "train" / "_annotations.coco.json"),
229
+ "val": (root / "valid", root / "valid" / "_annotations.coco.json"),
230
+ "test": (root / "test", root / "test" / "_annotations.coco.json"),
231
+ }
232
+
233
+ img_folder, ann_file = PATHS[image_set.split("_")[0]]
234
+
235
+ try:
236
+ square_resize = args.square_resize
237
+ except:
238
+ square_resize = False
239
+
240
+ try:
241
+ square_resize_div_64 = args.square_resize_div_64
242
+ except:
243
+ square_resize_div_64 = False
244
+
245
+
246
+ if square_resize_div_64:
247
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale))
248
+ else:
249
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale))
250
+ return dataset
rfdetr/datasets/coco_eval.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ COCO evaluator that works in distributed mode.
18
+
19
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
20
+ The difference is that there is less copy-pasting from pycocotools
21
+ in the end of the file, as python3 can suppress prints with contextlib
22
+ """
23
+ import os
24
+ import contextlib
25
+ import copy
26
+ import numpy as np
27
+ import torch
28
+
29
+ from pycocotools.cocoeval import COCOeval
30
+ from pycocotools.coco import COCO
31
+ import pycocotools.mask as mask_util
32
+
33
+ from rfdetr.util.misc import all_gather
34
+
35
+
36
+ class CocoEvaluator(object):
37
+ def __init__(self, coco_gt, iou_types):
38
+ assert isinstance(iou_types, (list, tuple))
39
+ coco_gt = copy.deepcopy(coco_gt)
40
+ self.coco_gt = coco_gt
41
+
42
+ self.iou_types = iou_types
43
+ self.coco_eval = {}
44
+ for iou_type in iou_types:
45
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
46
+
47
+ self.img_ids = []
48
+ self.eval_imgs = {k: [] for k in iou_types}
49
+
50
+ def update(self, predictions):
51
+ img_ids = list(np.unique(list(predictions.keys())))
52
+ self.img_ids.extend(img_ids)
53
+
54
+ for iou_type in self.iou_types:
55
+ results = self.prepare(predictions, iou_type)
56
+
57
+ # suppress pycocotools prints
58
+ with open(os.devnull, 'w') as devnull:
59
+ with contextlib.redirect_stdout(devnull):
60
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
61
+ coco_eval = self.coco_eval[iou_type]
62
+
63
+ coco_eval.cocoDt = coco_dt
64
+ coco_eval.params.imgIds = list(img_ids)
65
+ img_ids, eval_imgs = evaluate(coco_eval)
66
+
67
+ self.eval_imgs[iou_type].append(eval_imgs)
68
+
69
+ def synchronize_between_processes(self):
70
+ for iou_type in self.iou_types:
71
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
72
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
73
+
74
+ def accumulate(self):
75
+ for coco_eval in self.coco_eval.values():
76
+ coco_eval.accumulate()
77
+
78
+ def summarize(self):
79
+ for iou_type, coco_eval in self.coco_eval.items():
80
+ print("IoU metric: {}".format(iou_type))
81
+ coco_eval.summarize()
82
+
83
+ def prepare(self, predictions, iou_type):
84
+ if iou_type == "bbox":
85
+ return self.prepare_for_coco_detection(predictions)
86
+ elif iou_type == "segm":
87
+ return self.prepare_for_coco_segmentation(predictions)
88
+ elif iou_type == "keypoints":
89
+ return self.prepare_for_coco_keypoint(predictions)
90
+ else:
91
+ raise ValueError("Unknown iou type {}".format(iou_type))
92
+
93
+ def prepare_for_coco_detection(self, predictions):
94
+ coco_results = []
95
+ for original_id, prediction in predictions.items():
96
+ if len(prediction) == 0:
97
+ continue
98
+
99
+ boxes = prediction["boxes"]
100
+ boxes = convert_to_xywh(boxes).tolist()
101
+ scores = prediction["scores"].tolist()
102
+ labels = prediction["labels"].tolist()
103
+
104
+ coco_results.extend(
105
+ [
106
+ {
107
+ "image_id": original_id,
108
+ "category_id": labels[k],
109
+ "bbox": box,
110
+ "score": scores[k],
111
+ }
112
+ for k, box in enumerate(boxes)
113
+ ]
114
+ )
115
+ return coco_results
116
+
117
+ def prepare_for_coco_segmentation(self, predictions):
118
+ coco_results = []
119
+ for original_id, prediction in predictions.items():
120
+ if len(prediction) == 0:
121
+ continue
122
+
123
+ scores = prediction["scores"]
124
+ labels = prediction["labels"]
125
+ masks = prediction["masks"]
126
+
127
+ masks = masks > 0.5
128
+
129
+ scores = prediction["scores"].tolist()
130
+ labels = prediction["labels"].tolist()
131
+
132
+ rles = [
133
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
134
+ for mask in masks
135
+ ]
136
+ for rle in rles:
137
+ rle["counts"] = rle["counts"].decode("utf-8")
138
+
139
+ coco_results.extend(
140
+ [
141
+ {
142
+ "image_id": original_id,
143
+ "category_id": labels[k],
144
+ "segmentation": rle,
145
+ "score": scores[k],
146
+ }
147
+ for k, rle in enumerate(rles)
148
+ ]
149
+ )
150
+ return coco_results
151
+
152
+ def prepare_for_coco_keypoint(self, predictions):
153
+ coco_results = []
154
+ for original_id, prediction in predictions.items():
155
+ if len(prediction) == 0:
156
+ continue
157
+
158
+ boxes = prediction["boxes"]
159
+ boxes = convert_to_xywh(boxes).tolist()
160
+ scores = prediction["scores"].tolist()
161
+ labels = prediction["labels"].tolist()
162
+ keypoints = prediction["keypoints"]
163
+ keypoints = keypoints.flatten(start_dim=1).tolist()
164
+
165
+ coco_results.extend(
166
+ [
167
+ {
168
+ "image_id": original_id,
169
+ "category_id": labels[k],
170
+ 'keypoints': keypoint,
171
+ "score": scores[k],
172
+ }
173
+ for k, keypoint in enumerate(keypoints)
174
+ ]
175
+ )
176
+ return coco_results
177
+
178
+
179
+ def convert_to_xywh(boxes):
180
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
181
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
182
+
183
+
184
+ def merge(img_ids, eval_imgs):
185
+ all_img_ids = all_gather(img_ids)
186
+ all_eval_imgs = all_gather(eval_imgs)
187
+
188
+ merged_img_ids = []
189
+ for p in all_img_ids:
190
+ merged_img_ids.extend(p)
191
+
192
+ merged_eval_imgs = []
193
+ for p in all_eval_imgs:
194
+ merged_eval_imgs.append(p)
195
+
196
+ merged_img_ids = np.array(merged_img_ids)
197
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
198
+
199
+ # keep only unique (and in sorted order) images
200
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
201
+ merged_eval_imgs = merged_eval_imgs[..., idx]
202
+
203
+ return merged_img_ids, merged_eval_imgs
204
+
205
+
206
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
207
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
208
+ img_ids = list(img_ids)
209
+ eval_imgs = list(eval_imgs.flatten())
210
+
211
+ coco_eval.evalImgs = eval_imgs
212
+ coco_eval.params.imgIds = img_ids
213
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
214
+
215
+
216
+ #################################################################
217
+ # From pycocotools, just removed the prints and fixed
218
+ # a Python3 bug about unicode not defined
219
+ #################################################################
220
+
221
+
222
+ def evaluate(self):
223
+ '''
224
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
225
+ :return: None
226
+ '''
227
+ # tic = time.time()
228
+ # print('Running per image evaluation...')
229
+ p = self.params
230
+ # add backward compatibility if useSegm is specified in params
231
+ if p.useSegm is not None:
232
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
233
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
234
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
235
+ p.imgIds = list(np.unique(p.imgIds))
236
+ if p.useCats:
237
+ p.catIds = list(np.unique(p.catIds))
238
+ p.maxDets = sorted(p.maxDets)
239
+ self.params = p
240
+
241
+ self._prepare()
242
+ # loop through images, area range, max detection number
243
+ catIds = p.catIds if p.useCats else [-1]
244
+
245
+ if p.iouType == 'segm' or p.iouType == 'bbox':
246
+ computeIoU = self.computeIoU
247
+ elif p.iouType == 'keypoints':
248
+ computeIoU = self.computeOks
249
+ self.ious = {
250
+ (imgId, catId): computeIoU(imgId, catId)
251
+ for imgId in p.imgIds
252
+ for catId in catIds}
253
+
254
+ evaluateImg = self.evaluateImg
255
+ maxDet = p.maxDets[-1]
256
+ evalImgs = [
257
+ evaluateImg(imgId, catId, areaRng, maxDet)
258
+ for catId in catIds
259
+ for areaRng in p.areaRng
260
+ for imgId in p.imgIds
261
+ ]
262
+ # this is NOT in the pycocotools code, but could be done outside
263
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
264
+ self._paramsEval = copy.deepcopy(self.params)
265
+ # toc = time.time()
266
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
267
+ return p.imgIds, evalImgs
268
+
269
+ #################################################################
270
+ # end of straight copy from pycocotools, just removing the prints
271
+ #################################################################
rfdetr/datasets/o365.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """Dataset file for Object365."""
11
+ from pathlib import Path
12
+
13
+ from .coco import (
14
+ CocoDetection, make_coco_transforms, make_coco_transforms_square_div_64
15
+ )
16
+
17
+ from PIL import Image
18
+ Image.MAX_IMAGE_PIXELS = None
19
+
20
+
21
+ def build_o365_raw(image_set, args, resolution):
22
+ root = Path(args.coco_path)
23
+ PATHS = {
24
+ "train": (root, root / 'zhiyuan_objv2_train_val_wo_5k.json'),
25
+ "val": (root, root / 'zhiyuan_objv2_minival5k.json'),
26
+ }
27
+ img_folder, ann_file = PATHS[image_set]
28
+
29
+ try:
30
+ square_resize = args.square_resize
31
+ except:
32
+ square_resize = False
33
+
34
+ try:
35
+ square_resize_div_64 = args.square_resize_div_64
36
+ except:
37
+ square_resize_div_64 = False
38
+
39
+ if square_resize_div_64:
40
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
41
+ else:
42
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
43
+ return dataset
44
+
45
+
46
+ def build_o365(image_set, args, resolution):
47
+ if image_set == 'train':
48
+ train_ds = build_o365_raw('train', args, resolution=resolution)
49
+ return train_ds
50
+ if image_set == 'val':
51
+ val_ds = build_o365_raw('val', args, resolution=resolution)
52
+ return val_ds
53
+ raise ValueError('Unknown image_set: {}'.format(image_set))
rfdetr/datasets/transforms.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ Transforms and data augmentation for both image + bbox.
18
+ """
19
+ import random
20
+
21
+ import PIL
22
+ import numpy as np
23
+ try:
24
+ from collections.abc import Sequence
25
+ except Exception:
26
+ from collections import Sequence
27
+ from numbers import Number
28
+ import torch
29
+ import torchvision.transforms as T
30
+ # from detectron2.data import transforms as DT
31
+ import torchvision.transforms.functional as F
32
+
33
+ from rfdetr.util.box_ops import box_xyxy_to_cxcywh
34
+ from rfdetr.util.misc import interpolate
35
+
36
+
37
+ def crop(image, target, region):
38
+ cropped_image = F.crop(image, *region)
39
+
40
+ target = target.copy()
41
+ i, j, h, w = region
42
+
43
+ # should we do something wrt the original size?
44
+ target["size"] = torch.tensor([h, w])
45
+
46
+ fields = ["labels", "area", "iscrowd"]
47
+
48
+ if "boxes" in target:
49
+ boxes = target["boxes"]
50
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
51
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
52
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
53
+ cropped_boxes = cropped_boxes.clamp(min=0)
54
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
55
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
56
+ target["area"] = area
57
+ fields.append("boxes")
58
+
59
+ if "masks" in target:
60
+ # FIXME should we update the area here if there are no boxes?
61
+ target['masks'] = target['masks'][:, i:i + h, j:j + w]
62
+ fields.append("masks")
63
+
64
+ # remove elements for which the boxes or masks that have zero area
65
+ if "boxes" in target or "masks" in target:
66
+ # favor boxes selection when defining which elements to keep
67
+ # this is compatible with previous implementation
68
+ if "boxes" in target:
69
+ cropped_boxes = target['boxes'].reshape(-1, 2, 2)
70
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
71
+ else:
72
+ keep = target['masks'].flatten(1).any(1)
73
+
74
+ for field in fields:
75
+ target[field] = target[field][keep]
76
+
77
+ return cropped_image, target
78
+
79
+
80
+ def hflip(image, target):
81
+ flipped_image = F.hflip(image)
82
+
83
+ w, h = image.size
84
+
85
+ target = target.copy()
86
+ if "boxes" in target:
87
+ boxes = target["boxes"]
88
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
89
+ target["boxes"] = boxes
90
+
91
+ if "masks" in target:
92
+ target['masks'] = target['masks'].flip(-1)
93
+
94
+ return flipped_image, target
95
+
96
+
97
+ def resize(image, target, size, max_size=None):
98
+ # size can be min_size (scalar) or (w, h) tuple
99
+
100
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
101
+ w, h = image_size
102
+ if max_size is not None:
103
+ min_original_size = float(min((w, h)))
104
+ max_original_size = float(max((w, h)))
105
+ if max_original_size / min_original_size * size > max_size:
106
+ size = int(round(max_size * min_original_size / max_original_size))
107
+
108
+ if (w <= h and w == size) or (h <= w and h == size):
109
+ return (h, w)
110
+
111
+ if w < h:
112
+ ow = size
113
+ oh = int(size * h / w)
114
+ else:
115
+ oh = size
116
+ ow = int(size * w / h)
117
+
118
+ return (oh, ow)
119
+
120
+ def get_size(image_size, size, max_size=None):
121
+ if isinstance(size, (list, tuple)):
122
+ return size[::-1]
123
+ else:
124
+ return get_size_with_aspect_ratio(image_size, size, max_size)
125
+
126
+ size = get_size(image.size, size, max_size)
127
+ rescaled_image = F.resize(image, size)
128
+
129
+ if target is None:
130
+ return rescaled_image, None
131
+
132
+ ratios = tuple(
133
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
134
+ ratio_width, ratio_height = ratios
135
+
136
+ target = target.copy()
137
+ if "boxes" in target:
138
+ boxes = target["boxes"]
139
+ scaled_boxes = boxes * torch.as_tensor(
140
+ [ratio_width, ratio_height, ratio_width, ratio_height])
141
+ target["boxes"] = scaled_boxes
142
+
143
+ if "area" in target:
144
+ area = target["area"]
145
+ scaled_area = area * (ratio_width * ratio_height)
146
+ target["area"] = scaled_area
147
+
148
+ h, w = size
149
+ target["size"] = torch.tensor([h, w])
150
+
151
+ if "masks" in target:
152
+ target['masks'] = interpolate(
153
+ target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
154
+
155
+
156
+ return rescaled_image, target
157
+
158
+
159
+ def pad(image, target, padding):
160
+ # assumes that we only pad on the bottom right corners
161
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
162
+ if target is None:
163
+ return padded_image, None
164
+ target = target.copy()
165
+ # should we do something wrt the original size?
166
+ target["size"] = torch.tensor(padded_image.size[::-1])
167
+ if "masks" in target:
168
+ target['masks'] = torch.nn.functional.pad(
169
+ target['masks'], (0, padding[0], 0, padding[1]))
170
+ return padded_image, target
171
+
172
+
173
+ class RandomCrop(object):
174
+ def __init__(self, size):
175
+ self.size = size
176
+
177
+ def __call__(self, img, target):
178
+ region = T.RandomCrop.get_params(img, self.size)
179
+ return crop(img, target, region)
180
+
181
+
182
+ class RandomSizeCrop(object):
183
+ def __init__(self, min_size: int, max_size: int):
184
+ self.min_size = min_size
185
+ self.max_size = max_size
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ w = random.randint(self.min_size, min(img.width, self.max_size))
189
+ h = random.randint(self.min_size, min(img.height, self.max_size))
190
+ region = T.RandomCrop.get_params(img, [h, w])
191
+ return crop(img, target, region)
192
+
193
+
194
+ class CenterCrop(object):
195
+ def __init__(self, size):
196
+ self.size = size
197
+
198
+ def __call__(self, img, target):
199
+ image_width, image_height = img.size
200
+ crop_height, crop_width = self.size
201
+ crop_top = int(round((image_height - crop_height) / 2.))
202
+ crop_left = int(round((image_width - crop_width) / 2.))
203
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
204
+
205
+
206
+ class RandomHorizontalFlip(object):
207
+ def __init__(self, p=0.5):
208
+ self.p = p
209
+
210
+ def __call__(self, img, target):
211
+ if random.random() < self.p:
212
+ return hflip(img, target)
213
+ return img, target
214
+
215
+
216
+ class RandomResize(object):
217
+ def __init__(self, sizes, max_size=None):
218
+ assert isinstance(sizes, (list, tuple))
219
+ self.sizes = sizes
220
+ self.max_size = max_size
221
+
222
+ def __call__(self, img, target=None):
223
+ size = random.choice(self.sizes)
224
+ return resize(img, target, size, self.max_size)
225
+
226
+
227
+ class SquareResize(object):
228
+ def __init__(self, sizes):
229
+ assert isinstance(sizes, (list, tuple))
230
+ self.sizes = sizes
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ rescaled_img=F.resize(img, (size, size))
235
+ w, h = rescaled_img.size
236
+ if target is None:
237
+ return rescaled_img, None
238
+ ratios = tuple(
239
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
240
+ ratio_width, ratio_height = ratios
241
+
242
+ target = target.copy()
243
+ if "boxes" in target:
244
+ boxes = target["boxes"]
245
+ scaled_boxes = boxes * torch.as_tensor(
246
+ [ratio_width, ratio_height, ratio_width, ratio_height])
247
+ target["boxes"] = scaled_boxes
248
+
249
+ if "area" in target:
250
+ area = target["area"]
251
+ scaled_area = area * (ratio_width * ratio_height)
252
+ target["area"] = scaled_area
253
+
254
+ target["size"] = torch.tensor([h, w])
255
+
256
+ return rescaled_img, target
257
+
258
+
259
+ class RandomPad(object):
260
+ def __init__(self, max_pad):
261
+ self.max_pad = max_pad
262
+
263
+ def __call__(self, img, target):
264
+ pad_x = random.randint(0, self.max_pad)
265
+ pad_y = random.randint(0, self.max_pad)
266
+ return pad(img, target, (pad_x, pad_y))
267
+
268
+
269
+ class PILtoNdArray(object):
270
+
271
+ def __call__(self, img, target):
272
+ return np.asarray(img), target
273
+
274
+
275
+ class NdArraytoPIL(object):
276
+
277
+ def __call__(self, img, target):
278
+ return F.to_pil_image(img.astype('uint8')), target
279
+
280
+
281
+ class Pad(object):
282
+ def __init__(self,
283
+ size=None,
284
+ size_divisor=32,
285
+ pad_mode=0,
286
+ offsets=None,
287
+ fill_value=(127.5, 127.5, 127.5)):
288
+ """
289
+ Pad image to a specified size or multiple of size_divisor.
290
+ Args:
291
+ size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
292
+ size_divisor (int): size divisor, default 32
293
+ pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
294
+ if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
295
+ offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
296
+ fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
297
+ """
298
+
299
+ if not isinstance(size, (int, Sequence)):
300
+ raise TypeError(
301
+ "Type of target_size is invalid when random_size is True. \
302
+ Must be List, now is {}".format(type(size)))
303
+
304
+ if isinstance(size, int):
305
+ size = [size, size]
306
+
307
+ assert pad_mode in [
308
+ -1, 0, 1, 2
309
+ ], 'currently only supports four modes [-1, 0, 1, 2]'
310
+ if pad_mode == -1:
311
+ assert offsets, 'if pad_mode is -1, offsets should not be None'
312
+
313
+ self.size = size
314
+ self.size_divisor = size_divisor
315
+ self.pad_mode = pad_mode
316
+ self.fill_value = fill_value
317
+ self.offsets = offsets
318
+
319
+ def apply_bbox(self, bbox, offsets):
320
+ return bbox + np.array(offsets * 2, dtype=np.float32)
321
+
322
+ def apply_image(self, image, offsets, im_size, size):
323
+ x, y = offsets
324
+ im_h, im_w = im_size
325
+ h, w = size
326
+ canvas = np.ones((h, w, 3), dtype=np.float32)
327
+ canvas *= np.array(self.fill_value, dtype=np.float32)
328
+ canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
329
+ return canvas
330
+
331
+ def __call__(self, im, target):
332
+ im_h, im_w = im.shape[:2]
333
+ if self.size:
334
+ h, w = self.size
335
+ assert (
336
+ im_h <= h and im_w <= w
337
+ ), '(h, w) of target size should be greater than (im_h, im_w)'
338
+ else:
339
+ h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
340
+ w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
341
+
342
+ if h == im_h and w == im_w:
343
+ return im.astype(np.float32), target
344
+
345
+ if self.pad_mode == -1:
346
+ offset_x, offset_y = self.offsets
347
+ elif self.pad_mode == 0:
348
+ offset_y, offset_x = 0, 0
349
+ elif self.pad_mode == 1:
350
+ offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
351
+ else:
352
+ offset_y, offset_x = h - im_h, w - im_w
353
+
354
+ offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
355
+
356
+ im = self.apply_image(im, offsets, im_size, size)
357
+
358
+ if self.pad_mode == 0:
359
+ target["size"] = torch.tensor([h, w])
360
+ return im, target
361
+ if 'boxes' in target and len(target['boxes']) > 0:
362
+ boxes = np.asarray(target["boxes"])
363
+ target["boxes"] = torch.from_numpy(self.apply_bbox(boxes, offsets))
364
+ target["size"] = torch.tensor([h, w])
365
+
366
+ return im, target
367
+
368
+
369
+ class RandomExpand(object):
370
+ """Random expand the canvas.
371
+ Args:
372
+ ratio (float): maximum expansion ratio.
373
+ prob (float): probability to expand.
374
+ fill_value (list): color value used to fill the canvas. in RGB order.
375
+ """
376
+
377
+ def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
378
+ assert ratio > 1.01, "expand ratio must be larger than 1.01"
379
+ self.ratio = ratio
380
+ self.prob = prob
381
+ assert isinstance(fill_value, (Number, Sequence)), \
382
+ "fill value must be either float or sequence"
383
+ if isinstance(fill_value, Number):
384
+ fill_value = (fill_value, ) * 3
385
+ if not isinstance(fill_value, tuple):
386
+ fill_value = tuple(fill_value)
387
+ self.fill_value = fill_value
388
+
389
+ def __call__(self, img, target):
390
+ if np.random.uniform(0., 1.) < self.prob:
391
+ return img, target
392
+
393
+ height, width = img.shape[:2]
394
+ ratio = np.random.uniform(1., self.ratio)
395
+ h = int(height * ratio)
396
+ w = int(width * ratio)
397
+ if not h > height or not w > width:
398
+ return img, target
399
+ y = np.random.randint(0, h - height)
400
+ x = np.random.randint(0, w - width)
401
+ offsets, size = [x, y], [h, w]
402
+
403
+ pad = Pad(size,
404
+ pad_mode=-1,
405
+ offsets=offsets,
406
+ fill_value=self.fill_value)
407
+
408
+ return pad(img, target)
409
+
410
+
411
+ class RandomSelect(object):
412
+ """
413
+ Randomly selects between transforms1 and transforms2,
414
+ with probability p for transforms1 and (1 - p) for transforms2
415
+ """
416
+ def __init__(self, transforms1, transforms2, p=0.5):
417
+ self.transforms1 = transforms1
418
+ self.transforms2 = transforms2
419
+ self.p = p
420
+
421
+ def __call__(self, img, target):
422
+ if random.random() < self.p:
423
+ return self.transforms1(img, target)
424
+ return self.transforms2(img, target)
425
+
426
+
427
+ class ToTensor(object):
428
+ def __call__(self, img, target):
429
+ return F.to_tensor(img), target
430
+
431
+
432
+ class RandomErasing(object):
433
+
434
+ def __init__(self, *args, **kwargs):
435
+ self.eraser = T.RandomErasing(*args, **kwargs)
436
+
437
+ def __call__(self, img, target):
438
+ return self.eraser(img), target
439
+
440
+
441
+ class Normalize(object):
442
+ def __init__(self, mean, std):
443
+ self.mean = mean
444
+ self.std = std
445
+
446
+ def __call__(self, image, target=None):
447
+ image = F.normalize(image, mean=self.mean, std=self.std)
448
+ if target is None:
449
+ return image, None
450
+ target = target.copy()
451
+ h, w = image.shape[-2:]
452
+ if "boxes" in target:
453
+ boxes = target["boxes"]
454
+ boxes = box_xyxy_to_cxcywh(boxes)
455
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
456
+ target["boxes"] = boxes
457
+ return image, target
458
+
459
+
460
+ class Compose(object):
461
+ def __init__(self, transforms):
462
+ self.transforms = transforms
463
+
464
+ def __call__(self, image, target):
465
+ for t in self.transforms:
466
+ image, target = t(image, target)
467
+ return image, target
468
+
469
+ def __repr__(self):
470
+ format_string = self.__class__.__name__ + "("
471
+ for t in self.transforms:
472
+ format_string += "\n"
473
+ format_string += " {0}".format(t)
474
+ format_string += "\n)"
475
+ return format_string
rfdetr/deploy/__init__.py ADDED
File without changes
rfdetr/deploy/_onnx/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # LW-DETR
3
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ """
7
+ onnx optimizer and symbolic registry
8
+ """
9
+ from . import optimizer
10
+ from . import symbolic
11
+
12
+ from .optimizer import OnnxOptimizer
13
+ from .symbolic import CustomOpSymbolicRegistry
rfdetr/deploy/_onnx/optimizer.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ OnnxOptimizer
12
+ """
13
+ import os
14
+ from collections import OrderedDict
15
+ from copy import deepcopy
16
+
17
+ import numpy as np
18
+ import onnx
19
+ import torch
20
+ from onnx import shape_inference
21
+ import onnx_graphsurgeon as gs
22
+ from polygraphy.backend.onnx.loader import fold_constants
23
+ from onnx_graphsurgeon.logger.logger import G_LOGGER
24
+
25
+ from .symbolic import CustomOpSymbolicRegistry
26
+
27
+
28
+ class OnnxOptimizer():
29
+ def __init__(
30
+ self,
31
+ input,
32
+ severity=G_LOGGER.INFO
33
+ ):
34
+ if isinstance(input, str):
35
+ onnx_graph = self.load_onnx(input)
36
+ else:
37
+ onnx_graph = input
38
+ self.graph = gs.import_onnx(onnx_graph)
39
+ self.severity = severity
40
+ self.set_severity(severity)
41
+
42
+ def set_severity(self, severity):
43
+ G_LOGGER.severity = severity
44
+
45
+ def load_onnx(self, onnx_path:str):
46
+ """Load onnx from file
47
+ """
48
+ assert os.path.isfile(onnx_path), f"not found onnx file: {onnx_path}"
49
+ onnx_graph = onnx.load(onnx_path)
50
+ G_LOGGER.info(f"load onnx file: {onnx_path}")
51
+ return onnx_graph
52
+
53
+ def save_onnx(self, onnx_path:str):
54
+ onnx_graph = gs.export_onnx(self.graph)
55
+ G_LOGGER.info(f"save onnx file: {onnx_path}")
56
+ onnx.save(onnx_graph, onnx_path)
57
+
58
+ def info(self, prefix=''):
59
+ G_LOGGER.verbose(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
60
+
61
+ def cleanup(self, return_onnx=False):
62
+ self.graph.cleanup().toposort()
63
+ if return_onnx:
64
+ return gs.export_onnx(self.graph)
65
+
66
+ def select_outputs(self, keep, names=None):
67
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
68
+ if names:
69
+ for i, name in enumerate(names):
70
+ self.graph.outputs[i].name = name
71
+
72
+ def find_node_input(self, node, name:str=None, value=None) -> int:
73
+ for i, inp in enumerate(node.inputs):
74
+ if isinstance(name, str) and inp.name == name:
75
+ index = i
76
+ elif inp == value:
77
+ index = i
78
+ assert index >= 0, f"not found {name}({value}) in node.inputs"
79
+ return index
80
+
81
+ def find_node_output(self, node, name:str=None, value=None) -> int:
82
+ for i, inp in enumerate(node.outputs):
83
+ if isinstance(name, str) and inp.name == name:
84
+ index = i
85
+ elif inp == value:
86
+ index = i
87
+ assert index >= 0, f"not found {name}({value}) in node.outputs"
88
+ return index
89
+
90
+ def common_opt(self, return_onnx=False):
91
+ for fn in CustomOpSymbolicRegistry._OPTIMIZER:
92
+ fn(self)
93
+ self.cleanup()
94
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=False)
95
+ if onnx_graph.ByteSize() > 2147483648:
96
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
97
+ else:
98
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
99
+ self.graph = gs.import_onnx(onnx_graph)
100
+ self.cleanup()
101
+ if return_onnx:
102
+ return onnx_graph
103
+
104
+ def resize_fix(self):
105
+ '''
106
+ This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs).
107
+ It substitutes found Resize with Resize that takes the size of the output tensor instead of scales.
108
+ It adds Shape->Slice->Concat
109
+ Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor.
110
+ This fix is required for the dynamic shape support.
111
+ '''
112
+ mResizeNodes = 0
113
+ for node in self.graph.nodes:
114
+ if node.op == "Resize" and len(node.inputs) == 3:
115
+ name = node.name + "/"
116
+
117
+ add_node = node.o().o().i(1)
118
+ div_node = node.i()
119
+
120
+ shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4])
121
+ shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out])
122
+
123
+ const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64))
124
+ const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64))
125
+ const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64))
126
+
127
+ slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2])
128
+ slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out])
129
+
130
+ shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2])
131
+ shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out])
132
+
133
+ slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2])
134
+ slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out])
135
+
136
+ concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4])
137
+ concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out])
138
+
139
+ none_var = gs.Variable.empty()
140
+
141
+ resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]])
142
+
143
+ self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw])
144
+
145
+ node.inputs = []
146
+ node.outputs = []
147
+
148
+ mResizeNodes += 1
149
+
150
+ self.cleanup()
151
+ return mResizeNodes
152
+
153
+ def adjustAddNode(self):
154
+ nAdjustAddNode = 0
155
+ for node in self.graph.nodes:
156
+ # Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT.
157
+ if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant):
158
+ tensor = node.inputs[1]
159
+ bias = node.inputs[0]
160
+ node.inputs = [tensor, bias]
161
+ nAdjustAddNode += 1
162
+
163
+ self.cleanup()
164
+ return nAdjustAddNode
165
+
166
+ def decompose_instancenorms(self):
167
+ nRemoveInstanceNorm = 0
168
+ for node in self.graph.nodes:
169
+ if node.op == "InstanceNormalization":
170
+ name = node.name + "/"
171
+ input_tensor = node.inputs[0]
172
+ output_tensor = node.outputs[0]
173
+ mean_out = gs.Variable(name=name + "mean_out")
174
+ mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
175
+ sub_out = gs.Variable(name=name + "sub_out")
176
+ sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
177
+ pow_out = gs.Variable(name=name + "pow_out")
178
+ pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
179
+ pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
180
+ mean2_out = gs.Variable(name=name + "mean2_out")
181
+ mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
182
+ epsilon_out = gs.Variable(name=name + "epsilon_out")
183
+ epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
184
+ epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
185
+ sqrt_out = gs.Variable(name=name + "sqrt_out")
186
+ sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
187
+ div_out = gs.Variable(name=name + "div_out")
188
+ div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
189
+ constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
190
+ constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
191
+ mul_out = gs.Variable(name=name + "mul_out")
192
+ mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
193
+ add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
194
+ self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
195
+ node.inputs = []
196
+ node.outputs = []
197
+ nRemoveInstanceNorm += 1
198
+
199
+ self.cleanup()
200
+ return nRemoveInstanceNorm
201
+
202
+ def insert_groupnorm_plugin(self):
203
+ nGroupNormPlugin = 0
204
+ for node in self.graph.nodes:
205
+ if node.op == "Reshape" and node.outputs != [] and \
206
+ node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
207
+ node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
208
+ node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
209
+ len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3:
210
+ # "node.outputs != []" is added for VAE
211
+
212
+ inputTensor = node.inputs[0]
213
+
214
+ gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
215
+ index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
216
+ gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
217
+ constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
218
+
219
+ betaNode = gammaNode.o()
220
+ index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
221
+ beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
222
+ constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
223
+
224
+ epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]
225
+
226
+ if betaNode.o().op == "Sigmoid": # need Swish
227
+ bSwish = True
228
+ lastNode = betaNode.o().o() # Mul node of Swish
229
+ else:
230
+ bSwish = False
231
+ lastNode = betaNode # Cast node after Group Norm
232
+
233
+ if lastNode.o().op == "Cast":
234
+ lastNode = lastNode.o()
235
+ inputList = [inputTensor, constantGamma, constantBeta]
236
+ groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
237
+ groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
238
+ self.graph.nodes.append(groupNormN)
239
+
240
+ for subNode in self.graph.nodes:
241
+ if lastNode.outputs[0] in subNode.inputs:
242
+ index = subNode.inputs.index(lastNode.outputs[0])
243
+ subNode.inputs[index] = groupNormV
244
+ node.inputs = []
245
+ lastNode.outputs = []
246
+ nGroupNormPlugin += 1
247
+
248
+ self.cleanup()
249
+ return nGroupNormPlugin
250
+
251
+ def insert_layernorm_plugin(self):
252
+ nLayerNormPlugin = 0
253
+ for node in self.graph.nodes:
254
+ if node.op == 'ReduceMean' and \
255
+ node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \
256
+ node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \
257
+ node.o().o(0).o().op == 'ReduceMean' and \
258
+ node.o().o(0).o().o().op == 'Add' and \
259
+ node.o().o(0).o().o().o().op == 'Sqrt' and \
260
+ node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \
261
+ node.o().o(0).o().o().o().o().o().op == 'Mul' and \
262
+ node.o().o(0).o().o().o().o().o().o().op == 'Add' and \
263
+ len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1:
264
+
265
+ if node.i().op == "Add":
266
+ inputTensor = node.inputs[0] # CLIP
267
+ else:
268
+ inputTensor = node.i().inputs[0] # UNet and VAE
269
+
270
+ gammaNode = node.o().o().o().o().o().o().o()
271
+ index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
272
+ gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
273
+ constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
274
+
275
+ betaNode = gammaNode.o()
276
+ index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
277
+ beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
278
+ constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
279
+
280
+ inputList = [inputTensor, constantGamma, constantBeta]
281
+ layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape)
282
+ layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV])
283
+ self.graph.nodes.append(layerNormN)
284
+ nLayerNormPlugin += 1
285
+
286
+ if betaNode.outputs[0] in self.graph.outputs:
287
+ index = self.graph.outputs.index(betaNode.outputs[0])
288
+ self.graph.outputs[index] = layerNormV
289
+ else:
290
+ if betaNode.o().op == "Cast":
291
+ lastNode = betaNode.o()
292
+ else:
293
+ lastNode = betaNode
294
+ for subNode in self.graph.nodes:
295
+ if lastNode.outputs[0] in subNode.inputs:
296
+ index = subNode.inputs.index(lastNode.outputs[0])
297
+ subNode.inputs[index] = layerNormV
298
+ lastNode.outputs = []
299
+
300
+ self.cleanup()
301
+ return nLayerNormPlugin
302
+
303
+ def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0):
304
+ # Get weights of K
305
+ weights_k = node_k.inputs[1].values
306
+ # Get weights of V
307
+ weights_v = node_v.inputs[1].values
308
+ # Input number of channels to K and V
309
+ C = weights_k.shape[0]
310
+ # Number of heads
311
+ H = heads
312
+ # Dimension per head
313
+ D = weights_k.shape[1] // H
314
+
315
+ # Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape
316
+ weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D)
317
+
318
+ # K and V have the same input
319
+ input_tensor = node_k.inputs[0]
320
+ # K and V must have the same output which we feed into fmha plugin
321
+ output_tensor_k = node_k.outputs[0]
322
+ # Create tensor
323
+ constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv))
324
+
325
+ # Create fused KV node
326
+ fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k])
327
+ self.graph.nodes.append(fused_kv_node)
328
+
329
+ # Connect the output of fused node to the inputs of the nodes after K and V
330
+ node_v.o(num_dynamic).inputs[0] = output_tensor_k
331
+ node_k.o(num_dynamic).inputs[0] = output_tensor_k
332
+ for i in range(0,num_dynamic):
333
+ node_v.o().inputs.clear()
334
+ node_k.o().inputs.clear()
335
+
336
+ # Clear inputs and outputs of K and V to ge these nodes cleared
337
+ node_k.outputs.clear()
338
+ node_v.outputs.clear()
339
+ node_k.inputs.clear()
340
+ node_v.inputs.clear()
341
+
342
+ self.cleanup()
343
+ return fused_kv_node
344
+
345
+ def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0):
346
+ # Get inputs and outputs for the fMHCA plugin
347
+ # We take an output of reshape that follows the Q GEMM
348
+ output_q = node_q.o(num_dynamic).o().inputs[0]
349
+ output_kv = node_kv.o().inputs[0]
350
+ output_final_tranpose = final_tranpose.outputs[0]
351
+
352
+ # Clear the inputs of the nodes that follow the Q and KV GEMM
353
+ # to delete these subgraphs (it will be substituted by fMHCA plugin)
354
+ node_kv.outputs[0].outputs[0].inputs.clear()
355
+ node_kv.outputs[0].outputs[0].inputs.clear()
356
+ node_q.o(num_dynamic).o().inputs.clear()
357
+ for i in range(0,num_dynamic):
358
+ node_q.o(i).o().o(1).inputs.clear()
359
+
360
+ weights_kv = node_kv.inputs[1].values
361
+ dims_per_head = weights_kv.shape[1] // (heads * 2)
362
+
363
+ # Reshape dims
364
+ shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64)))
365
+
366
+ # Reshape output tensor
367
+ output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None)
368
+ # Create fMHA plugin
369
+ reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape])
370
+ # Insert node
371
+ self.graph.nodes.append(reshape)
372
+
373
+ # Create fMHCA plugin
374
+ fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose])
375
+ # Insert node
376
+ self.graph.nodes.append(fmhca)
377
+
378
+ # Connect input of fMHCA to output of Q GEMM
379
+ node_q.o(num_dynamic).outputs[0] = output_q
380
+
381
+ if num_dynamic > 0:
382
+ reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None)
383
+ reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out])
384
+ self.graph.nodes.append(reshape2_input1_shape)
385
+ final_tranpose.o().inputs[1] = reshape2_input1_out
386
+
387
+ # Clear outputs of transpose to get this subgraph cleared
388
+ final_tranpose.outputs.clear()
389
+
390
+ self.cleanup()
391
+
392
+ def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0):
393
+ # Get weights of Q
394
+ weights_q = node_q.inputs[1].values
395
+ # Get weights of K
396
+ weights_k = node_k.inputs[1].values
397
+ # Get weights of V
398
+ weights_v = node_v.inputs[1].values
399
+
400
+ # Input number of channels to Q, K and V
401
+ C = weights_k.shape[0]
402
+ # Number of heads
403
+ H = heads
404
+ # Hidden dimension per head
405
+ D = weights_k.shape[1] // H
406
+
407
+ # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
408
+ weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D)
409
+
410
+ input_tensor = node_k.inputs[0] # K and V have the same input
411
+ # Q, K and V must have the same output which we feed into fmha plugin
412
+ output_tensor_k = node_k.outputs[0]
413
+ # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
414
+ constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv))
415
+
416
+ # Created a fused node
417
+ fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k])
418
+ self.graph.nodes.append(fused_qkv_node)
419
+
420
+ # Connect the output of the fused node to the inputs of the nodes after Q, K and V
421
+ node_q.o(num_dynamic).inputs[0] = output_tensor_k
422
+ node_k.o(num_dynamic).inputs[0] = output_tensor_k
423
+ node_v.o(num_dynamic).inputs[0] = output_tensor_k
424
+ for i in range(0,num_dynamic):
425
+ node_q.o().inputs.clear()
426
+ node_k.o().inputs.clear()
427
+ node_v.o().inputs.clear()
428
+
429
+ # Clear inputs and outputs of Q, K and V to ge these nodes cleared
430
+ node_q.outputs.clear()
431
+ node_k.outputs.clear()
432
+ node_v.outputs.clear()
433
+
434
+ node_q.inputs.clear()
435
+ node_k.inputs.clear()
436
+ node_v.inputs.clear()
437
+
438
+ self.cleanup()
439
+ return fused_qkv_node
440
+
441
+ def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0):
442
+ # Get inputs and outputs for the fMHA plugin
443
+ output_qkv = node_qkv.o().inputs[0]
444
+ output_final_tranpose = final_tranpose.outputs[0]
445
+
446
+ # Clear the inputs of the nodes that follow the QKV GEMM
447
+ # to delete these subgraphs (it will be substituted by fMHA plugin)
448
+ node_qkv.outputs[0].outputs[2].inputs.clear()
449
+ node_qkv.outputs[0].outputs[1].inputs.clear()
450
+ node_qkv.outputs[0].outputs[0].inputs.clear()
451
+
452
+ weights_qkv = node_qkv.inputs[1].values
453
+ dims_per_head = weights_qkv.shape[1] // (heads * 3)
454
+
455
+ # Reshape dims
456
+ shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64)))
457
+
458
+ # Reshape output tensor
459
+ output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None)
460
+ # Create fMHA plugin
461
+ reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape])
462
+ # Insert node
463
+ self.graph.nodes.append(reshape)
464
+
465
+ # Create fMHA plugin
466
+ fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose])
467
+ # Insert node
468
+ self.graph.nodes.append(fmha)
469
+
470
+ if num_dynamic > 0:
471
+ reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None)
472
+ reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out])
473
+ self.graph.nodes.append(reshape2_input1_shape)
474
+ final_tranpose.o().inputs[1] = reshape2_input1_out
475
+
476
+ # Clear outputs of transpose to get this subgraph cleared
477
+ final_tranpose.outputs.clear()
478
+
479
+ self.cleanup()
480
+
481
+ def mha_mhca_detected(self, node, mha):
482
+ # Go from V GEMM down to the S*V MatMul and all way up to K GEMM
483
+ # If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
484
+ # If we are looking for MHA inputs (K and V) must be not equal.
485
+ if node.op == "MatMul" and len(node.outputs) == 1 and \
486
+ ((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \
487
+ (not mha and len(node.inputs[0].inputs) == 0)):
488
+
489
+ if node.o().op == 'Shape':
490
+ if node.o(1).op == 'Shape':
491
+ num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
492
+ else:
493
+ num_dynamic_kv = 1
494
+ # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
495
+ num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1
496
+ else:
497
+ num_dynamic_kv = 0
498
+ num_dynamic_q = 0
499
+
500
+ o = node.o(num_dynamic_kv)
501
+ if o.op == "Reshape" and \
502
+ o.o().op == "Transpose" and \
503
+ o.o().o().op == "Reshape" and \
504
+ o.o().o().o().op == "MatMul" and \
505
+ o.o().o().o().i(0).op == "Softmax" and \
506
+ o.o().o().o().i(1).op == "Reshape" and \
507
+ o.o().o().o().i(0).i().op == "Mul" and \
508
+ o.o().o().o().i(0).i().i().op == "MatMul" and \
509
+ o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
510
+ o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
511
+ o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
512
+ o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
513
+ o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
514
+ o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
515
+ node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
516
+ # "len(node.outputs) == 1" to make sure we are not in the already fused node
517
+ node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
518
+ node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
519
+ node_v = node
520
+ final_tranpose = o.o().o().o().o(num_dynamic_q).o()
521
+ # Sanity check to make sure that the graph looks like expected
522
+ if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
523
+ return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose
524
+ return False, 0, 0, None, None, None, None
525
+
526
+ def fuse_kv_insert_fmhca(self, heads, mhca_index, sm):
527
+ nodes = self.graph.nodes
528
+ # Iterate over graph and search for MHCA pattern
529
+ for idx, _ in enumerate(nodes):
530
+ # fMHCA can't be at the 2 last layers of the network. It is a guard from OOB
531
+ if idx + 1 > len(nodes) or idx + 2 > len(nodes):
532
+ continue
533
+
534
+ # Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected
535
+ detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
536
+ self.mha_mhca_detected(nodes[idx], mha=False)
537
+ if detected:
538
+ assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1
539
+ # Skip the FMHCA plugin for SM75 except for when the dim per head is 40.
540
+ if sm == 75 and node_q.inputs[1].shape[1] // heads == 160:
541
+ continue
542
+ # Fuse K and V GEMMS
543
+ node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv)
544
+ # Insert fMHCA plugin
545
+ self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q)
546
+ return True
547
+ return False
548
+
549
+ def fuse_qkv_insert_fmha(self, heads, mha_index):
550
+ nodes = self.graph.nodes
551
+ # Iterate over graph and search for MHA pattern
552
+ for idx, _ in enumerate(nodes):
553
+ # fMHA can't be at the 2 last layers of the network. It is a guard from OOB
554
+ if idx + 1 > len(nodes) or idx + 2 > len(nodes):
555
+ continue
556
+
557
+ # Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected
558
+ detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
559
+ self.mha_mhca_detected(nodes[idx], mha=True)
560
+ if detected:
561
+ assert num_dynamic_q == num_dynamic_kv
562
+ # Fuse Q, K and V GEMMS
563
+ node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
564
+ # Insert fMHA plugin
565
+ self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
566
+ return True
567
+ return False
568
+
569
+ def insert_fmhca_plugin(self, num_heads, sm):
570
+ mhca_index = 0
571
+ while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm):
572
+ mhca_index += 1
573
+ return mhca_index
574
+
575
+ def insert_fmha_plugin(self, num_heads):
576
+ mha_index = 0
577
+ while self.fuse_qkv_insert_fmha(num_heads, mha_index):
578
+ mha_index += 1
579
+ return mha_index
rfdetr/deploy/_onnx/symbolic.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ """
10
+ CustomOpSymbolicRegistry class
11
+ """
12
+ from copy import deepcopy
13
+
14
+ import onnx
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.onnx import register_custom_op_symbolic
19
+ from torch.onnx.symbolic_helper import parse_args
20
+ from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes
21
+ from torch.autograd import Function
22
+
23
+
24
+ class CustomOpSymbolicRegistry:
25
+ # _SYMBOLICS = {}
26
+ _OPTIMIZER = []
27
+
28
+ @classmethod
29
+ def optimizer(cls, fn):
30
+ cls._OPTIMIZER.append(fn)
31
+
32
+
33
+ def register_optimizer():
34
+ def optimizer_wrapper(fn):
35
+ CustomOpSymbolicRegistry.optimizer(fn)
36
+ return fn
37
+ return optimizer_wrapper
rfdetr/deploy/benchmark.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ This tool provides performance benchmarks by using ONNX Runtime and TensorRT
12
+ to run inference on a given model with the COCO validation set. It offers
13
+ reliable measurements of inference latency using ONNX Runtime or TensorRT
14
+ on the device.
15
+ """
16
+ import argparse
17
+ import copy
18
+ import contextlib
19
+ import datetime
20
+ import json
21
+ import os
22
+ import os.path as osp
23
+ import random
24
+ import time
25
+ import ast
26
+ from pathlib import Path
27
+ from collections import namedtuple, OrderedDict
28
+
29
+ from pycocotools.cocoeval import COCOeval
30
+ from pycocotools.coco import COCO
31
+ import pycocotools.mask as mask_util
32
+
33
+ import numpy as np
34
+ from PIL import Image
35
+ import torch
36
+ from torch.utils.data import DataLoader, DistributedSampler
37
+ import torchvision.transforms as T
38
+ import torchvision.transforms.functional as F
39
+ import tqdm
40
+
41
+ import pycuda.driver as cuda
42
+ import pycuda.autoinit
43
+ import onnxruntime as nxrun
44
+ import tensorrt as trt
45
+
46
+
47
+ def parser_args():
48
+ parser = argparse.ArgumentParser('performance benchmark tool for onnx/trt model')
49
+ parser.add_argument('--path', type=str, help='engine file path')
50
+ parser.add_argument('--coco_path', type=str, default="data/coco", help='coco dataset path')
51
+ parser.add_argument('--device', default=0, type=int)
52
+ parser.add_argument('--run_benchmark', action='store_true', help='repeat the inference to benchmark the latency')
53
+ parser.add_argument('--disable_eval', action='store_true', help='disable evaluation')
54
+ return parser.parse_args()
55
+
56
+
57
+ class CocoEvaluator(object):
58
+ def __init__(self, coco_gt, iou_types):
59
+ assert isinstance(iou_types, (list, tuple))
60
+ coco_gt = COCO(coco_gt)
61
+ coco_gt = copy.deepcopy(coco_gt)
62
+ self.coco_gt = coco_gt
63
+
64
+ self.iou_types = iou_types
65
+ self.coco_eval = {}
66
+ for iou_type in iou_types:
67
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
68
+
69
+ self.img_ids = []
70
+ self.eval_imgs = {k: [] for k in iou_types}
71
+
72
+ def update(self, predictions):
73
+ img_ids = list(np.unique(list(predictions.keys())))
74
+ self.img_ids.extend(img_ids)
75
+
76
+ for iou_type in self.iou_types:
77
+ results = self.prepare(predictions, iou_type)
78
+
79
+ # suppress pycocotools prints
80
+ with open(os.devnull, 'w') as devnull:
81
+ with contextlib.redirect_stdout(devnull):
82
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
83
+ coco_eval = self.coco_eval[iou_type]
84
+
85
+ coco_eval.cocoDt = coco_dt
86
+ coco_eval.params.imgIds = list(img_ids)
87
+ img_ids, eval_imgs = evaluate(coco_eval)
88
+
89
+ self.eval_imgs[iou_type].append(eval_imgs)
90
+
91
+ def synchronize_between_processes(self):
92
+ for iou_type in self.iou_types:
93
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
94
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
95
+
96
+ def accumulate(self):
97
+ for coco_eval in self.coco_eval.values():
98
+ coco_eval.accumulate()
99
+
100
+ def summarize(self):
101
+ for iou_type, coco_eval in self.coco_eval.items():
102
+ print("IoU metric: {}".format(iou_type))
103
+ coco_eval.summarize()
104
+
105
+ def prepare(self, predictions, iou_type):
106
+ if iou_type == "bbox":
107
+ return self.prepare_for_coco_detection(predictions)
108
+ else:
109
+ raise ValueError("Unknown iou type {}".format(iou_type))
110
+
111
+ def prepare_for_coco_detection(self, predictions):
112
+ coco_results = []
113
+ for original_id, prediction in predictions.items():
114
+ if len(prediction) == 0:
115
+ continue
116
+
117
+ boxes = prediction["boxes"]
118
+ boxes = convert_to_xywh(boxes).tolist()
119
+ scores = prediction["scores"].tolist()
120
+ labels = prediction["labels"].tolist()
121
+
122
+ coco_results.extend(
123
+ [
124
+ {
125
+ "image_id": original_id,
126
+ "category_id": labels[k],
127
+ "bbox": box,
128
+ "score": scores[k],
129
+ }
130
+ for k, box in enumerate(boxes)
131
+ ]
132
+ )
133
+ return coco_results
134
+
135
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
136
+ img_ids = list(img_ids)
137
+ eval_imgs = list(eval_imgs.flatten())
138
+
139
+ coco_eval.evalImgs = eval_imgs
140
+ coco_eval.params.imgIds = img_ids
141
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
142
+
143
+ def evaluate(self):
144
+ '''
145
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
146
+ :return: None
147
+ '''
148
+ # Running per image evaluation...
149
+ p = self.params
150
+ # add backward compatibility if useSegm is specified in params
151
+ if p.useSegm is not None:
152
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
153
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
154
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
155
+ p.imgIds = list(np.unique(p.imgIds))
156
+ if p.useCats:
157
+ p.catIds = list(np.unique(p.catIds))
158
+ p.maxDets = sorted(p.maxDets)
159
+ self.params = p
160
+
161
+ self._prepare()
162
+ # loop through images, area range, max detection number
163
+ catIds = p.catIds if p.useCats else [-1]
164
+
165
+ if p.iouType == 'segm' or p.iouType == 'bbox':
166
+ computeIoU = self.computeIoU
167
+ elif p.iouType == 'keypoints':
168
+ computeIoU = self.computeOks
169
+ self.ious = {
170
+ (imgId, catId): computeIoU(imgId, catId)
171
+ for imgId in p.imgIds
172
+ for catId in catIds}
173
+
174
+ evaluateImg = self.evaluateImg
175
+ maxDet = p.maxDets[-1]
176
+ evalImgs = [
177
+ evaluateImg(imgId, catId, areaRng, maxDet)
178
+ for catId in catIds
179
+ for areaRng in p.areaRng
180
+ for imgId in p.imgIds
181
+ ]
182
+ # this is NOT in the pycocotools code, but could be done outside
183
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
184
+ self._paramsEval = copy.deepcopy(self.params)
185
+ return p.imgIds, evalImgs
186
+
187
+ def convert_to_xywh(boxes):
188
+ boxes[:, 2:] -= boxes[:, :2]
189
+ return boxes
190
+
191
+
192
+ def get_image_list(ann_file):
193
+ with open(ann_file, 'r') as fin:
194
+ data = json.load(fin)
195
+ return data['images']
196
+
197
+
198
+ def load_image(file_path):
199
+ return Image.open(file_path).convert("RGB")
200
+
201
+
202
+ class Compose(object):
203
+ def __init__(self, transforms):
204
+ self.transforms = transforms
205
+
206
+ def __call__(self, image, target):
207
+ for t in self.transforms:
208
+ image, target = t(image, target)
209
+ return image, target
210
+
211
+ def __repr__(self):
212
+ format_string = self.__class__.__name__ + "("
213
+ for t in self.transforms:
214
+ format_string += "\n"
215
+ format_string += " {0}".format(t)
216
+ format_string += "\n)"
217
+ return format_string
218
+
219
+
220
+ class ToTensor(object):
221
+ def __call__(self, img, target):
222
+ return F.to_tensor(img), target
223
+
224
+
225
+ class Normalize(object):
226
+ def __init__(self, mean, std):
227
+ self.mean = mean
228
+ self.std = std
229
+
230
+ def __call__(self, image, target=None):
231
+ image = F.normalize(image, mean=self.mean, std=self.std)
232
+ if target is None:
233
+ return image, None
234
+ target = target.copy()
235
+ h, w = image.shape[-2:]
236
+ if "boxes" in target:
237
+ boxes = target["boxes"]
238
+ boxes = box_xyxy_to_cxcywh(boxes)
239
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
240
+ target["boxes"] = boxes
241
+ return image, target
242
+
243
+
244
+ class SquareResize(object):
245
+ def __init__(self, sizes):
246
+ assert isinstance(sizes, (list, tuple))
247
+ self.sizes = sizes
248
+
249
+ def __call__(self, img, target=None):
250
+ size = random.choice(self.sizes)
251
+ rescaled_img=F.resize(img, (size, size))
252
+ w, h = rescaled_img.size
253
+ if target is None:
254
+ return rescaled_img, None
255
+ ratios = tuple(
256
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
257
+ ratio_width, ratio_height = ratios
258
+
259
+ target = target.copy()
260
+ if "boxes" in target:
261
+ boxes = target["boxes"]
262
+ scaled_boxes = boxes * torch.as_tensor(
263
+ [ratio_width, ratio_height, ratio_width, ratio_height])
264
+ target["boxes"] = scaled_boxes
265
+
266
+ if "area" in target:
267
+ area = target["area"]
268
+ scaled_area = area * (ratio_width * ratio_height)
269
+ target["area"] = scaled_area
270
+
271
+ target["size"] = torch.tensor([h, w])
272
+
273
+ return rescaled_img, target
274
+
275
+
276
+ def infer_transforms():
277
+ normalize = Compose([
278
+ ToTensor(),
279
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
280
+ ])
281
+ return Compose([
282
+ SquareResize([640]),
283
+ normalize,
284
+ ])
285
+
286
+
287
+ def box_cxcywh_to_xyxy(x):
288
+ x_c, y_c, w, h = x.unbind(-1)
289
+ b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)),
290
+ (x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))]
291
+ return torch.stack(b, dim=-1)
292
+
293
+
294
+ def post_process(outputs, target_sizes):
295
+ out_logits, out_bbox = outputs['labels'], outputs['dets']
296
+
297
+ assert len(out_logits) == len(target_sizes)
298
+ assert target_sizes.shape[1] == 2
299
+
300
+ prob = out_logits.sigmoid()
301
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
302
+ scores = topk_values
303
+ topk_boxes = topk_indexes // out_logits.shape[2]
304
+ labels = topk_indexes % out_logits.shape[2]
305
+ boxes = box_cxcywh_to_xyxy(out_bbox)
306
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
307
+
308
+ # and from relative [0, 1] to absolute [0, height] coordinates
309
+ img_h, img_w = target_sizes.unbind(1)
310
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
311
+ boxes = boxes * scale_fct[:, None, :]
312
+
313
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
314
+
315
+ return results
316
+
317
+
318
+ def infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
319
+ time_list = []
320
+ for img_dict in tqdm.tqdm(img_list):
321
+ image = load_image(os.path.join(prefix, img_dict['file_name']))
322
+ width, height = image.size
323
+ orig_target_sizes = torch.Tensor([height, width])
324
+ image_tensor, _ = infer_transforms()(image, None) # target is None
325
+
326
+ samples = image_tensor[None].numpy()
327
+
328
+ time_profile.reset()
329
+ with time_profile:
330
+ for _ in range(repeats):
331
+ res = sess.run(None, {"input": samples})
332
+ time_list.append(time_profile.total / repeats)
333
+ outputs = {}
334
+ outputs['labels'] = torch.Tensor(res[1]).to(device)
335
+ outputs['dets'] = torch.Tensor(res[0]).to(device)
336
+
337
+ orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
338
+ results = post_process(outputs, orig_target_sizes)
339
+ res = {img_dict['id']: results[0]}
340
+ if coco_evaluator is not None:
341
+ coco_evaluator.update(res)
342
+
343
+ print("Model latency with ONNX Runtime: {}ms".format(1000 * sum(time_list) / len(img_list)))
344
+
345
+ # accumulate predictions from all images
346
+ stats = {}
347
+ if coco_evaluator is not None:
348
+ coco_evaluator.synchronize_between_processes()
349
+ coco_evaluator.accumulate()
350
+ coco_evaluator.summarize()
351
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
352
+ print(stats)
353
+
354
+
355
+ def infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
356
+ time_list = []
357
+ for img_dict in tqdm.tqdm(img_list):
358
+ image = load_image(os.path.join(prefix, img_dict['file_name']))
359
+ width, height = image.size
360
+ orig_target_sizes = torch.Tensor([height, width])
361
+ image_tensor, _ = infer_transforms()(image, None) # target is None
362
+
363
+ samples = image_tensor[None].to(device)
364
+ _, _, h, w = samples.shape
365
+ im_shape = torch.Tensor(np.array([h, w]).reshape((1, 2)).astype(np.float32)).to(device)
366
+ scale_factor = torch.Tensor(np.array([h / height, w / width]).reshape((1, 2)).astype(np.float32)).to(device)
367
+
368
+ time_profile.reset()
369
+ with time_profile:
370
+ for _ in range(repeats):
371
+ outputs = model({"input": samples})
372
+
373
+ time_list.append(time_profile.total / repeats)
374
+ orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
375
+ if coco_evaluator is not None:
376
+ results = post_process(outputs, orig_target_sizes)
377
+ res = {img_dict['id']: results[0]}
378
+ coco_evaluator.update(res)
379
+
380
+ print("Model latency with TensorRT: {}ms".format(1000 * sum(time_list) / len(img_list)))
381
+
382
+ # accumulate predictions from all images
383
+ stats = {}
384
+ if coco_evaluator is not None:
385
+ coco_evaluator.synchronize_between_processes()
386
+ coco_evaluator.accumulate()
387
+ coco_evaluator.summarize()
388
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
389
+ print(stats)
390
+
391
+
392
+ class TRTInference(object):
393
+ """TensorRT inference engine
394
+ """
395
+ def __init__(self, engine_path='dino.engine', device='cuda:0', sync_mode:bool=False, max_batch_size=32, verbose=False):
396
+ self.engine_path = engine_path
397
+ self.device = device
398
+ self.sync_mode = sync_mode
399
+ self.max_batch_size = max_batch_size
400
+
401
+ self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
402
+
403
+ self.engine = self.load_engine(engine_path)
404
+
405
+ self.context = self.engine.create_execution_context()
406
+
407
+ self.bindings = self.get_bindings(self.engine, self.context, self.max_batch_size, self.device)
408
+ self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items())
409
+
410
+ self.input_names = self.get_input_names()
411
+ self.output_names = self.get_output_names()
412
+
413
+ if not self.sync_mode:
414
+ self.stream = cuda.Stream()
415
+
416
+ # self.time_profile = TimeProfiler()
417
+ self.time_profile = None
418
+
419
+ def get_dummy_input(self, batch_size:int):
420
+ blob = {}
421
+ for name, binding in self.bindings.items():
422
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
423
+ print(f"make dummy input {name} with shape {binding.shape}")
424
+ blob[name] = torch.rand(batch_size, *binding.shape[1:]).float().to('cuda:0')
425
+ return blob
426
+
427
+ def load_engine(self, path):
428
+ '''load engine
429
+ '''
430
+ trt.init_libnvinfer_plugins(self.logger, '')
431
+ with open(path, 'rb') as f, trt.Runtime(self.logger) as runtime:
432
+ return runtime.deserialize_cuda_engine(f.read())
433
+
434
+ def get_input_names(self, ):
435
+ names = []
436
+ for _, name in enumerate(self.engine):
437
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
438
+ names.append(name)
439
+ return names
440
+
441
+ def get_output_names(self, ):
442
+ names = []
443
+ for _, name in enumerate(self.engine):
444
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
445
+ names.append(name)
446
+ return names
447
+
448
+ def get_bindings(self, engine, context, max_batch_size=32, device=None):
449
+ '''build binddings
450
+ '''
451
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
452
+ bindings = OrderedDict()
453
+
454
+ for i, name in enumerate(engine):
455
+ shape = engine.get_tensor_shape(name)
456
+ dtype = trt.nptype(engine.get_tensor_dtype(name))
457
+
458
+ if shape[0] == -1:
459
+ raise NotImplementedError
460
+
461
+ if False:
462
+ if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
463
+ data = np.random.randn(*shape).astype(dtype)
464
+ ptr = cuda.mem_alloc(data.nbytes)
465
+ bindings[name] = Binding(name, dtype, shape, data, ptr)
466
+ else:
467
+ data = cuda.pagelocked_empty(trt.volume(shape), dtype)
468
+ ptr = cuda.mem_alloc(data.nbytes)
469
+ bindings[name] = Binding(name, dtype, shape, data, ptr)
470
+
471
+ else:
472
+ data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
473
+ bindings[name] = Binding(name, dtype, shape, data, data.data_ptr())
474
+
475
+ return bindings
476
+
477
+ def run_sync(self, blob):
478
+ self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
479
+ self.context.execute_v2(list(self.bindings_addr.values()))
480
+ outputs = {n: self.bindings[n].data for n in self.output_names}
481
+ return outputs
482
+
483
+ def run_async(self, blob):
484
+ self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
485
+ bindings_addr = [int(v) for _, v in self.bindings_addr.items()]
486
+ self.context.execute_async_v2(bindings=bindings_addr, stream_handle=self.stream.handle)
487
+ outputs = {n: self.bindings[n].data for n in self.output_names}
488
+ self.stream.synchronize()
489
+ return outputs
490
+
491
+ def __call__(self, blob):
492
+ if self.sync_mode:
493
+ return self.run_sync(blob)
494
+ else:
495
+ return self.run_async(blob)
496
+
497
+ def synchronize(self, ):
498
+ if not self.sync_mode and torch.cuda.is_available():
499
+ torch.cuda.synchronize()
500
+ elif self.sync_mode:
501
+ self.stream.synchronize()
502
+
503
+ def speed(self, blob, n):
504
+ self.time_profile.reset()
505
+ with self.time_profile:
506
+ for _ in range(n):
507
+ _ = self(blob)
508
+ return self.time_profile.total / n
509
+
510
+
511
+ def build_engine(self, onnx_file_path, engine_file_path, max_batch_size=32):
512
+ '''Takes an ONNX file and creates a TensorRT engine to run inference with
513
+ http://gitlab.baidu.com/paddle-inference/benchmark/blob/main/backend_trt.py#L57
514
+ '''
515
+ EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
516
+ with trt.Builder(self.logger) as builder, \
517
+ builder.create_network(EXPLICIT_BATCH) as network, \
518
+ trt.OnnxParser(network, self.logger) as parser, \
519
+ builder.create_builder_config() as config:
520
+
521
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1024 MiB
522
+ config.set_flag(trt.BuilderFlag.FP16)
523
+
524
+ with open(onnx_file_path, 'rb') as model:
525
+ if not parser.parse(model.read()):
526
+ print('ERROR: Failed to parse the ONNX file.')
527
+ for error in range(parser.num_errors):
528
+ print(parser.get_error(error))
529
+ return None
530
+
531
+ serialized_engine = builder.build_serialized_network(network, config)
532
+ with open(engine_file_path, 'wb') as f:
533
+ f.write(serialized_engine)
534
+
535
+ return serialized_engine
536
+
537
+
538
+ class TimeProfiler(contextlib.ContextDecorator):
539
+ def __init__(self, ):
540
+ self.total = 0
541
+
542
+ def __enter__(self, ):
543
+ self.start = self.time()
544
+ return self
545
+
546
+ def __exit__(self, type, value, traceback):
547
+ self.total += self.time() - self.start
548
+
549
+ def reset(self, ):
550
+ self.total = 0
551
+
552
+ def time(self, ):
553
+ if torch.cuda.is_available():
554
+ torch.cuda.synchronize()
555
+ return time.perf_counter()
556
+
557
+
558
+ def main(args):
559
+ print(args)
560
+
561
+ coco_gt = osp.join(args.coco_path, 'annotations/instances_val2017.json')
562
+ img_list = get_image_list(coco_gt)
563
+ prefix = osp.join(args.coco_path, 'val2017')
564
+ if args.run_benchmark:
565
+ repeats = 10
566
+ print('Inference for each image will be repeated 10 times to obtain '
567
+ 'a reliable measurement of inference latency.')
568
+ else:
569
+ repeats = 1
570
+
571
+ if args.disable_eval:
572
+ coco_evaluator = None
573
+ else:
574
+ coco_evaluator = CocoEvaluator(coco_gt, ('bbox',))
575
+
576
+ time_profile = TimeProfiler()
577
+
578
+ if args.path.endswith(".onnx"):
579
+ sess = nxrun.InferenceSession(args.path, providers=['CUDAExecutionProvider'])
580
+ infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
581
+ elif args.path.endswith(".engine"):
582
+ model = TRTInference(args.path, sync_mode=True, device=f'cuda:{args.device}')
583
+ infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
584
+ else:
585
+ raise NotImplementedError('Only model file names ending with ".onnx" and ".engine" are supported.')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ args = parser_args()
590
+ main(args)
rfdetr/deploy/export.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ export ONNX model and TensorRT engine for deployment
12
+ """
13
+ import os
14
+ import ast
15
+ import random
16
+ import argparse
17
+ import subprocess
18
+ import torch.nn as nn
19
+ from pathlib import Path
20
+ import time
21
+ from collections import defaultdict
22
+
23
+ import onnx
24
+ import torch
25
+ import onnxsim
26
+ import numpy as np
27
+ from PIL import Image
28
+
29
+ import rfdetr.util.misc as utils
30
+ import rfdetr.datasets.transforms as T
31
+ from rfdetr.models import build_model
32
+ from rfdetr.deploy._onnx import OnnxOptimizer
33
+ import re
34
+ import sys
35
+
36
+
37
+ def run_command_shell(command, dry_run:bool = False) -> int:
38
+ if dry_run:
39
+ print("")
40
+ print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} {command}")
41
+ print("")
42
+ try:
43
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
44
+ return result
45
+ except subprocess.CalledProcessError as e:
46
+ print(f"Command failed with exit code {e.returncode}")
47
+ print(f"Error output:\n{e.stderr.decode('utf-8')}")
48
+ raise
49
+
50
+
51
+ def make_infer_image(infer_dir, shape, batch_size, device="cuda"):
52
+ if infer_dir is None:
53
+ dummy = np.random.randint(0, 256, (shape[0], shape[1], 3), dtype=np.uint8)
54
+ image = Image.fromarray(dummy, mode="RGB")
55
+ else:
56
+ image = Image.open(infer_dir).convert("RGB")
57
+
58
+ transforms = T.Compose([
59
+ T.SquareResize([shape[0]]),
60
+ T.ToTensor(),
61
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
62
+ ])
63
+
64
+ inps, _ = transforms(image, None)
65
+ inps = inps.to(device)
66
+ # inps = utils.nested_tensor_from_tensor_list([inps for _ in range(args.batch_size)])
67
+ inps = torch.stack([inps for _ in range(batch_size)])
68
+ return inps
69
+
70
+ def export_onnx(output_dir, model, input_names, input_tensors, output_names, dynamic_axes, backbone_only=False, verbose=True, opset_version=17):
71
+ export_name = "backbone_model" if backbone_only else "inference_model"
72
+ output_file = os.path.join(output_dir, f"{export_name}.onnx")
73
+
74
+ # Prepare model for export
75
+ if hasattr(model, "export"):
76
+ model.export()
77
+
78
+ torch.onnx.export(
79
+ model,
80
+ input_tensors,
81
+ output_file,
82
+ input_names=input_names,
83
+ output_names=output_names,
84
+ export_params=True,
85
+ keep_initializers_as_inputs=False,
86
+ do_constant_folding=True,
87
+ verbose=verbose,
88
+ opset_version=opset_version,
89
+ dynamic_axes=dynamic_axes)
90
+
91
+ print(f'\nSuccessfully exported ONNX model: {output_file}')
92
+ return output_file
93
+
94
+
95
+ def onnx_simplify(onnx_dir:str, input_names, input_tensors, force=False):
96
+ sim_onnx_dir = onnx_dir.replace(".onnx", ".sim.onnx")
97
+ if os.path.isfile(sim_onnx_dir) and not force:
98
+ return sim_onnx_dir
99
+
100
+ if isinstance(input_tensors, torch.Tensor):
101
+ input_tensors = [input_tensors]
102
+
103
+ print(f'start simplify ONNX model: {onnx_dir}')
104
+ opt = OnnxOptimizer(onnx_dir)
105
+ opt.info('Model: original')
106
+ opt.common_opt()
107
+ opt.info('Model: optimized')
108
+ opt.save_onnx(sim_onnx_dir)
109
+ input_dict = {name: tensor.detach().cpu().numpy() for name, tensor in zip(input_names, input_tensors)}
110
+ model_opt, check_ok = onnxsim.simplify(
111
+ onnx_dir,
112
+ check_n = 3,
113
+ input_data=input_dict,
114
+ dynamic_input_shape=False)
115
+ if check_ok:
116
+ onnx.save(model_opt, sim_onnx_dir)
117
+ else:
118
+ raise RuntimeError("Failed to simplify ONNX model.")
119
+ print(f'Successfully simplified ONNX model: {sim_onnx_dir}')
120
+ return sim_onnx_dir
121
+
122
+
123
+ def trtexec(onnx_dir:str, args) -> None:
124
+ engine_dir = onnx_dir.replace(".onnx", f".engine")
125
+
126
+ # Base trtexec command
127
+ trt_command = " ".join([
128
+ "trtexec",
129
+ f"--onnx={onnx_dir}",
130
+ f"--saveEngine={engine_dir}",
131
+ f"--memPoolSize=workspace:4096 --fp16",
132
+ f"--useCudaGraph --useSpinWait --warmUp=500 --avgRuns=1000 --duration=10",
133
+ f"{'--verbose' if args.verbose else ''}"])
134
+
135
+ if args.profile:
136
+ profile_dir = onnx_dir.replace(".onnx", f".nsys-rep")
137
+ # Wrap with nsys profile command
138
+ command = " ".join([
139
+ "nsys profile",
140
+ f"--output={profile_dir}",
141
+ "--trace=cuda,nvtx",
142
+ "--force-overwrite true",
143
+ trt_command
144
+ ])
145
+ print(f'Profile data will be saved to: {profile_dir}')
146
+ else:
147
+ command = trt_command
148
+
149
+ output = run_command_shell(command, args.dry_run)
150
+ stats = parse_trtexec_output(output.stdout)
151
+
152
+ def parse_trtexec_output(output_text):
153
+ print(output_text)
154
+ # Common patterns in trtexec output
155
+ gpu_compute_pattern = r"GPU Compute Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms, median = (\d+\.\d+) ms"
156
+ h2d_pattern = r"Host to Device Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
157
+ d2h_pattern = r"Device to Host Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
158
+ latency_pattern = r"Latency: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
159
+ throughput_pattern = r"Throughput: (\d+\.\d+) qps"
160
+
161
+ stats = {}
162
+
163
+ # Extract compute times
164
+ if match := re.search(gpu_compute_pattern, output_text):
165
+ stats.update({
166
+ 'compute_min_ms': float(match.group(1)),
167
+ 'compute_max_ms': float(match.group(2)),
168
+ 'compute_mean_ms': float(match.group(3)),
169
+ 'compute_median_ms': float(match.group(4))
170
+ })
171
+
172
+ # Extract H2D times
173
+ if match := re.search(h2d_pattern, output_text):
174
+ stats.update({
175
+ 'h2d_min_ms': float(match.group(1)),
176
+ 'h2d_max_ms': float(match.group(2)),
177
+ 'h2d_mean_ms': float(match.group(3))
178
+ })
179
+
180
+ # Extract D2H times
181
+ if match := re.search(d2h_pattern, output_text):
182
+ stats.update({
183
+ 'd2h_min_ms': float(match.group(1)),
184
+ 'd2h_max_ms': float(match.group(2)),
185
+ 'd2h_mean_ms': float(match.group(3))
186
+ })
187
+
188
+ if match := re.search(latency_pattern, output_text):
189
+ stats.update({
190
+ 'latency_min_ms': float(match.group(1)),
191
+ 'latency_max_ms': float(match.group(2)),
192
+ 'latency_mean_ms': float(match.group(3))
193
+ })
194
+
195
+ # Extract throughput
196
+ if match := re.search(throughput_pattern, output_text):
197
+ stats['throughput_qps'] = float(match.group(1))
198
+
199
+ return stats
200
+
201
+ def no_batch_norm(model):
202
+ for module in model.modules():
203
+ if isinstance(module, nn.BatchNorm2d):
204
+ raise ValueError("BatchNorm2d found in the model. Please remove it.")
205
+
206
+ def main(args):
207
+ print("git:\n {}\n".format(utils.get_sha()))
208
+ print(args)
209
+ # convert device to device_id
210
+ if args.device == 'cuda':
211
+ device_id = "0"
212
+ elif args.device == 'cpu':
213
+ device_id = ""
214
+ else:
215
+ device_id = str(int(args.device))
216
+ args.device = f"cuda:{device_id}"
217
+
218
+ # device for export onnx
219
+ # TODO: export onnx with cuda failed with onnx error
220
+ device = torch.device("cpu")
221
+ os.environ["CUDA_VISIBLE_DEVICES"] = device_id
222
+
223
+ # fix the seed for reproducibility
224
+ seed = args.seed + utils.get_rank()
225
+ torch.manual_seed(seed)
226
+ np.random.seed(seed)
227
+ random.seed(seed)
228
+
229
+ model, criterion, postprocessors = build_model(args)
230
+ n_parameters = sum(p.numel() for p in model.parameters())
231
+ print(f"number of parameters: {n_parameters}")
232
+ n_backbone_parameters = sum(p.numel() for p in model.backbone.parameters())
233
+ print(f"number of backbone parameters: {n_backbone_parameters}")
234
+ n_projector_parameters = sum(p.numel() for p in model.backbone[0].projector.parameters())
235
+ print(f"number of projector parameters: {n_projector_parameters}")
236
+ n_backbone_encoder_parameters = sum(p.numel() for p in model.backbone[0].encoder.parameters())
237
+ print(f"number of backbone encoder parameters: {n_backbone_encoder_parameters}")
238
+ n_transformer_parameters = sum(p.numel() for p in model.transformer.parameters())
239
+ print(f"number of transformer parameters: {n_transformer_parameters}")
240
+ if args.resume:
241
+ checkpoint = torch.load(args.resume, map_location='cpu')
242
+ model.load_state_dict(checkpoint['model'], strict=True)
243
+ print(f"load checkpoints {args.resume}")
244
+
245
+ if args.layer_norm:
246
+ no_batch_norm(model)
247
+
248
+ model.to(device)
249
+
250
+ input_tensors = make_infer_image(args, device)
251
+ input_names = ['input']
252
+ output_names = ['features'] if args.backbone_only else ['dets', 'labels']
253
+ dynamic_axes = None
254
+ # Run model inference in pytorch mode
255
+ model.eval().to("cuda")
256
+ input_tensors = input_tensors.to("cuda")
257
+ with torch.no_grad():
258
+ if args.backbone_only:
259
+ features = model(input_tensors)
260
+ print(f"PyTorch inference output shape: {features.shape}")
261
+ else:
262
+ outputs = model(input_tensors)
263
+ dets = outputs['pred_boxes']
264
+ labels = outputs['pred_logits']
265
+ print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
266
+ model.cpu()
267
+ input_tensors = input_tensors.cpu()
268
+
269
+
270
+ output_file = export_onnx(model, args, input_names, input_tensors, output_names, dynamic_axes)
271
+
272
+ if args.simplify:
273
+ output_file = onnx_simplify(output_file, input_names, input_tensors, args)
274
+
275
+ if args.tensorrt:
276
+ output_file = trtexec(output_file, args)
rfdetr/deploy/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pycuda
2
+ onnx
3
+ onnxsim
4
+ onnxruntime
5
+ onnxruntime-gpu
6
+ onnx_graphsurgeon
7
+ tensorrt>=8.6.1
8
+ polygraphy
rfdetr/detr.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ import json
9
+ import os
10
+ from collections import defaultdict
11
+ from logging import getLogger
12
+ from typing import Union, List
13
+ from copy import deepcopy
14
+
15
+ import numpy as np
16
+ import supervision as sv
17
+ import torch
18
+ import torchvision.transforms.functional as F
19
+ from PIL import Image
20
+
21
+ try:
22
+ torch.set_float32_matmul_precision('high')
23
+ except:
24
+ pass
25
+
26
+ from rfdetr.config import RFDETRBaseConfig, RFDETRLargeConfig, TrainConfig, ModelConfig
27
+ from rfdetr.main import Model, download_pretrain_weights
28
+ from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink
29
+ from rfdetr.util.coco_classes import COCO_CLASSES
30
+
31
+ logger = getLogger(__name__)
32
+ class RFDETR:
33
+ means = [0.485, 0.456, 0.406]
34
+ stds = [0.229, 0.224, 0.225]
35
+
36
+ def __init__(self, **kwargs):
37
+ self.model_config = self.get_model_config(**kwargs)
38
+ self.maybe_download_pretrain_weights()
39
+ self.model = self.get_model(self.model_config)
40
+ self.callbacks = defaultdict(list)
41
+
42
+ self.model.inference_model = None
43
+ self._is_optimized_for_inference = False
44
+ self._has_warned_about_not_being_optimized_for_inference = False
45
+ self._optimized_has_been_compiled = False
46
+ self._optimized_batch_size = None
47
+ self._optimized_resolution = None
48
+ self._optimized_dtype = None
49
+
50
+ def maybe_download_pretrain_weights(self):
51
+ download_pretrain_weights(self.model_config.pretrain_weights)
52
+
53
+ def get_model_config(self, **kwargs):
54
+ return ModelConfig(**kwargs)
55
+
56
+ def train(self, **kwargs):
57
+ config = self.get_train_config(**kwargs)
58
+ self.train_from_config(config, **kwargs)
59
+
60
+ def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
61
+ self.remove_optimized_model()
62
+
63
+ self.model.inference_model = deepcopy(self.model.model)
64
+ self.model.inference_model.eval()
65
+ self.model.inference_model.export()
66
+
67
+ self._optimized_resolution = self.model.resolution
68
+ self._is_optimized_for_inference = True
69
+
70
+ self.model.inference_model = self.model.inference_model.to(dtype=dtype)
71
+ self._optimized_dtype = dtype
72
+
73
+ if compile:
74
+ self.model.inference_model = torch.jit.trace(
75
+ self.model.inference_model,
76
+ torch.randn(
77
+ batch_size, 3, self.model.resolution, self.model.resolution,
78
+ device=self.model.device,
79
+ dtype=dtype
80
+ )
81
+ )
82
+ self._optimized_has_been_compiled = True
83
+ self._optimized_batch_size = batch_size
84
+
85
+ def remove_optimized_model(self):
86
+ self.model.inference_model = None
87
+ self._is_optimized_for_inference = False
88
+ self._optimized_has_been_compiled = False
89
+ self._optimized_batch_size = None
90
+ self._optimized_resolution = None
91
+ self._optimized_half = False
92
+
93
+ def export(self, **kwargs):
94
+ self.model.export(**kwargs)
95
+
96
+ def train_from_config(self, config: TrainConfig, **kwargs):
97
+ with open(
98
+ os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r"
99
+ ) as f:
100
+ anns = json.load(f)
101
+ num_classes = len(anns["categories"])
102
+ class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"]
103
+ self.model.class_names = class_names
104
+
105
+ if self.model_config.num_classes != num_classes:
106
+ logger.warning(
107
+ f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n"
108
+ f"reinitializing your detection head with {num_classes} classes."
109
+ )
110
+ self.model.reinitialize_detection_head(num_classes)
111
+
112
+
113
+ train_config = config.dict()
114
+ model_config = self.model_config.dict()
115
+ model_config.pop("num_classes")
116
+ if "class_names" in model_config:
117
+ model_config.pop("class_names")
118
+
119
+ if "class_names" in train_config and train_config["class_names"] is None:
120
+ train_config["class_names"] = class_names
121
+
122
+ for k, v in train_config.items():
123
+ if k in model_config:
124
+ model_config.pop(k)
125
+ if k in kwargs:
126
+ kwargs.pop(k)
127
+
128
+ all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes}
129
+
130
+ metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)
131
+ self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update)
132
+ self.callbacks["on_train_end"].append(metrics_plot_sink.save)
133
+
134
+ if config.tensorboard:
135
+ metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=config.output_dir)
136
+ self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update)
137
+ self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close)
138
+
139
+ if config.wandb:
140
+ metrics_wandb_sink = MetricsWandBSink(
141
+ output_dir=config.output_dir,
142
+ project=config.project,
143
+ run=config.run,
144
+ config=config.model_dump()
145
+ )
146
+ self.callbacks["on_fit_epoch_end"].append(metrics_wandb_sink.update)
147
+ self.callbacks["on_train_end"].append(metrics_wandb_sink.close)
148
+
149
+ if config.early_stopping:
150
+ from rfdetr.util.early_stopping import EarlyStoppingCallback
151
+ early_stopping_callback = EarlyStoppingCallback(
152
+ model=self.model,
153
+ patience=config.early_stopping_patience,
154
+ min_delta=config.early_stopping_min_delta,
155
+ use_ema=config.early_stopping_use_ema
156
+ )
157
+ self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update)
158
+
159
+ self.model.train(
160
+ **all_kwargs,
161
+ callbacks=self.callbacks,
162
+ )
163
+
164
+ def get_train_config(self, **kwargs):
165
+ return TrainConfig(**kwargs)
166
+
167
+ def get_model(self, config: ModelConfig):
168
+ return Model(**config.dict())
169
+
170
+ # Get class_names from the model
171
+ @property
172
+ def class_names(self):
173
+ if hasattr(self.model, 'class_names') and self.model.class_names:
174
+ return {i+1: name for i, name in enumerate(self.model.class_names)}
175
+
176
+ return COCO_CLASSES
177
+
178
+ def predict(
179
+ self,
180
+ images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]],
181
+ threshold: float = 0.5,
182
+ **kwargs,
183
+ ) -> Union[sv.Detections, List[sv.Detections]]:
184
+ """Performs object detection on the input images and returns bounding box
185
+ predictions.
186
+
187
+ This method accepts a single image or a list of images in various formats
188
+ (file path, PIL Image, NumPy array, or torch.Tensor). The images should be in
189
+ RGB channel order. If a torch.Tensor is provided, it must already be normalized
190
+ to values in the [0, 1] range and have the shape (C, H, W).
191
+
192
+ Args:
193
+ images (Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]]):
194
+ A single image or a list of images to process. Images can be provided
195
+ as file paths, PIL Images, NumPy arrays, or torch.Tensors.
196
+ threshold (float, optional):
197
+ The minimum confidence score needed to consider a detected bounding box valid.
198
+ **kwargs:
199
+ Additional keyword arguments.
200
+
201
+ Returns:
202
+ Union[sv.Detections, List[sv.Detections]]: A single or multiple Detections
203
+ objects, each containing bounding box coordinates, confidence scores,
204
+ and class IDs.
205
+ """
206
+ if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
207
+ logger.warning(
208
+ "Model is not optimized for inference. "
209
+ "Latency may be higher than expected. "
210
+ "You can optimize the model for inference by calling model.optimize_for_inference()."
211
+ )
212
+ self._has_warned_about_not_being_optimized_for_inference = True
213
+
214
+ self.model.model.eval()
215
+
216
+ if not isinstance(images, list):
217
+ images = [images]
218
+
219
+ orig_sizes = []
220
+ processed_images = []
221
+
222
+ for img in images:
223
+
224
+ if isinstance(img, str):
225
+ img = Image.open(img)
226
+
227
+ if not isinstance(img, torch.Tensor):
228
+ img = F.to_tensor(img)
229
+
230
+ if (img > 1).any():
231
+ raise ValueError(
232
+ "Image has pixel values above 1. Please ensure the image is "
233
+ "normalized (scaled to [0, 1])."
234
+ )
235
+ if img.shape[0] != 3:
236
+ raise ValueError(
237
+ f"Invalid image shape. Expected 3 channels (RGB), but got "
238
+ f"{img.shape[0]} channels."
239
+ )
240
+ img_tensor = img
241
+
242
+ h, w = img_tensor.shape[1:]
243
+ orig_sizes.append((h, w))
244
+
245
+ img_tensor = img_tensor.to(self.model.device)
246
+ img_tensor = F.normalize(img_tensor, self.means, self.stds)
247
+ img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))
248
+
249
+ processed_images.append(img_tensor)
250
+
251
+ batch_tensor = torch.stack(processed_images)
252
+
253
+ if self._is_optimized_for_inference:
254
+ if self._optimized_resolution != batch_tensor.shape[2]:
255
+ # this could happen if someone manually changes self.model.resolution after optimizing the model
256
+ raise ValueError(f"Resolution mismatch. "
257
+ f"Model was optimized for resolution {self._optimized_resolution}, "
258
+ f"but got {batch_tensor.shape[2]}. "
259
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model().")
260
+ if self._optimized_has_been_compiled:
261
+ if self._optimized_batch_size != batch_tensor.shape[0]:
262
+ raise ValueError(f"Batch size mismatch. "
263
+ f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
264
+ f"but got {batch_tensor.shape[0]}. "
265
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
266
+ "Alternatively, you can recompile the optimized model for a different batch size "
267
+ "by calling model.optimize_for_inference(batch_size=<new_batch_size>).")
268
+
269
+ with torch.inference_mode():
270
+ if self._is_optimized_for_inference:
271
+ predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
272
+ else:
273
+ predictions = self.model.model(batch_tensor)
274
+ if isinstance(predictions, tuple):
275
+ predictions = {
276
+ "pred_logits": predictions[1],
277
+ "pred_boxes": predictions[0]
278
+ }
279
+ target_sizes = torch.tensor(orig_sizes, device=self.model.device)
280
+ results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes)
281
+
282
+ detections_list = []
283
+ for result in results:
284
+ scores = result["scores"]
285
+ labels = result["labels"]
286
+ boxes = result["boxes"]
287
+
288
+ keep = scores > threshold
289
+ scores = scores[keep]
290
+ labels = labels[keep]
291
+ boxes = boxes[keep]
292
+
293
+ detections = sv.Detections(
294
+ xyxy=boxes.float().cpu().numpy(),
295
+ confidence=scores.float().cpu().numpy(),
296
+ class_id=labels.cpu().numpy(),
297
+ )
298
+ detections_list.append(detections)
299
+
300
+ return detections_list if len(detections_list) > 1 else detections_list[0]
301
+
302
+
303
+ class RFDETRBase(RFDETR):
304
+ def get_model_config(self, **kwargs):
305
+ return RFDETRBaseConfig(**kwargs)
306
+
307
+ def get_train_config(self, **kwargs):
308
+ return TrainConfig(**kwargs)
309
+
310
+ class RFDETRLarge(RFDETR):
311
+ def get_model_config(self, **kwargs):
312
+ return RFDETRLargeConfig(**kwargs)
313
+
314
+ def get_train_config(self, **kwargs):
315
+ return TrainConfig(**kwargs)
rfdetr/engine.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Conditional DETR
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
12
+ # ------------------------------------------------------------------------
13
+ # Copied from DETR (https://github.com/facebookresearch/detr)
14
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
15
+ # ------------------------------------------------------------------------
16
+
17
+ """
18
+ Train and eval functions used in main.py
19
+ """
20
+ import math
21
+ import sys
22
+ from typing import Iterable
23
+
24
+ import torch
25
+
26
+ import rfdetr.util.misc as utils
27
+ from rfdetr.datasets.coco_eval import CocoEvaluator
28
+
29
+ try:
30
+ from torch.amp import autocast, GradScaler
31
+ DEPRECATED_AMP = False
32
+ except ImportError:
33
+ from torch.cuda.amp import autocast, GradScaler
34
+ DEPRECATED_AMP = True
35
+ from typing import DefaultDict, List, Callable
36
+ from rfdetr.util.misc import NestedTensor
37
+
38
+
39
+
40
+ def get_autocast_args(args):
41
+ if DEPRECATED_AMP:
42
+ return {'enabled': args.amp, 'dtype': torch.bfloat16}
43
+ else:
44
+ return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}
45
+
46
+
47
+ def train_one_epoch(
48
+ model: torch.nn.Module,
49
+ criterion: torch.nn.Module,
50
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
51
+ data_loader: Iterable,
52
+ optimizer: torch.optim.Optimizer,
53
+ device: torch.device,
54
+ epoch: int,
55
+ batch_size: int,
56
+ max_norm: float = 0,
57
+ ema_m: torch.nn.Module = None,
58
+ schedules: dict = {},
59
+ num_training_steps_per_epoch=None,
60
+ vit_encoder_num_layers=None,
61
+ args=None,
62
+ callbacks: DefaultDict[str, List[Callable]] = None,
63
+ ):
64
+ metric_logger = utils.MetricLogger(delimiter=" ")
65
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
66
+ metric_logger.add_meter(
67
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
68
+ )
69
+ header = "Epoch: [{}]".format(epoch)
70
+ print_freq = 10
71
+ start_steps = epoch * num_training_steps_per_epoch
72
+
73
+ print("Grad accum steps: ", args.grad_accum_steps)
74
+ print("Total batch size: ", batch_size * utils.get_world_size())
75
+
76
+ # Add gradient scaler for AMP
77
+ if DEPRECATED_AMP:
78
+ scaler = GradScaler(enabled=args.amp)
79
+ else:
80
+ scaler = GradScaler('cuda', enabled=args.amp)
81
+
82
+ optimizer.zero_grad()
83
+ assert batch_size % args.grad_accum_steps == 0
84
+ sub_batch_size = batch_size // args.grad_accum_steps
85
+ print("LENGTH OF DATA LOADER:", len(data_loader))
86
+ for data_iter_step, (samples, targets) in enumerate(
87
+ metric_logger.log_every(data_loader, print_freq, header)
88
+ ):
89
+ it = start_steps + data_iter_step
90
+ callback_dict = {
91
+ "step": it,
92
+ "model": model,
93
+ "epoch": epoch,
94
+ }
95
+ for callback in callbacks["on_train_batch_start"]:
96
+ callback(callback_dict)
97
+ if "dp" in schedules:
98
+ if args.distributed:
99
+ model.module.update_drop_path(
100
+ schedules["dp"][it], vit_encoder_num_layers
101
+ )
102
+ else:
103
+ model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers)
104
+ if "do" in schedules:
105
+ if args.distributed:
106
+ model.module.update_dropout(schedules["do"][it])
107
+ else:
108
+ model.update_dropout(schedules["do"][it])
109
+
110
+ for i in range(args.grad_accum_steps):
111
+ start_idx = i * sub_batch_size
112
+ final_idx = start_idx + sub_batch_size
113
+ new_samples_tensors = samples.tensors[start_idx:final_idx]
114
+ new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
115
+ new_samples = new_samples.to(device)
116
+ new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]
117
+
118
+ with autocast(**get_autocast_args(args)):
119
+ outputs = model(new_samples, new_targets)
120
+ loss_dict = criterion(outputs, new_targets)
121
+ weight_dict = criterion.weight_dict
122
+ losses = sum(
123
+ (1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
124
+ for k in loss_dict.keys()
125
+ if k in weight_dict
126
+ )
127
+
128
+
129
+ scaler.scale(losses).backward()
130
+
131
+ # reduce losses over all GPUs for logging purposes
132
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
133
+ loss_dict_reduced_unscaled = {
134
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
135
+ }
136
+ loss_dict_reduced_scaled = {
137
+ k: v * weight_dict[k]
138
+ for k, v in loss_dict_reduced.items()
139
+ if k in weight_dict
140
+ }
141
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
142
+
143
+ loss_value = losses_reduced_scaled.item()
144
+
145
+ if not math.isfinite(loss_value):
146
+ print(loss_dict_reduced)
147
+ raise ValueError("Loss is {}, stopping training".format(loss_value))
148
+
149
+ if max_norm > 0:
150
+ scaler.unscale_(optimizer)
151
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
152
+
153
+ scaler.step(optimizer)
154
+ scaler.update()
155
+ lr_scheduler.step()
156
+ optimizer.zero_grad()
157
+ if ema_m is not None:
158
+ if epoch >= 0:
159
+ ema_m.update(model)
160
+ metric_logger.update(
161
+ loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
162
+ )
163
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
164
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
165
+ # gather the stats from all processes
166
+ metric_logger.synchronize_between_processes()
167
+ print("Averaged stats:", metric_logger)
168
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
169
+
170
+
171
+ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
172
+ model.eval()
173
+ if args.fp16_eval:
174
+ model.half()
175
+ criterion.eval()
176
+
177
+ metric_logger = utils.MetricLogger(delimiter=" ")
178
+ metric_logger.add_meter(
179
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
180
+ )
181
+ header = "Test:"
182
+
183
+ iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
184
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
185
+
186
+ for samples, targets in metric_logger.log_every(data_loader, 10, header):
187
+ samples = samples.to(device)
188
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
189
+
190
+ if args.fp16_eval:
191
+ samples.tensors = samples.tensors.half()
192
+
193
+ # Add autocast for evaluation
194
+ with autocast(**get_autocast_args(args)):
195
+ outputs = model(samples)
196
+
197
+ if args.fp16_eval:
198
+ for key in outputs.keys():
199
+ if key == "enc_outputs":
200
+ for sub_key in outputs[key].keys():
201
+ outputs[key][sub_key] = outputs[key][sub_key].float()
202
+ elif key == "aux_outputs":
203
+ for idx in range(len(outputs[key])):
204
+ for sub_key in outputs[key][idx].keys():
205
+ outputs[key][idx][sub_key] = outputs[key][idx][
206
+ sub_key
207
+ ].float()
208
+ else:
209
+ outputs[key] = outputs[key].float()
210
+
211
+ loss_dict = criterion(outputs, targets)
212
+ weight_dict = criterion.weight_dict
213
+
214
+ # reduce losses over all GPUs for logging purposes
215
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
216
+ loss_dict_reduced_scaled = {
217
+ k: v * weight_dict[k]
218
+ for k, v in loss_dict_reduced.items()
219
+ if k in weight_dict
220
+ }
221
+ loss_dict_reduced_unscaled = {
222
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
223
+ }
224
+ metric_logger.update(
225
+ loss=sum(loss_dict_reduced_scaled.values()),
226
+ **loss_dict_reduced_scaled,
227
+ **loss_dict_reduced_unscaled,
228
+ )
229
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
230
+
231
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
232
+ results = postprocessors["bbox"](outputs, orig_target_sizes)
233
+ res = {
234
+ target["image_id"].item(): output
235
+ for target, output in zip(targets, results)
236
+ }
237
+ if coco_evaluator is not None:
238
+ coco_evaluator.update(res)
239
+
240
+ # gather the stats from all processes
241
+ metric_logger.synchronize_between_processes()
242
+ print("Averaged stats:", metric_logger)
243
+ if coco_evaluator is not None:
244
+ coco_evaluator.synchronize_between_processes()
245
+
246
+ # accumulate predictions from all images
247
+ if coco_evaluator is not None:
248
+ coco_evaluator.accumulate()
249
+ coco_evaluator.summarize()
250
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
251
+ if coco_evaluator is not None:
252
+ if "bbox" in postprocessors.keys():
253
+ stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
254
+ if "segm" in postprocessors.keys():
255
+ stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
256
+ return stats, coco_evaluator
rfdetr/main.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Modified from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ cleaned main file
18
+ """
19
+ import argparse
20
+ import ast
21
+ import copy
22
+ import datetime
23
+ import json
24
+ import math
25
+ import os
26
+ import random
27
+ import shutil
28
+ import time
29
+ from copy import deepcopy
30
+ from logging import getLogger
31
+ from pathlib import Path
32
+ from typing import DefaultDict, List, Callable
33
+
34
+ import numpy as np
35
+ import torch
36
+ from peft import LoraConfig, get_peft_model
37
+ from torch.utils.data import DataLoader, DistributedSampler
38
+
39
+ import rfdetr.util.misc as utils
40
+ from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
41
+ from rfdetr.engine import evaluate, train_one_epoch
42
+ from rfdetr.models import build_model, build_criterion_and_postprocessors
43
+ from rfdetr.util.benchmark import benchmark
44
+ from rfdetr.util.drop_scheduler import drop_scheduler
45
+ from rfdetr.util.files import download_file
46
+ from rfdetr.util.get_param_dicts import get_param_dict
47
+ from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict
48
+
49
+ if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
50
+ import torch.multiprocessing
51
+ torch.multiprocessing.set_sharing_strategy('file_system')
52
+
53
+ logger = getLogger(__name__)
54
+
55
+ HOSTED_MODELS = {
56
+ "rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
57
+ # below is a less converged model that may be better for finetuning but worse for inference
58
+ "rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
59
+ "rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
60
+ }
61
+
62
+ def download_pretrain_weights(pretrain_weights: str, redownload=False):
63
+ if pretrain_weights in HOSTED_MODELS:
64
+ if redownload or not os.path.exists(pretrain_weights):
65
+ logger.info(
66
+ f"Downloading pretrained weights for {pretrain_weights}"
67
+ )
68
+ download_file(
69
+ HOSTED_MODELS[pretrain_weights],
70
+ pretrain_weights,
71
+ )
72
+
73
+ class Model:
74
+ def __init__(self, **kwargs):
75
+ args = populate_args(**kwargs)
76
+ self.resolution = args.resolution
77
+ self.model = build_model(args)
78
+ self.device = torch.device(args.device)
79
+ if args.pretrain_weights is not None:
80
+ print("Loading pretrain weights")
81
+ try:
82
+ checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
83
+ except Exception as e:
84
+ print(f"Failed to load pretrain weights: {e}")
85
+ # re-download weights if they are corrupted
86
+ print("Failed to load pretrain weights, re-downloading")
87
+ download_pretrain_weights(args.pretrain_weights, redownload=True)
88
+ checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
89
+
90
+ # Extract class_names from checkpoint if available
91
+ if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'):
92
+ self.class_names = checkpoint['args'].class_names
93
+
94
+ checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
95
+ if checkpoint_num_classes != args.num_classes + 1:
96
+ logger.warning(
97
+ f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n"
98
+ f"reinitializing detection head with {checkpoint_num_classes - 1} classes"
99
+ )
100
+ self.reinitialize_detection_head(checkpoint_num_classes)
101
+ # add support to exclude_keys
102
+ # e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]`
103
+ if args.pretrain_exclude_keys is not None:
104
+ assert isinstance(args.pretrain_exclude_keys, list)
105
+ for exclude_key in args.pretrain_exclude_keys:
106
+ checkpoint['model'].pop(exclude_key)
107
+ if args.pretrain_keys_modify_to_load is not None:
108
+ from util.obj365_to_coco_model import get_coco_pretrain_from_obj365
109
+ assert isinstance(args.pretrain_keys_modify_to_load, list)
110
+ for modify_key_to_load in args.pretrain_keys_modify_to_load:
111
+ try:
112
+ checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365(
113
+ model_without_ddp.state_dict()[modify_key_to_load],
114
+ checkpoint['model'][modify_key_to_load]
115
+ )
116
+ except:
117
+ print(f"Failed to load {modify_key_to_load}, deleting from checkpoint")
118
+ checkpoint['model'].pop(modify_key_to_load)
119
+
120
+ # we may want to resume training with a smaller number of groups for group detr
121
+ num_desired_queries = args.num_queries * args.group_detr
122
+ query_param_names = ["refpoint_embed.weight", "query_feat.weight"]
123
+ for name, state in checkpoint['model'].items():
124
+ if any(name.endswith(x) for x in query_param_names):
125
+ checkpoint['model'][name] = state[:num_desired_queries]
126
+
127
+ self.model.load_state_dict(checkpoint['model'], strict=False)
128
+
129
+ if args.backbone_lora:
130
+ print("Applying LORA to backbone")
131
+ lora_config = LoraConfig(
132
+ r=16,
133
+ lora_alpha=16,
134
+ use_dora=True,
135
+ target_modules=[
136
+ "q_proj", "v_proj", "k_proj", # covers OWL-ViT
137
+ "qkv", # covers open_clip ie Siglip2
138
+ "query", "key", "value", "cls_token", "register_tokens", # covers Dinov2 with windowed attn
139
+ ]
140
+ )
141
+ self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
142
+ self.model = self.model.to(self.device)
143
+ self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
144
+ self.stop_early = False
145
+
146
+ def reinitialize_detection_head(self, num_classes):
147
+ self.model.reinitialize_detection_head(num_classes)
148
+
149
+ def request_early_stop(self):
150
+ self.stop_early = True
151
+ print("Early stopping requested, will complete current epoch and stop")
152
+
153
+ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
154
+ currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
155
+ for key in callbacks.keys():
156
+ if key not in currently_supported_callbacks:
157
+ raise ValueError(
158
+ f"Callback {key} is not currently supported, please file an issue if you need it!\n"
159
+ f"Currently supported callbacks: {currently_supported_callbacks}"
160
+ )
161
+ args = populate_args(**kwargs)
162
+ utils.init_distributed_mode(args)
163
+ print("git:\n {}\n".format(utils.get_sha()))
164
+ print(args)
165
+ device = torch.device(args.device)
166
+
167
+ # fix the seed for reproducibility
168
+ seed = args.seed + utils.get_rank()
169
+ torch.manual_seed(seed)
170
+ np.random.seed(seed)
171
+ random.seed(seed)
172
+
173
+ criterion, postprocessors = build_criterion_and_postprocessors(args)
174
+ model = self.model
175
+ model.to(device)
176
+
177
+ model_without_ddp = model
178
+ if args.distributed:
179
+ if args.sync_bn:
180
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
181
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
182
+ model_without_ddp = model.module
183
+
184
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
185
+ print('number of params:', n_parameters)
186
+ param_dicts = get_param_dict(args, model_without_ddp)
187
+
188
+ param_dicts = [p for p in param_dicts if p['params'].requires_grad]
189
+
190
+ optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
191
+ weight_decay=args.weight_decay)
192
+ # Choose the learning rate scheduler based on the new argument
193
+
194
+ dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
195
+ dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)
196
+
197
+ # for cosine annealing, calculate total training steps and warmup steps
198
+ total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps
199
+ num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr
200
+ total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs
201
+ warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs
202
+ def lr_lambda(current_step: int):
203
+ if current_step < warmup_steps_lr:
204
+ # Linear warmup
205
+ return float(current_step) / float(max(1, warmup_steps_lr))
206
+ else:
207
+ # Cosine annealing from multiplier 1.0 down to lr_min_factor
208
+ if args.lr_scheduler == 'cosine':
209
+ progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr))
210
+ return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress))
211
+ elif args.lr_scheduler == 'step':
212
+ if current_step < args.lr_drop * num_training_steps_per_epoch_lr:
213
+ return 1.0
214
+ else:
215
+ return 0.1
216
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
217
+
218
+ if args.distributed:
219
+ sampler_train = DistributedSampler(dataset_train)
220
+ sampler_val = DistributedSampler(dataset_val, shuffle=False)
221
+ else:
222
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
223
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
224
+
225
+ effective_batch_size = args.batch_size * args.grad_accum_steps
226
+ min_batches = kwargs.get('min_batches', 5)
227
+ if len(dataset_train) < effective_batch_size * min_batches:
228
+ logger.info(
229
+ f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}"
230
+ )
231
+ sampler = torch.utils.data.RandomSampler(
232
+ dataset_train,
233
+ replacement=True,
234
+ num_samples=effective_batch_size * min_batches,
235
+ )
236
+ data_loader_train = DataLoader(
237
+ dataset_train,
238
+ batch_size=effective_batch_size,
239
+ collate_fn=utils.collate_fn,
240
+ num_workers=args.num_workers,
241
+ sampler=sampler,
242
+ )
243
+ else:
244
+ batch_sampler_train = torch.utils.data.BatchSampler(
245
+ sampler_train, effective_batch_size, drop_last=True)
246
+ data_loader_train = DataLoader(
247
+ dataset_train,
248
+ batch_sampler=batch_sampler_train,
249
+ collate_fn=utils.collate_fn,
250
+ num_workers=args.num_workers
251
+ )
252
+
253
+ data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
254
+ drop_last=False, collate_fn=utils.collate_fn,
255
+ num_workers=args.num_workers)
256
+
257
+ base_ds = get_coco_api_from_dataset(dataset_val)
258
+
259
+ if args.use_ema:
260
+ self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau)
261
+ else:
262
+ self.ema_m = None
263
+
264
+
265
+ output_dir = Path(args.output_dir)
266
+
267
+ if utils.is_main_process():
268
+ print("Get benchmark")
269
+ if args.do_benchmark:
270
+ benchmark_model = copy.deepcopy(model_without_ddp)
271
+ bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
272
+ print(json.dumps(bm, indent=2))
273
+ del benchmark_model
274
+
275
+ if args.resume:
276
+ checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
277
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
278
+ if args.use_ema:
279
+ if 'ema_model' in checkpoint:
280
+ self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
281
+ else:
282
+ del self.ema_m
283
+ self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau)
284
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
285
+ optimizer.load_state_dict(checkpoint['optimizer'])
286
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
287
+ args.start_epoch = checkpoint['epoch'] + 1
288
+
289
+ if args.eval:
290
+ test_stats, coco_evaluator = evaluate(
291
+ model, criterion, postprocessors, data_loader_val, base_ds, device, args)
292
+ if args.output_dir:
293
+ utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
294
+ return
295
+
296
+ # for drop
297
+ total_batch_size = effective_batch_size * utils.get_world_size()
298
+ num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size
299
+ schedules = {}
300
+ if args.dropout > 0:
301
+ schedules['do'] = drop_scheduler(
302
+ args.dropout, args.epochs, num_training_steps_per_epoch,
303
+ args.cutoff_epoch, args.drop_mode, args.drop_schedule)
304
+ print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
305
+
306
+ if args.drop_path > 0:
307
+ schedules['dp'] = drop_scheduler(
308
+ args.drop_path, args.epochs, num_training_steps_per_epoch,
309
+ args.cutoff_epoch, args.drop_mode, args.drop_schedule)
310
+ print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
311
+
312
+ print("Start training")
313
+ start_time = time.time()
314
+ best_map_holder = BestMetricHolder(use_ema=args.use_ema)
315
+ best_map_5095 = 0
316
+ best_map_50 = 0
317
+ best_map_ema_5095 = 0
318
+ best_map_ema_50 = 0
319
+ for epoch in range(args.start_epoch, args.epochs):
320
+ epoch_start_time = time.time()
321
+ if args.distributed:
322
+ sampler_train.set_epoch(epoch)
323
+
324
+ model.train()
325
+ criterion.train()
326
+ train_stats = train_one_epoch(
327
+ model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch,
328
+ effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules,
329
+ num_training_steps_per_epoch=num_training_steps_per_epoch,
330
+ vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks)
331
+ train_epoch_time = time.time() - epoch_start_time
332
+ train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time)))
333
+ if args.output_dir:
334
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
335
+ # extra checkpoint before LR drop and every `checkpoint_interval` epochs
336
+ if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
337
+ checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
338
+ for checkpoint_path in checkpoint_paths:
339
+ weights = {
340
+ 'model': model_without_ddp.state_dict(),
341
+ 'optimizer': optimizer.state_dict(),
342
+ 'lr_scheduler': lr_scheduler.state_dict(),
343
+ 'epoch': epoch,
344
+ 'args': args,
345
+ }
346
+ if args.use_ema:
347
+ weights.update({
348
+ 'ema_model': self.ema_m.module.state_dict(),
349
+ })
350
+ if not args.dont_save_weights:
351
+ # create checkpoint dir
352
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
353
+
354
+ utils.save_on_master(weights, checkpoint_path)
355
+
356
+ with torch.inference_mode():
357
+ test_stats, coco_evaluator = evaluate(
358
+ model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
359
+ )
360
+
361
+ map_regular = test_stats['coco_eval_bbox'][0]
362
+ _isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
363
+ if _isbest:
364
+ best_map_5095 = max(best_map_5095, map_regular)
365
+ best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1])
366
+ checkpoint_path = output_dir / 'checkpoint0009.pth'
367
+ if not args.dont_save_weights:
368
+ utils.save_on_master({
369
+ 'model': model_without_ddp.state_dict(),
370
+ 'optimizer': optimizer.state_dict(),
371
+ 'lr_scheduler': lr_scheduler.state_dict(),
372
+ 'epoch': epoch,
373
+ 'args': args,
374
+ }, checkpoint_path)
375
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
376
+ **{f'test_{k}': v for k, v in test_stats.items()},
377
+ 'epoch': epoch,
378
+ 'n_parameters': n_parameters}
379
+ if args.use_ema:
380
+ ema_test_stats, _ = evaluate(
381
+ self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
382
+ )
383
+ log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
384
+ map_ema = ema_test_stats['coco_eval_bbox'][0]
385
+ best_map_ema_5095 = max(best_map_ema_5095, map_ema)
386
+ _isbest = best_map_holder.update(map_ema, epoch, is_ema=True)
387
+ if _isbest:
388
+ best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1])
389
+ checkpoint_path = output_dir / 'checkpoint_best_ema.pth'
390
+ if not args.dont_save_weights:
391
+ utils.save_on_master({
392
+ 'model': self.ema_m.module.state_dict(),
393
+ 'optimizer': optimizer.state_dict(),
394
+ 'lr_scheduler': lr_scheduler.state_dict(),
395
+ 'epoch': epoch,
396
+ 'args': args,
397
+ }, checkpoint_path)
398
+ log_stats.update(best_map_holder.summary())
399
+
400
+ # epoch parameters
401
+ ep_paras = {
402
+ 'epoch': epoch,
403
+ 'n_parameters': n_parameters
404
+ }
405
+ log_stats.update(ep_paras)
406
+ try:
407
+ log_stats.update({'now_time': str(datetime.datetime.now())})
408
+ except:
409
+ pass
410
+ log_stats['train_epoch_time'] = train_epoch_time_str
411
+ epoch_time = time.time() - epoch_start_time
412
+ epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
413
+ log_stats['epoch_time'] = epoch_time_str
414
+ if args.output_dir and utils.is_main_process():
415
+ with (output_dir / "log.txt").open("a") as f:
416
+ f.write(json.dumps(log_stats) + "\n")
417
+
418
+ # for evaluation logs
419
+ if coco_evaluator is not None:
420
+ (output_dir / 'eval').mkdir(exist_ok=True)
421
+ if "bbox" in coco_evaluator.coco_eval:
422
+ filenames = ['latest.pth']
423
+ if epoch % 50 == 0:
424
+ filenames.append(f'{epoch:03}.pth')
425
+ for name in filenames:
426
+ torch.save(coco_evaluator.coco_eval["bbox"].eval,
427
+ output_dir / "eval" / name)
428
+
429
+ for callback in callbacks["on_fit_epoch_end"]:
430
+ callback(log_stats)
431
+
432
+ if self.stop_early:
433
+ print(f"Early stopping requested, stopping at epoch {epoch}")
434
+ break
435
+
436
+ best_is_ema = best_map_ema_5095 > best_map_5095
437
+
438
+ if utils.is_main_process():
439
+ if best_is_ema:
440
+ shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
441
+ else:
442
+ shutil.copy2(output_dir / 'checkpoint0009.pth', output_dir / 'checkpoint_best_total.pth')
443
+
444
+ utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth')
445
+
446
+ best_map_5095 = max(best_map_5095, best_map_ema_5095)
447
+ best_map_50 = max(best_map_50, best_map_ema_50)
448
+
449
+ results_json = {
450
+ "map95": best_map_5095,
451
+ "map50": best_map_50,
452
+ "class": "all"
453
+ }
454
+ results = {
455
+ "class_map": {
456
+ "valid": [
457
+ results_json
458
+ ]
459
+ }
460
+ }
461
+ with open(output_dir / "results.json", "w") as f:
462
+ json.dump(results, f)
463
+
464
+ total_time = time.time() - start_time
465
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
466
+ print('Training time {}'.format(total_time_str))
467
+ print('Results saved to {}'.format(output_dir / "results.json"))
468
+
469
+ if best_is_ema:
470
+ self.model = self.ema_m.module
471
+ self.model.eval()
472
+
473
+ for callback in callbacks["on_train_end"]:
474
+ callback()
475
+
476
+ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs):
477
+ """Export the trained model to ONNX format"""
478
+ print(f"Exporting model to ONNX format")
479
+ try:
480
+ from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image
481
+ except ImportError:
482
+ print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.")
483
+ raise
484
+
485
+
486
+ device = self.device
487
+ model = deepcopy(self.model.to("cpu"))
488
+ model.to(device)
489
+
490
+ os.makedirs(output_dir, exist_ok=True)
491
+ output_dir = Path(output_dir)
492
+ if shape is None:
493
+ shape = (self.resolution, self.resolution)
494
+ else:
495
+ if shape[0] % 14 != 0 or shape[1] % 14 != 0:
496
+ raise ValueError("Shape must be divisible by 14")
497
+
498
+ input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
499
+ input_names = ['input']
500
+ output_names = ['features'] if backbone_only else ['dets', 'labels']
501
+ dynamic_axes = None
502
+ self.model.eval()
503
+ with torch.no_grad():
504
+ if backbone_only:
505
+ features = model(input_tensors)
506
+ print(f"PyTorch inference output shape: {features.shape}")
507
+ else:
508
+ outputs = model(input_tensors)
509
+ dets = outputs['pred_boxes']
510
+ labels = outputs['pred_logits']
511
+ print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
512
+ model.cpu()
513
+ input_tensors = input_tensors.cpu()
514
+
515
+ # Export to ONNX
516
+ output_file = export_onnx(
517
+ output_dir=output_dir,
518
+ model=model,
519
+ input_names=input_names,
520
+ input_tensors=input_tensors,
521
+ output_names=output_names,
522
+ dynamic_axes=dynamic_axes,
523
+ backbone_only=backbone_only,
524
+ verbose=verbose,
525
+ opset_version=opset_version
526
+ )
527
+
528
+ print(f"Successfully exported ONNX model to: {output_file}")
529
+
530
+ if simplify:
531
+ sim_output_file = onnx_simplify(
532
+ onnx_dir=output_file,
533
+ input_names=input_names,
534
+ input_tensors=input_tensors,
535
+ force=force
536
+ )
537
+ print(f"Successfully simplified ONNX model to: {sim_output_file}")
538
+
539
+ print("ONNX export completed successfully")
540
+ self.model = self.model.to(device)
541
+
542
+
543
+ if __name__ == '__main__':
544
+ parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()])
545
+ args = parser.parse_args()
546
+
547
+ if args.output_dir:
548
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
549
+
550
+ config = vars(args) # Convert Namespace to dictionary
551
+
552
+ if args.subcommand == 'distill':
553
+ distill(**config)
554
+ elif args.subcommand is None:
555
+ main(**config)
556
+ elif args.subcommand == 'export_model':
557
+ filter_keys = [
558
+ "num_classes",
559
+ "grad_accum_steps",
560
+ "lr",
561
+ "lr_encoder",
562
+ "weight_decay",
563
+ "epochs",
564
+ "lr_drop",
565
+ "clip_max_norm",
566
+ "lr_vit_layer_decay",
567
+ "lr_component_decay",
568
+ "dropout",
569
+ "drop_path",
570
+ "drop_mode",
571
+ "drop_schedule",
572
+ "cutoff_epoch",
573
+ "pretrained_encoder",
574
+ "pretrain_weights",
575
+ "pretrain_exclude_keys",
576
+ "pretrain_keys_modify_to_load",
577
+ "freeze_florence",
578
+ "freeze_aimv2",
579
+ "decoder_norm",
580
+ "set_cost_class",
581
+ "set_cost_bbox",
582
+ "set_cost_giou",
583
+ "cls_loss_coef",
584
+ "bbox_loss_coef",
585
+ "giou_loss_coef",
586
+ "focal_alpha",
587
+ "aux_loss",
588
+ "sum_group_losses",
589
+ "use_varifocal_loss",
590
+ "use_position_supervised_loss",
591
+ "ia_bce_loss",
592
+ "dataset_file",
593
+ "coco_path",
594
+ "dataset_dir",
595
+ "square_resize_div_64",
596
+ "output_dir",
597
+ "checkpoint_interval",
598
+ "seed",
599
+ "resume",
600
+ "start_epoch",
601
+ "eval",
602
+ "use_ema",
603
+ "ema_decay",
604
+ "ema_tau",
605
+ "num_workers",
606
+ "device",
607
+ "world_size",
608
+ "dist_url",
609
+ "sync_bn",
610
+ "fp16_eval",
611
+ "infer_dir",
612
+ "verbose",
613
+ "opset_version",
614
+ "dry_run",
615
+ "shape",
616
+ ]
617
+ for key in filter_keys:
618
+ config.pop(key, None) # Use pop with None to avoid KeyError
619
+
620
+ from deploy.export import main as export_main
621
+ if args.batch_size != 1:
622
+ config['batch_size'] = 1
623
+ print(f"Only batch_size 1 is supported for onnx export, \
624
+ but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.")
625
+ export_main(**config)
626
+
627
+ def get_args_parser():
628
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
629
+ parser.add_argument('--num_classes', default=2, type=int)
630
+ parser.add_argument('--grad_accum_steps', default=1, type=int)
631
+ parser.add_argument('--amp', default=False, type=bool)
632
+ parser.add_argument('--lr', default=1e-4, type=float)
633
+ parser.add_argument('--lr_encoder', default=1.5e-4, type=float)
634
+ parser.add_argument('--batch_size', default=2, type=int)
635
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
636
+ parser.add_argument('--epochs', default=12, type=int)
637
+ parser.add_argument('--lr_drop', default=11, type=int)
638
+ parser.add_argument('--clip_max_norm', default=0.1, type=float,
639
+ help='gradient clipping max norm')
640
+ parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float)
641
+ parser.add_argument('--lr_component_decay', default=1.0, type=float)
642
+ parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model')
643
+
644
+ # drop args
645
+ # dropout and stochastic depth drop rate; set at most one to non-zero
646
+ parser.add_argument('--dropout', type=float, default=0,
647
+ help='Drop path rate (default: 0.0)')
648
+ parser.add_argument('--drop_path', type=float, default=0,
649
+ help='Drop path rate (default: 0.0)')
650
+
651
+ # early / late dropout and stochastic depth settings
652
+ parser.add_argument('--drop_mode', type=str, default='standard',
653
+ choices=['standard', 'early', 'late'], help='drop mode')
654
+ parser.add_argument('--drop_schedule', type=str, default='constant',
655
+ choices=['constant', 'linear'],
656
+ help='drop schedule for early dropout / s.d. only')
657
+ parser.add_argument('--cutoff_epoch', type=int, default=0,
658
+ help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
659
+
660
+ # Model parameters
661
+ parser.add_argument('--pretrained_encoder', type=str, default=None,
662
+ help="Path to the pretrained encoder.")
663
+ parser.add_argument('--pretrain_weights', type=str, default=None,
664
+ help="Path to the pretrained model.")
665
+ parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
666
+ help="Keys you do not want to load.")
667
+ parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+',
668
+ help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.")
669
+
670
+ # * Backbone
671
+ parser.add_argument('--encoder', default='vit_tiny', type=str,
672
+ help="Name of the transformer or convolutional encoder to use")
673
+ parser.add_argument('--vit_encoder_num_layers', default=12, type=int,
674
+ help="Number of layers used in ViT encoder")
675
+ parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+')
676
+ parser.add_argument('--position_embedding', default='sine', type=str,
677
+ choices=('sine', 'learned'),
678
+ help="Type of positional embedding to use on top of the image features")
679
+ parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now')
680
+ parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder")
681
+ parser.add_argument("--layer_norm", action="store_true", dest="layer_norm")
682
+ parser.add_argument("--rms_norm", action="store_true", dest="rms_norm")
683
+ parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora")
684
+ parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain")
685
+
686
+ # * Transformer
687
+ parser.add_argument('--dec_layers', default=3, type=int,
688
+ help="Number of decoding layers in the transformer")
689
+ parser.add_argument('--dim_feedforward', default=2048, type=int,
690
+ help="Intermediate size of the feedforward layers in the transformer blocks")
691
+ parser.add_argument('--hidden_dim', default=256, type=int,
692
+ help="Size of the embeddings (dimension of the transformer)")
693
+ parser.add_argument('--sa_nheads', default=8, type=int,
694
+ help="Number of attention heads inside the transformer's self-attentions")
695
+ parser.add_argument('--ca_nheads', default=8, type=int,
696
+ help="Number of attention heads inside the transformer's cross-attentions")
697
+ parser.add_argument('--num_queries', default=300, type=int,
698
+ help="Number of query slots")
699
+ parser.add_argument('--group_detr', default=13, type=int,
700
+ help="Number of groups to speed up detr training")
701
+ parser.add_argument('--two_stage', action='store_true')
702
+ parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6'))
703
+ parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up')
704
+ parser.add_argument('--num_select', default=100, type=int,
705
+ help='the number of predictions selected for evaluation')
706
+ parser.add_argument('--dec_n_points', default=4, type=int,
707
+ help='the number of sampling points')
708
+ parser.add_argument('--decoder_norm', default='LN', type=str)
709
+ parser.add_argument('--bbox_reparam', action='store_true')
710
+ parser.add_argument('--freeze_batch_norm', action='store_true')
711
+ # * Matcher
712
+ parser.add_argument('--set_cost_class', default=2, type=float,
713
+ help="Class coefficient in the matching cost")
714
+ parser.add_argument('--set_cost_bbox', default=5, type=float,
715
+ help="L1 box coefficient in the matching cost")
716
+ parser.add_argument('--set_cost_giou', default=2, type=float,
717
+ help="giou box coefficient in the matching cost")
718
+
719
+ # * Loss coefficients
720
+ parser.add_argument('--cls_loss_coef', default=2, type=float)
721
+ parser.add_argument('--bbox_loss_coef', default=5, type=float)
722
+ parser.add_argument('--giou_loss_coef', default=2, type=float)
723
+ parser.add_argument('--focal_alpha', default=0.25, type=float)
724
+
725
+ # Loss
726
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
727
+ help="Disables auxiliary decoding losses (loss at each layer)")
728
+ parser.add_argument('--sum_group_losses', action='store_true',
729
+ help="To sum losses across groups or mean losses.")
730
+ parser.add_argument('--use_varifocal_loss', action='store_true')
731
+ parser.add_argument('--use_position_supervised_loss', action='store_true')
732
+ parser.add_argument('--ia_bce_loss', action='store_true')
733
+
734
+ # dataset parameters
735
+ parser.add_argument('--dataset_file', default='coco')
736
+ parser.add_argument('--coco_path', type=str)
737
+ parser.add_argument('--dataset_dir', type=str)
738
+ parser.add_argument('--square_resize_div_64', action='store_true')
739
+
740
+ parser.add_argument('--output_dir', default='output',
741
+ help='path where to save, empty for no saving')
742
+ parser.add_argument('--dont_save_weights', action='store_true')
743
+ parser.add_argument('--checkpoint_interval', default=10, type=int,
744
+ help='epoch interval to save checkpoint')
745
+ parser.add_argument('--seed', default=42, type=int)
746
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
747
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
748
+ help='start epoch')
749
+ parser.add_argument('--eval', action='store_true')
750
+ parser.add_argument('--use_ema', action='store_true')
751
+ parser.add_argument('--ema_decay', default=0.9997, type=float)
752
+ parser.add_argument('--ema_tau', default=0, type=float)
753
+
754
+ parser.add_argument('--num_workers', default=2, type=int)
755
+
756
+ # distributed training parameters
757
+ parser.add_argument('--device', default='cuda',
758
+ help='device to use for training / testing')
759
+ parser.add_argument('--world_size', default=1, type=int,
760
+ help='number of distributed processes')
761
+ parser.add_argument('--dist_url', default='env://',
762
+ help='url used to set up distributed training')
763
+ parser.add_argument('--sync_bn', default=True, type=bool,
764
+ help='setup synchronized BatchNorm for distributed training')
765
+
766
+ # fp16
767
+ parser.add_argument('--fp16_eval', default=False, action='store_true',
768
+ help='evaluate in fp16 precision.')
769
+
770
+ # custom args
771
+ parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only')
772
+ parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only')
773
+ parser.add_argument('--resolution', type=int, default=640, help="input resolution")
774
+ parser.add_argument('--use_cls_token', action='store_true', help='use cls token')
775
+ parser.add_argument('--multi_scale', action='store_true', help='use multi scale')
776
+ parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales')
777
+ parser.add_argument('--warmup_epochs', default=1, type=float,
778
+ help='Number of warmup epochs for linear warmup before cosine annealing')
779
+ # Add scheduler type argument: 'step' or 'cosine'
780
+ parser.add_argument(
781
+ '--lr_scheduler',
782
+ default='step',
783
+ choices=['step', 'cosine'],
784
+ help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'"
785
+ )
786
+ parser.add_argument('--lr_min_factor', default=0.0, type=float,
787
+ help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
788
+ # Early stopping parameters
789
+ parser.add_argument('--early_stopping', action='store_true',
790
+ help='Enable early stopping based on mAP improvement')
791
+ parser.add_argument('--early_stopping_patience', default=10, type=int,
792
+ help='Number of epochs with no improvement after which training will be stopped')
793
+ parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
794
+ help='Minimum change in mAP to qualify as an improvement')
795
+ parser.add_argument('--early_stopping_use_ema', action='store_true',
796
+ help='Use EMA model metrics for early stopping')
797
+ # subparsers
798
+ subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
799
+ description='valid subcommands', help='additional help')
800
+
801
+ # subparser for export model
802
+ parser_export = subparsers.add_parser('export_model', help='LWDETR model export')
803
+ parser_export.add_argument('--infer_dir', type=str, default=None)
804
+ parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True)
805
+ parser_export.add_argument('--opset_version', type=int, default=17)
806
+ parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model")
807
+ parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true',
808
+ help="build tensorrt engine")
809
+ parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command")
810
+ parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export')
811
+ parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)")
812
+ return parser
813
+
814
+ def populate_args(
815
+ # Basic training parameters
816
+ num_classes=2,
817
+ grad_accum_steps=1,
818
+ amp=False,
819
+ lr=1e-4,
820
+ lr_encoder=1.5e-4,
821
+ batch_size=2,
822
+ weight_decay=1e-4,
823
+ epochs=12,
824
+ lr_drop=11,
825
+ clip_max_norm=0.1,
826
+ lr_vit_layer_decay=0.8,
827
+ lr_component_decay=1.0,
828
+ do_benchmark=False,
829
+
830
+ # Drop parameters
831
+ dropout=0,
832
+ drop_path=0,
833
+ drop_mode='standard',
834
+ drop_schedule='constant',
835
+ cutoff_epoch=0,
836
+
837
+ # Model parameters
838
+ pretrained_encoder=None,
839
+ pretrain_weights=None,
840
+ pretrain_exclude_keys=None,
841
+ pretrain_keys_modify_to_load=None,
842
+ pretrained_distiller=None,
843
+
844
+ # Backbone parameters
845
+ encoder='vit_tiny',
846
+ vit_encoder_num_layers=12,
847
+ window_block_indexes=None,
848
+ position_embedding='sine',
849
+ out_feature_indexes=[-1],
850
+ freeze_encoder=False,
851
+ layer_norm=False,
852
+ rms_norm=False,
853
+ backbone_lora=False,
854
+ force_no_pretrain=False,
855
+
856
+ # Transformer parameters
857
+ dec_layers=3,
858
+ dim_feedforward=2048,
859
+ hidden_dim=256,
860
+ sa_nheads=8,
861
+ ca_nheads=8,
862
+ num_queries=300,
863
+ group_detr=13,
864
+ two_stage=False,
865
+ projector_scale='P4',
866
+ lite_refpoint_refine=False,
867
+ num_select=100,
868
+ dec_n_points=4,
869
+ decoder_norm='LN',
870
+ bbox_reparam=False,
871
+ freeze_batch_norm=False,
872
+
873
+ # Matcher parameters
874
+ set_cost_class=2,
875
+ set_cost_bbox=5,
876
+ set_cost_giou=2,
877
+
878
+ # Loss coefficients
879
+ cls_loss_coef=2,
880
+ bbox_loss_coef=5,
881
+ giou_loss_coef=2,
882
+ focal_alpha=0.25,
883
+ aux_loss=True,
884
+ sum_group_losses=False,
885
+ use_varifocal_loss=False,
886
+ use_position_supervised_loss=False,
887
+ ia_bce_loss=False,
888
+
889
+ # Dataset parameters
890
+ dataset_file='coco',
891
+ coco_path=None,
892
+ dataset_dir=None,
893
+ square_resize_div_64=False,
894
+
895
+ # Output parameters
896
+ output_dir='output',
897
+ dont_save_weights=False,
898
+ checkpoint_interval=10,
899
+ seed=42,
900
+ resume='',
901
+ start_epoch=0,
902
+ eval=False,
903
+ use_ema=False,
904
+ ema_decay=0.9997,
905
+ ema_tau=0,
906
+ num_workers=2,
907
+
908
+ # Distributed training parameters
909
+ device='cuda',
910
+ world_size=1,
911
+ dist_url='env://',
912
+ sync_bn=True,
913
+
914
+ # FP16
915
+ fp16_eval=False,
916
+
917
+ # Custom args
918
+ encoder_only=False,
919
+ backbone_only=False,
920
+ resolution=640,
921
+ use_cls_token=False,
922
+ multi_scale=False,
923
+ expanded_scales=False,
924
+ warmup_epochs=1,
925
+ lr_scheduler='step',
926
+ lr_min_factor=0.0,
927
+ # Early stopping parameters
928
+ early_stopping=True,
929
+ early_stopping_patience=10,
930
+ early_stopping_min_delta=0.001,
931
+ early_stopping_use_ema=False,
932
+ gradient_checkpointing=False,
933
+ # Additional
934
+ subcommand=None,
935
+ **extra_kwargs # To handle any unexpected arguments
936
+ ):
937
+ args = argparse.Namespace(
938
+ num_classes=num_classes,
939
+ grad_accum_steps=grad_accum_steps,
940
+ amp=amp,
941
+ lr=lr,
942
+ lr_encoder=lr_encoder,
943
+ batch_size=batch_size,
944
+ weight_decay=weight_decay,
945
+ epochs=epochs,
946
+ lr_drop=lr_drop,
947
+ clip_max_norm=clip_max_norm,
948
+ lr_vit_layer_decay=lr_vit_layer_decay,
949
+ lr_component_decay=lr_component_decay,
950
+ do_benchmark=do_benchmark,
951
+ dropout=dropout,
952
+ drop_path=drop_path,
953
+ drop_mode=drop_mode,
954
+ drop_schedule=drop_schedule,
955
+ cutoff_epoch=cutoff_epoch,
956
+ pretrained_encoder=pretrained_encoder,
957
+ pretrain_weights=pretrain_weights,
958
+ pretrain_exclude_keys=pretrain_exclude_keys,
959
+ pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
960
+ pretrained_distiller=pretrained_distiller,
961
+ encoder=encoder,
962
+ vit_encoder_num_layers=vit_encoder_num_layers,
963
+ window_block_indexes=window_block_indexes,
964
+ position_embedding=position_embedding,
965
+ out_feature_indexes=out_feature_indexes,
966
+ freeze_encoder=freeze_encoder,
967
+ layer_norm=layer_norm,
968
+ rms_norm=rms_norm,
969
+ backbone_lora=backbone_lora,
970
+ force_no_pretrain=force_no_pretrain,
971
+ dec_layers=dec_layers,
972
+ dim_feedforward=dim_feedforward,
973
+ hidden_dim=hidden_dim,
974
+ sa_nheads=sa_nheads,
975
+ ca_nheads=ca_nheads,
976
+ num_queries=num_queries,
977
+ group_detr=group_detr,
978
+ two_stage=two_stage,
979
+ projector_scale=projector_scale,
980
+ lite_refpoint_refine=lite_refpoint_refine,
981
+ num_select=num_select,
982
+ dec_n_points=dec_n_points,
983
+ decoder_norm=decoder_norm,
984
+ bbox_reparam=bbox_reparam,
985
+ freeze_batch_norm=freeze_batch_norm,
986
+ set_cost_class=set_cost_class,
987
+ set_cost_bbox=set_cost_bbox,
988
+ set_cost_giou=set_cost_giou,
989
+ cls_loss_coef=cls_loss_coef,
990
+ bbox_loss_coef=bbox_loss_coef,
991
+ giou_loss_coef=giou_loss_coef,
992
+ focal_alpha=focal_alpha,
993
+ aux_loss=aux_loss,
994
+ sum_group_losses=sum_group_losses,
995
+ use_varifocal_loss=use_varifocal_loss,
996
+ use_position_supervised_loss=use_position_supervised_loss,
997
+ ia_bce_loss=ia_bce_loss,
998
+ dataset_file=dataset_file,
999
+ coco_path=coco_path,
1000
+ dataset_dir=dataset_dir,
1001
+ square_resize_div_64=square_resize_div_64,
1002
+ output_dir=output_dir,
1003
+ dont_save_weights=dont_save_weights,
1004
+ checkpoint_interval=checkpoint_interval,
1005
+ seed=seed,
1006
+ resume=resume,
1007
+ start_epoch=start_epoch,
1008
+ eval=eval,
1009
+ use_ema=use_ema,
1010
+ ema_decay=ema_decay,
1011
+ ema_tau=ema_tau,
1012
+ num_workers=num_workers,
1013
+ device=device,
1014
+ world_size=world_size,
1015
+ dist_url=dist_url,
1016
+ sync_bn=sync_bn,
1017
+ fp16_eval=fp16_eval,
1018
+ encoder_only=encoder_only,
1019
+ backbone_only=backbone_only,
1020
+ resolution=resolution,
1021
+ use_cls_token=use_cls_token,
1022
+ multi_scale=multi_scale,
1023
+ expanded_scales=expanded_scales,
1024
+ warmup_epochs=warmup_epochs,
1025
+ lr_scheduler=lr_scheduler,
1026
+ lr_min_factor=lr_min_factor,
1027
+ early_stopping=early_stopping,
1028
+ early_stopping_patience=early_stopping_patience,
1029
+ early_stopping_min_delta=early_stopping_min_delta,
1030
+ early_stopping_use_ema=early_stopping_use_ema,
1031
+ gradient_checkpointing=gradient_checkpointing,
1032
+ **extra_kwargs
1033
+ )
1034
+ return args
rfdetr/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Copied from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ from .lwdetr import build_model, build_criterion_and_postprocessors
rfdetr/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (277 Bytes). View file