Matis Despujols commited on
Commit
066effd
·
verified ·
1 Parent(s): aa2b37d

Upload 97 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. rfdetr/__init__.py +12 -0
  2. rfdetr/__pycache__/__init__.cpython-313.pyc +0 -0
  3. rfdetr/__pycache__/config.cpython-313.pyc +0 -0
  4. rfdetr/__pycache__/detr.cpython-313.pyc +0 -0
  5. rfdetr/__pycache__/engine.cpython-313.pyc +0 -0
  6. rfdetr/__pycache__/main.cpython-313.pyc +0 -0
  7. rfdetr/cli/__pycache__/main.cpython-313.pyc +0 -0
  8. rfdetr/cli/main.py +87 -0
  9. rfdetr/config.py +142 -0
  10. rfdetr/datasets/__init__.py +36 -0
  11. rfdetr/datasets/__pycache__/__init__.cpython-313.pyc +0 -0
  12. rfdetr/datasets/__pycache__/coco.cpython-313.pyc +0 -0
  13. rfdetr/datasets/__pycache__/coco_eval.cpython-313.pyc +0 -0
  14. rfdetr/datasets/__pycache__/o365.cpython-313.pyc +0 -0
  15. rfdetr/datasets/__pycache__/transforms.cpython-313.pyc +0 -0
  16. rfdetr/datasets/coco.py +280 -0
  17. rfdetr/datasets/coco_eval.py +271 -0
  18. rfdetr/datasets/o365.py +53 -0
  19. rfdetr/datasets/transforms.py +475 -0
  20. rfdetr/deploy/__init__.py +0 -0
  21. rfdetr/deploy/__pycache__/__init__.cpython-313.pyc +0 -0
  22. rfdetr/deploy/__pycache__/benchmark.cpython-313.pyc +0 -0
  23. rfdetr/deploy/__pycache__/export.cpython-313.pyc +0 -0
  24. rfdetr/deploy/_onnx/__init__.py +13 -0
  25. rfdetr/deploy/_onnx/__pycache__/__init__.cpython-313.pyc +0 -0
  26. rfdetr/deploy/_onnx/__pycache__/optimizer.cpython-313.pyc +0 -0
  27. rfdetr/deploy/_onnx/__pycache__/symbolic.cpython-313.pyc +0 -0
  28. rfdetr/deploy/_onnx/optimizer.py +579 -0
  29. rfdetr/deploy/_onnx/symbolic.py +37 -0
  30. rfdetr/deploy/benchmark.py +590 -0
  31. rfdetr/deploy/export.py +276 -0
  32. rfdetr/detr.py +451 -0
  33. rfdetr/engine.py +340 -0
  34. rfdetr/main.py +1062 -0
  35. rfdetr/models/__init__.py +16 -0
  36. rfdetr/models/__pycache__/__init__.cpython-313.pyc +0 -0
  37. rfdetr/models/__pycache__/lwdetr.cpython-313.pyc +0 -0
  38. rfdetr/models/__pycache__/matcher.cpython-313.pyc +0 -0
  39. rfdetr/models/__pycache__/position_encoding.cpython-313.pyc +0 -0
  40. rfdetr/models/__pycache__/transformer.cpython-313.pyc +0 -0
  41. rfdetr/models/backbone/__init__.py +110 -0
  42. rfdetr/models/backbone/__pycache__/__init__.cpython-313.pyc +0 -0
  43. rfdetr/models/backbone/__pycache__/backbone.cpython-313.pyc +0 -0
  44. rfdetr/models/backbone/__pycache__/base.cpython-313.pyc +0 -0
  45. rfdetr/models/backbone/__pycache__/dinov2.cpython-313.pyc +0 -0
  46. rfdetr/models/backbone/__pycache__/dinov2_with_windowed_attn.cpython-313.pyc +0 -0
  47. rfdetr/models/backbone/__pycache__/projector.cpython-313.pyc +0 -0
  48. rfdetr/models/backbone/backbone.py +205 -0
  49. rfdetr/models/backbone/base.py +20 -0
  50. rfdetr/models/backbone/dinov2.py +197 -0
rfdetr/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ import os
9
+ if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
10
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11
+
12
+ from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRNano, RFDETRSmall, RFDETRMedium
rfdetr/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (530 Bytes). View file
 
rfdetr/__pycache__/config.cpython-313.pyc ADDED
Binary file (7.12 kB). View file
 
rfdetr/__pycache__/detr.cpython-313.pyc ADDED
Binary file (22.4 kB). View file
 
rfdetr/__pycache__/engine.cpython-313.pyc ADDED
Binary file (17.6 kB). View file
 
rfdetr/__pycache__/main.cpython-313.pyc ADDED
Binary file (47.4 kB). View file
 
rfdetr/cli/__pycache__/main.cpython-313.pyc ADDED
Binary file (4.16 kB). View file
 
rfdetr/cli/main.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ import argparse
11
+ from rf100vl import get_rf100vl_projects
12
+ import roboflow
13
+ from rfdetr import RFDETRBase
14
+ import torch
15
+ import os
16
+
17
+ def download_dataset(rf_project: roboflow.Project, dataset_version: int):
18
+ versions = rf_project.versions()
19
+ if dataset_version is not None:
20
+ versions = [v for v in versions if v.version == str(dataset_version)]
21
+ if len(versions) == 0:
22
+ raise ValueError(f"Dataset version {dataset_version} not found")
23
+ version = versions[0]
24
+ else:
25
+ version = max(versions, key=lambda v: v.id)
26
+ location = os.path.join("datasets/", rf_project.name + "_v" + version.version)
27
+ if not os.path.exists(location):
28
+ location = version.download(
29
+ model_format="coco", location=location, overwrite=False
30
+ ).location
31
+
32
+ return location
33
+
34
+
35
+ def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int):
36
+ location = download_dataset(rf_project, dataset_version)
37
+ print(location)
38
+ rf_detr = RFDETRBase()
39
+ device_supports_cuda = torch.cuda.is_available()
40
+ rf_detr.train(
41
+ dataset_dir=location,
42
+ epochs=1,
43
+ device="cuda" if device_supports_cuda else "cpu",
44
+ )
45
+
46
+
47
+ def train_from_coco_dir(coco_dir: str):
48
+ rf_detr = RFDETRBase()
49
+ rf_detr.train(
50
+ dataset_dir=coco_dir,
51
+ epochs=1,
52
+ device="cuda" if device_supports_cuda else "cpu",
53
+ )
54
+
55
+
56
+ def trainer():
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--coco_dir", type=str, required=False)
59
+ parser.add_argument("--api_key", type=str, required=False)
60
+ parser.add_argument("--workspace", type=str, required=False, default=None)
61
+ parser.add_argument("--project_name", type=str, required=False, default=None)
62
+ parser.add_argument("--dataset_version", type=int, required=False, default=None)
63
+ args = parser.parse_args()
64
+
65
+ if args.coco_dir is not None:
66
+ train_from_coco_dir(args.coco_dir)
67
+ return
68
+
69
+ if (args.workspace is None and args.project_name is not None) or (
70
+ args.workspace is not None and args.project_name is None
71
+ ):
72
+ raise ValueError(
73
+ "Either both workspace and project_name must be provided or none of them"
74
+ )
75
+
76
+ if args.workspace is not None:
77
+ rf = roboflow.Roboflow(api_key=args.api_key)
78
+ project = rf.workspace(args.workspace).project(args.project_name)
79
+ else:
80
+ projects = get_rf100vl_projects(api_key=args.api_key)
81
+ project = projects[0].rf_project
82
+
83
+ train_from_rf_project(project, args.dataset_version)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ trainer()
rfdetr/config.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ from pydantic import BaseModel
9
+ from typing import List, Optional, Literal, Type
10
+ import torch
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12
+
13
+ class ModelConfig(BaseModel):
14
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
15
+ out_feature_indexes: List[int]
16
+ dec_layers: int
17
+ two_stage: bool = True
18
+ projector_scale: List[Literal["P3", "P4", "P5"]]
19
+ hidden_dim: int
20
+ patch_size: int
21
+ num_windows: int
22
+ sa_nheads: int
23
+ ca_nheads: int
24
+ dec_n_points: int
25
+ bbox_reparam: bool = True
26
+ lite_refpoint_refine: bool = True
27
+ layer_norm: bool = True
28
+ amp: bool = True
29
+ num_classes: int = 90
30
+ pretrain_weights: Optional[str] = None
31
+ device: Literal["cpu", "cuda", "mps"] = DEVICE
32
+ resolution: int
33
+ group_detr: int = 13
34
+ gradient_checkpointing: bool = False
35
+ positional_encoding_size: int
36
+
37
+ class RFDETRBaseConfig(ModelConfig):
38
+ """
39
+ The configuration for an RF-DETR Base model.
40
+ """
41
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
42
+ hidden_dim: int = 256
43
+ patch_size: int = 14
44
+ num_windows: int = 4
45
+ dec_layers: int = 3
46
+ sa_nheads: int = 8
47
+ ca_nheads: int = 16
48
+ dec_n_points: int = 2
49
+ num_queries: int = 300
50
+ num_select: int = 300
51
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
52
+ out_feature_indexes: List[int] = [2, 5, 8, 11]
53
+ pretrain_weights: Optional[str] = "rf-detr-base.pth"
54
+ resolution: int = 560
55
+ positional_encoding_size: int = 37
56
+
57
+ class RFDETRLargeConfig(RFDETRBaseConfig):
58
+ """
59
+ The configuration for an RF-DETR Large model.
60
+ """
61
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
62
+ hidden_dim: int = 384
63
+ sa_nheads: int = 12
64
+ ca_nheads: int = 24
65
+ dec_n_points: int = 4
66
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
67
+ pretrain_weights: Optional[str] = "rf-detr-large.pth"
68
+
69
+ class RFDETRNanoConfig(RFDETRBaseConfig):
70
+ """
71
+ The configuration for an RF-DETR Nano model.
72
+ """
73
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
74
+ num_windows: int = 2
75
+ dec_layers: int = 2
76
+ patch_size: int = 16
77
+ resolution: int = 384
78
+ positional_encoding_size: int = 24
79
+ pretrain_weights: Optional[str] = "rf-detr-nano.pth"
80
+
81
+ class RFDETRSmallConfig(RFDETRBaseConfig):
82
+ """
83
+ The configuration for an RF-DETR Small model.
84
+ """
85
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
86
+ num_windows: int = 2
87
+ dec_layers: int = 3
88
+ patch_size: int = 16
89
+ resolution: int = 512
90
+ positional_encoding_size: int = 32
91
+ pretrain_weights: Optional[str] = "rf-detr-small.pth"
92
+
93
+ class RFDETRMediumConfig(RFDETRBaseConfig):
94
+ """
95
+ The configuration for an RF-DETR Medium model.
96
+ """
97
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
98
+ num_windows: int = 2
99
+ dec_layers: int = 4
100
+ patch_size: int = 16
101
+ resolution: int = 576
102
+ positional_encoding_size: int = 36
103
+ pretrain_weights: Optional[str] = "rf-detr-medium.pth"
104
+
105
+ class TrainConfig(BaseModel):
106
+ lr: float = 1e-4
107
+ lr_encoder: float = 1.5e-4
108
+ batch_size: int = 4
109
+ grad_accum_steps: int = 4
110
+ epochs: int = 100
111
+ ema_decay: float = 0.993
112
+ ema_tau: int = 100
113
+ lr_drop: int = 100
114
+ checkpoint_interval: int = 10
115
+ warmup_epochs: int = 0
116
+ lr_vit_layer_decay: float = 0.8
117
+ lr_component_decay: float = 0.7
118
+ drop_path: float = 0.0
119
+ group_detr: int = 13
120
+ ia_bce_loss: bool = True
121
+ cls_loss_coef: float = 1.0
122
+ num_select: int = 300
123
+ dataset_file: Literal["coco", "o365", "roboflow"] = "roboflow"
124
+ square_resize_div_64: bool = True
125
+ dataset_dir: str
126
+ output_dir: str = "output"
127
+ multi_scale: bool = True
128
+ expanded_scales: bool = True
129
+ do_random_resize_via_padding: bool = False
130
+ use_ema: bool = True
131
+ num_workers: int = 2
132
+ weight_decay: float = 1e-4
133
+ early_stopping: bool = False
134
+ early_stopping_patience: int = 10
135
+ early_stopping_min_delta: float = 0.001
136
+ early_stopping_use_ema: bool = False
137
+ tensorboard: bool = True
138
+ wandb: bool = False
139
+ project: Optional[str] = None
140
+ run: Optional[str] = None
141
+ class_names: List[str] = None
142
+ run_test: bool = True
rfdetr/datasets/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # LW-DETR
3
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from DETR (https://github.com/facebookresearch/detr)
10
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+
13
+ import torch.utils.data
14
+ import torchvision
15
+
16
+ from .coco import build as build_coco
17
+ from .o365 import build_o365
18
+ from .coco import build_roboflow
19
+
20
+
21
+ def get_coco_api_from_dataset(dataset):
22
+ for _ in range(10):
23
+ if isinstance(dataset, torch.utils.data.Subset):
24
+ dataset = dataset.dataset
25
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
26
+ return dataset.coco
27
+
28
+
29
+ def build_dataset(image_set, args, resolution):
30
+ if args.dataset_file == 'coco':
31
+ return build_coco(image_set, args, resolution)
32
+ if args.dataset_file == 'o365':
33
+ return build_o365(image_set, args, resolution)
34
+ if args.dataset_file == 'roboflow':
35
+ return build_roboflow(image_set, args, resolution)
36
+ raise ValueError(f'dataset {args.dataset_file} not supported')
rfdetr/datasets/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.53 kB). View file
 
rfdetr/datasets/__pycache__/coco.cpython-313.pyc ADDED
Binary file (11 kB). View file
 
rfdetr/datasets/__pycache__/coco_eval.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
rfdetr/datasets/__pycache__/o365.cpython-313.pyc ADDED
Binary file (1.93 kB). View file
 
rfdetr/datasets/__pycache__/transforms.cpython-313.pyc ADDED
Binary file (23.9 kB). View file
 
rfdetr/datasets/coco.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ COCO dataset which returns image_id for evaluation.
18
+
19
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
20
+ """
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ import torch.utils.data
25
+ import torchvision
26
+
27
+ import rfdetr.datasets.transforms as T
28
+
29
+
30
+ def compute_multi_scale_scales(resolution, expanded_scales=False, patch_size=16, num_windows=4):
31
+ # round to the nearest multiple of 4*patch_size to enable both patching and windowing
32
+ base_num_patches_per_window = resolution // (patch_size * num_windows)
33
+ offsets = [-3, -2, -1, 0, 1, 2, 3, 4] if not expanded_scales else [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
34
+ scales = [base_num_patches_per_window + offset for offset in offsets]
35
+ proposed_scales = [scale * patch_size * num_windows for scale in scales]
36
+ proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * num_windows * 2] # ensure minimum image size
37
+ return proposed_scales
38
+
39
+
40
+ class CocoDetection(torchvision.datasets.CocoDetection):
41
+ def __init__(self, img_folder, ann_file, transforms):
42
+ super(CocoDetection, self).__init__(img_folder, ann_file)
43
+ self._transforms = transforms
44
+ self.prepare = ConvertCoco()
45
+
46
+ def __getitem__(self, idx):
47
+ img, target = super(CocoDetection, self).__getitem__(idx)
48
+ image_id = self.ids[idx]
49
+ target = {'image_id': image_id, 'annotations': target}
50
+ img, target = self.prepare(img, target)
51
+ if self._transforms is not None:
52
+ img, target = self._transforms(img, target)
53
+ return img, target
54
+
55
+
56
+ class ConvertCoco(object):
57
+
58
+ def __call__(self, image, target):
59
+ w, h = image.size
60
+
61
+ image_id = target["image_id"]
62
+ image_id = torch.tensor([image_id])
63
+
64
+ anno = target["annotations"]
65
+
66
+ anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
67
+
68
+ boxes = [obj["bbox"] for obj in anno]
69
+ # guard against no boxes via resizing
70
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
71
+ boxes[:, 2:] += boxes[:, :2]
72
+ boxes[:, 0::2].clamp_(min=0, max=w)
73
+ boxes[:, 1::2].clamp_(min=0, max=h)
74
+
75
+ classes = [obj["category_id"] for obj in anno]
76
+ classes = torch.tensor(classes, dtype=torch.int64)
77
+
78
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
79
+ boxes = boxes[keep]
80
+ classes = classes[keep]
81
+
82
+ target = {}
83
+ target["boxes"] = boxes
84
+ target["labels"] = classes
85
+ target["image_id"] = image_id
86
+
87
+ # for conversion to coco api
88
+ area = torch.tensor([obj["area"] for obj in anno])
89
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
90
+ target["area"] = area[keep]
91
+ target["iscrowd"] = iscrowd[keep]
92
+
93
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
94
+ target["size"] = torch.as_tensor([int(h), int(w)])
95
+
96
+ return image, target
97
+
98
+
99
+ def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4):
100
+
101
+ normalize = T.Compose([
102
+ T.ToTensor(),
103
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
104
+ ])
105
+
106
+ scales = [resolution]
107
+ if multi_scale:
108
+ # scales = [448, 512, 576, 640, 704, 768, 832, 896]
109
+ scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows)
110
+ if skip_random_resize:
111
+ scales = [scales[-1]]
112
+ print(scales)
113
+
114
+ if image_set == 'train':
115
+ return T.Compose([
116
+ T.RandomHorizontalFlip(),
117
+ T.RandomSelect(
118
+ T.RandomResize(scales, max_size=1333),
119
+ T.Compose([
120
+ T.RandomResize([400, 500, 600]),
121
+ T.RandomSizeCrop(384, 600),
122
+ T.RandomResize(scales, max_size=1333),
123
+ ])
124
+ ),
125
+ normalize,
126
+ ])
127
+
128
+ if image_set == 'val':
129
+ return T.Compose([
130
+ T.RandomResize([resolution], max_size=1333),
131
+ normalize,
132
+ ])
133
+ if image_set == 'val_speed':
134
+ return T.Compose([
135
+ T.SquareResize([resolution]),
136
+ normalize,
137
+ ])
138
+
139
+ raise ValueError(f'unknown {image_set}')
140
+
141
+
142
+ def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4):
143
+ """
144
+ """
145
+
146
+ normalize = T.Compose([
147
+ T.ToTensor(),
148
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
149
+ ])
150
+
151
+
152
+ scales = [resolution]
153
+ if multi_scale:
154
+ # scales = [448, 512, 576, 640, 704, 768, 832, 896]
155
+ scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows)
156
+ if skip_random_resize:
157
+ scales = [scales[-1]]
158
+ print(scales)
159
+
160
+ if image_set == 'train':
161
+ return T.Compose([
162
+ T.RandomHorizontalFlip(),
163
+ T.RandomSelect(
164
+ T.SquareResize(scales),
165
+ T.Compose([
166
+ T.RandomResize([400, 500, 600]),
167
+ T.RandomSizeCrop(384, 600),
168
+ T.SquareResize(scales),
169
+ ]),
170
+ ),
171
+ normalize,
172
+ ])
173
+
174
+ if image_set == 'val':
175
+ return T.Compose([
176
+ T.SquareResize([resolution]),
177
+ normalize,
178
+ ])
179
+ if image_set == 'test':
180
+ return T.Compose([
181
+ T.SquareResize([resolution]),
182
+ normalize,
183
+ ])
184
+ if image_set == 'val_speed':
185
+ return T.Compose([
186
+ T.SquareResize([resolution]),
187
+ normalize,
188
+ ])
189
+
190
+ raise ValueError(f'unknown {image_set}')
191
+
192
+ def build(image_set, args, resolution):
193
+ root = Path(args.coco_path)
194
+ assert root.exists(), f'provided COCO path {root} does not exist'
195
+ mode = 'instances'
196
+ PATHS = {
197
+ "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
198
+ "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
199
+ "test": (root / "test2017", root / "annotations" / f'image_info_test-dev2017.json'),
200
+ }
201
+
202
+ img_folder, ann_file = PATHS[image_set.split("_")[0]]
203
+
204
+ try:
205
+ square_resize = args.square_resize
206
+ except:
207
+ square_resize = False
208
+
209
+ try:
210
+ square_resize_div_64 = args.square_resize_div_64
211
+ except:
212
+ square_resize_div_64 = False
213
+
214
+
215
+ if square_resize_div_64:
216
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(
217
+ image_set,
218
+ resolution,
219
+ multi_scale=args.multi_scale,
220
+ expanded_scales=args.expanded_scales,
221
+ skip_random_resize=not args.do_random_resize_via_padding,
222
+ patch_size=args.patch_size,
223
+ num_windows=args.num_windows
224
+ ))
225
+ else:
226
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(
227
+ image_set,
228
+ resolution,
229
+ multi_scale=args.multi_scale,
230
+ expanded_scales=args.expanded_scales,
231
+ skip_random_resize=not args.do_random_resize_via_padding,
232
+ patch_size=args.patch_size,
233
+ num_windows=args.num_windows
234
+ ))
235
+ return dataset
236
+
237
+ def build_roboflow(image_set, args, resolution):
238
+ root = Path(args.dataset_dir)
239
+ assert root.exists(), f'provided Roboflow path {root} does not exist'
240
+ mode = 'instances'
241
+ PATHS = {
242
+ "train": (root / "train", root / "train" / "_annotations.coco.json"),
243
+ "val": (root / "valid", root / "valid" / "_annotations.coco.json"),
244
+ "test": (root / "test", root / "test" / "_annotations.coco.json"),
245
+ }
246
+
247
+ img_folder, ann_file = PATHS[image_set.split("_")[0]]
248
+
249
+ try:
250
+ square_resize = args.square_resize
251
+ except:
252
+ square_resize = False
253
+
254
+ try:
255
+ square_resize_div_64 = args.square_resize_div_64
256
+ except:
257
+ square_resize_div_64 = False
258
+
259
+
260
+ if square_resize_div_64:
261
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(
262
+ image_set,
263
+ resolution,
264
+ multi_scale=args.multi_scale,
265
+ expanded_scales=args.expanded_scales,
266
+ skip_random_resize=not args.do_random_resize_via_padding,
267
+ patch_size=args.patch_size,
268
+ num_windows=args.num_windows
269
+ ))
270
+ else:
271
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(
272
+ image_set,
273
+ resolution,
274
+ multi_scale=args.multi_scale,
275
+ expanded_scales=args.expanded_scales,
276
+ skip_random_resize=not args.do_random_resize_via_padding,
277
+ patch_size=args.patch_size,
278
+ num_windows=args.num_windows
279
+ ))
280
+ return dataset
rfdetr/datasets/coco_eval.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ COCO evaluator that works in distributed mode.
18
+
19
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
20
+ The difference is that there is less copy-pasting from pycocotools
21
+ in the end of the file, as python3 can suppress prints with contextlib
22
+ """
23
+ import os
24
+ import contextlib
25
+ import copy
26
+ import numpy as np
27
+ import torch
28
+
29
+ from pycocotools.cocoeval import COCOeval
30
+ from pycocotools.coco import COCO
31
+ import pycocotools.mask as mask_util
32
+
33
+ from rfdetr.util.misc import all_gather
34
+
35
+
36
+ class CocoEvaluator(object):
37
+ def __init__(self, coco_gt, iou_types):
38
+ assert isinstance(iou_types, (list, tuple))
39
+ coco_gt = copy.deepcopy(coco_gt)
40
+ self.coco_gt = coco_gt
41
+
42
+ self.iou_types = iou_types
43
+ self.coco_eval = {}
44
+ for iou_type in iou_types:
45
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
46
+
47
+ self.img_ids = []
48
+ self.eval_imgs = {k: [] for k in iou_types}
49
+
50
+ def update(self, predictions):
51
+ img_ids = list(np.unique(list(predictions.keys())))
52
+ self.img_ids.extend(img_ids)
53
+
54
+ for iou_type in self.iou_types:
55
+ results = self.prepare(predictions, iou_type)
56
+
57
+ # suppress pycocotools prints
58
+ with open(os.devnull, 'w') as devnull:
59
+ with contextlib.redirect_stdout(devnull):
60
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
61
+ coco_eval = self.coco_eval[iou_type]
62
+
63
+ coco_eval.cocoDt = coco_dt
64
+ coco_eval.params.imgIds = list(img_ids)
65
+ img_ids, eval_imgs = evaluate(coco_eval)
66
+
67
+ self.eval_imgs[iou_type].append(eval_imgs)
68
+
69
+ def synchronize_between_processes(self):
70
+ for iou_type in self.iou_types:
71
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
72
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
73
+
74
+ def accumulate(self):
75
+ for coco_eval in self.coco_eval.values():
76
+ coco_eval.accumulate()
77
+
78
+ def summarize(self):
79
+ for iou_type, coco_eval in self.coco_eval.items():
80
+ print("IoU metric: {}".format(iou_type))
81
+ coco_eval.summarize()
82
+
83
+ def prepare(self, predictions, iou_type):
84
+ if iou_type == "bbox":
85
+ return self.prepare_for_coco_detection(predictions)
86
+ elif iou_type == "segm":
87
+ return self.prepare_for_coco_segmentation(predictions)
88
+ elif iou_type == "keypoints":
89
+ return self.prepare_for_coco_keypoint(predictions)
90
+ else:
91
+ raise ValueError("Unknown iou type {}".format(iou_type))
92
+
93
+ def prepare_for_coco_detection(self, predictions):
94
+ coco_results = []
95
+ for original_id, prediction in predictions.items():
96
+ if len(prediction) == 0:
97
+ continue
98
+
99
+ boxes = prediction["boxes"]
100
+ boxes = convert_to_xywh(boxes).tolist()
101
+ scores = prediction["scores"].tolist()
102
+ labels = prediction["labels"].tolist()
103
+
104
+ coco_results.extend(
105
+ [
106
+ {
107
+ "image_id": original_id,
108
+ "category_id": labels[k],
109
+ "bbox": box,
110
+ "score": scores[k],
111
+ }
112
+ for k, box in enumerate(boxes)
113
+ ]
114
+ )
115
+ return coco_results
116
+
117
+ def prepare_for_coco_segmentation(self, predictions):
118
+ coco_results = []
119
+ for original_id, prediction in predictions.items():
120
+ if len(prediction) == 0:
121
+ continue
122
+
123
+ scores = prediction["scores"]
124
+ labels = prediction["labels"]
125
+ masks = prediction["masks"]
126
+
127
+ masks = masks > 0.5
128
+
129
+ scores = prediction["scores"].tolist()
130
+ labels = prediction["labels"].tolist()
131
+
132
+ rles = [
133
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
134
+ for mask in masks
135
+ ]
136
+ for rle in rles:
137
+ rle["counts"] = rle["counts"].decode("utf-8")
138
+
139
+ coco_results.extend(
140
+ [
141
+ {
142
+ "image_id": original_id,
143
+ "category_id": labels[k],
144
+ "segmentation": rle,
145
+ "score": scores[k],
146
+ }
147
+ for k, rle in enumerate(rles)
148
+ ]
149
+ )
150
+ return coco_results
151
+
152
+ def prepare_for_coco_keypoint(self, predictions):
153
+ coco_results = []
154
+ for original_id, prediction in predictions.items():
155
+ if len(prediction) == 0:
156
+ continue
157
+
158
+ boxes = prediction["boxes"]
159
+ boxes = convert_to_xywh(boxes).tolist()
160
+ scores = prediction["scores"].tolist()
161
+ labels = prediction["labels"].tolist()
162
+ keypoints = prediction["keypoints"]
163
+ keypoints = keypoints.flatten(start_dim=1).tolist()
164
+
165
+ coco_results.extend(
166
+ [
167
+ {
168
+ "image_id": original_id,
169
+ "category_id": labels[k],
170
+ 'keypoints': keypoint,
171
+ "score": scores[k],
172
+ }
173
+ for k, keypoint in enumerate(keypoints)
174
+ ]
175
+ )
176
+ return coco_results
177
+
178
+
179
+ def convert_to_xywh(boxes):
180
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
181
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
182
+
183
+
184
+ def merge(img_ids, eval_imgs):
185
+ all_img_ids = all_gather(img_ids)
186
+ all_eval_imgs = all_gather(eval_imgs)
187
+
188
+ merged_img_ids = []
189
+ for p in all_img_ids:
190
+ merged_img_ids.extend(p)
191
+
192
+ merged_eval_imgs = []
193
+ for p in all_eval_imgs:
194
+ merged_eval_imgs.append(p)
195
+
196
+ merged_img_ids = np.array(merged_img_ids)
197
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
198
+
199
+ # keep only unique (and in sorted order) images
200
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
201
+ merged_eval_imgs = merged_eval_imgs[..., idx]
202
+
203
+ return merged_img_ids, merged_eval_imgs
204
+
205
+
206
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
207
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
208
+ img_ids = list(img_ids)
209
+ eval_imgs = list(eval_imgs.flatten())
210
+
211
+ coco_eval.evalImgs = eval_imgs
212
+ coco_eval.params.imgIds = img_ids
213
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
214
+
215
+
216
+ #################################################################
217
+ # From pycocotools, just removed the prints and fixed
218
+ # a Python3 bug about unicode not defined
219
+ #################################################################
220
+
221
+
222
+ def evaluate(self):
223
+ '''
224
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
225
+ :return: None
226
+ '''
227
+ # tic = time.time()
228
+ # print('Running per image evaluation...')
229
+ p = self.params
230
+ # add backward compatibility if useSegm is specified in params
231
+ if p.useSegm is not None:
232
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
233
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
234
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
235
+ p.imgIds = list(np.unique(p.imgIds))
236
+ if p.useCats:
237
+ p.catIds = list(np.unique(p.catIds))
238
+ p.maxDets = sorted(p.maxDets)
239
+ self.params = p
240
+
241
+ self._prepare()
242
+ # loop through images, area range, max detection number
243
+ catIds = p.catIds if p.useCats else [-1]
244
+
245
+ if p.iouType == 'segm' or p.iouType == 'bbox':
246
+ computeIoU = self.computeIoU
247
+ elif p.iouType == 'keypoints':
248
+ computeIoU = self.computeOks
249
+ self.ious = {
250
+ (imgId, catId): computeIoU(imgId, catId)
251
+ for imgId in p.imgIds
252
+ for catId in catIds}
253
+
254
+ evaluateImg = self.evaluateImg
255
+ maxDet = p.maxDets[-1]
256
+ evalImgs = [
257
+ evaluateImg(imgId, catId, areaRng, maxDet)
258
+ for catId in catIds
259
+ for areaRng in p.areaRng
260
+ for imgId in p.imgIds
261
+ ]
262
+ # this is NOT in the pycocotools code, but could be done outside
263
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
264
+ self._paramsEval = copy.deepcopy(self.params)
265
+ # toc = time.time()
266
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
267
+ return p.imgIds, evalImgs
268
+
269
+ #################################################################
270
+ # end of straight copy from pycocotools, just removing the prints
271
+ #################################################################
rfdetr/datasets/o365.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """Dataset file for Object365."""
11
+ from pathlib import Path
12
+
13
+ from .coco import (
14
+ CocoDetection, make_coco_transforms, make_coco_transforms_square_div_64
15
+ )
16
+
17
+ from PIL import Image
18
+ Image.MAX_IMAGE_PIXELS = None
19
+
20
+
21
+ def build_o365_raw(image_set, args, resolution):
22
+ root = Path(args.coco_path)
23
+ PATHS = {
24
+ "train": (root, root / 'zhiyuan_objv2_train_val_wo_5k.json'),
25
+ "val": (root, root / 'zhiyuan_objv2_minival5k.json'),
26
+ }
27
+ img_folder, ann_file = PATHS[image_set]
28
+
29
+ try:
30
+ square_resize = args.square_resize
31
+ except:
32
+ square_resize = False
33
+
34
+ try:
35
+ square_resize_div_64 = args.square_resize_div_64
36
+ except:
37
+ square_resize_div_64 = False
38
+
39
+ if square_resize_div_64:
40
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
41
+ else:
42
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
43
+ return dataset
44
+
45
+
46
+ def build_o365(image_set, args, resolution):
47
+ if image_set == 'train':
48
+ train_ds = build_o365_raw('train', args, resolution=resolution)
49
+ return train_ds
50
+ if image_set == 'val':
51
+ val_ds = build_o365_raw('val', args, resolution=resolution)
52
+ return val_ds
53
+ raise ValueError('Unknown image_set: {}'.format(image_set))
rfdetr/datasets/transforms.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ Transforms and data augmentation for both image + bbox.
18
+ """
19
+ import random
20
+
21
+ import PIL
22
+ import numpy as np
23
+ try:
24
+ from collections.abc import Sequence
25
+ except Exception:
26
+ from collections import Sequence
27
+ from numbers import Number
28
+ import torch
29
+ import torchvision.transforms as T
30
+ # from detectron2.data import transforms as DT
31
+ import torchvision.transforms.functional as F
32
+
33
+ from rfdetr.util.box_ops import box_xyxy_to_cxcywh
34
+ from rfdetr.util.misc import interpolate
35
+
36
+
37
+ def crop(image, target, region):
38
+ cropped_image = F.crop(image, *region)
39
+
40
+ target = target.copy()
41
+ i, j, h, w = region
42
+
43
+ # should we do something wrt the original size?
44
+ target["size"] = torch.tensor([h, w])
45
+
46
+ fields = ["labels", "area", "iscrowd"]
47
+
48
+ if "boxes" in target:
49
+ boxes = target["boxes"]
50
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
51
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
52
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
53
+ cropped_boxes = cropped_boxes.clamp(min=0)
54
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
55
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
56
+ target["area"] = area
57
+ fields.append("boxes")
58
+
59
+ if "masks" in target:
60
+ # FIXME should we update the area here if there are no boxes?
61
+ target['masks'] = target['masks'][:, i:i + h, j:j + w]
62
+ fields.append("masks")
63
+
64
+ # remove elements for which the boxes or masks that have zero area
65
+ if "boxes" in target or "masks" in target:
66
+ # favor boxes selection when defining which elements to keep
67
+ # this is compatible with previous implementation
68
+ if "boxes" in target:
69
+ cropped_boxes = target['boxes'].reshape(-1, 2, 2)
70
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
71
+ else:
72
+ keep = target['masks'].flatten(1).any(1)
73
+
74
+ for field in fields:
75
+ target[field] = target[field][keep]
76
+
77
+ return cropped_image, target
78
+
79
+
80
+ def hflip(image, target):
81
+ flipped_image = F.hflip(image)
82
+
83
+ w, h = image.size
84
+
85
+ target = target.copy()
86
+ if "boxes" in target:
87
+ boxes = target["boxes"]
88
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
89
+ target["boxes"] = boxes
90
+
91
+ if "masks" in target:
92
+ target['masks'] = target['masks'].flip(-1)
93
+
94
+ return flipped_image, target
95
+
96
+
97
+ def resize(image, target, size, max_size=None):
98
+ # size can be min_size (scalar) or (w, h) tuple
99
+
100
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
101
+ w, h = image_size
102
+ if max_size is not None:
103
+ min_original_size = float(min((w, h)))
104
+ max_original_size = float(max((w, h)))
105
+ if max_original_size / min_original_size * size > max_size:
106
+ size = int(round(max_size * min_original_size / max_original_size))
107
+
108
+ if (w <= h and w == size) or (h <= w and h == size):
109
+ return (h, w)
110
+
111
+ if w < h:
112
+ ow = size
113
+ oh = int(size * h / w)
114
+ else:
115
+ oh = size
116
+ ow = int(size * w / h)
117
+
118
+ return (oh, ow)
119
+
120
+ def get_size(image_size, size, max_size=None):
121
+ if isinstance(size, (list, tuple)):
122
+ return size[::-1]
123
+ else:
124
+ return get_size_with_aspect_ratio(image_size, size, max_size)
125
+
126
+ size = get_size(image.size, size, max_size)
127
+ rescaled_image = F.resize(image, size)
128
+
129
+ if target is None:
130
+ return rescaled_image, None
131
+
132
+ ratios = tuple(
133
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
134
+ ratio_width, ratio_height = ratios
135
+
136
+ target = target.copy()
137
+ if "boxes" in target:
138
+ boxes = target["boxes"]
139
+ scaled_boxes = boxes * torch.as_tensor(
140
+ [ratio_width, ratio_height, ratio_width, ratio_height])
141
+ target["boxes"] = scaled_boxes
142
+
143
+ if "area" in target:
144
+ area = target["area"]
145
+ scaled_area = area * (ratio_width * ratio_height)
146
+ target["area"] = scaled_area
147
+
148
+ h, w = size
149
+ target["size"] = torch.tensor([h, w])
150
+
151
+ if "masks" in target:
152
+ target['masks'] = interpolate(
153
+ target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
154
+
155
+
156
+ return rescaled_image, target
157
+
158
+
159
+ def pad(image, target, padding):
160
+ # assumes that we only pad on the bottom right corners
161
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
162
+ if target is None:
163
+ return padded_image, None
164
+ target = target.copy()
165
+ # should we do something wrt the original size?
166
+ target["size"] = torch.tensor(padded_image.size[::-1])
167
+ if "masks" in target:
168
+ target['masks'] = torch.nn.functional.pad(
169
+ target['masks'], (0, padding[0], 0, padding[1]))
170
+ return padded_image, target
171
+
172
+
173
+ class RandomCrop(object):
174
+ def __init__(self, size):
175
+ self.size = size
176
+
177
+ def __call__(self, img, target):
178
+ region = T.RandomCrop.get_params(img, self.size)
179
+ return crop(img, target, region)
180
+
181
+
182
+ class RandomSizeCrop(object):
183
+ def __init__(self, min_size: int, max_size: int):
184
+ self.min_size = min_size
185
+ self.max_size = max_size
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ w = random.randint(self.min_size, min(img.width, self.max_size))
189
+ h = random.randint(self.min_size, min(img.height, self.max_size))
190
+ region = T.RandomCrop.get_params(img, [h, w])
191
+ return crop(img, target, region)
192
+
193
+
194
+ class CenterCrop(object):
195
+ def __init__(self, size):
196
+ self.size = size
197
+
198
+ def __call__(self, img, target):
199
+ image_width, image_height = img.size
200
+ crop_height, crop_width = self.size
201
+ crop_top = int(round((image_height - crop_height) / 2.))
202
+ crop_left = int(round((image_width - crop_width) / 2.))
203
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
204
+
205
+
206
+ class RandomHorizontalFlip(object):
207
+ def __init__(self, p=0.5):
208
+ self.p = p
209
+
210
+ def __call__(self, img, target):
211
+ if random.random() < self.p:
212
+ return hflip(img, target)
213
+ return img, target
214
+
215
+
216
+ class RandomResize(object):
217
+ def __init__(self, sizes, max_size=None):
218
+ assert isinstance(sizes, (list, tuple))
219
+ self.sizes = sizes
220
+ self.max_size = max_size
221
+
222
+ def __call__(self, img, target=None):
223
+ size = random.choice(self.sizes)
224
+ return resize(img, target, size, self.max_size)
225
+
226
+
227
+ class SquareResize(object):
228
+ def __init__(self, sizes):
229
+ assert isinstance(sizes, (list, tuple))
230
+ self.sizes = sizes
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ rescaled_img=F.resize(img, (size, size))
235
+ w, h = rescaled_img.size
236
+ if target is None:
237
+ return rescaled_img, None
238
+ ratios = tuple(
239
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
240
+ ratio_width, ratio_height = ratios
241
+
242
+ target = target.copy()
243
+ if "boxes" in target:
244
+ boxes = target["boxes"]
245
+ scaled_boxes = boxes * torch.as_tensor(
246
+ [ratio_width, ratio_height, ratio_width, ratio_height])
247
+ target["boxes"] = scaled_boxes
248
+
249
+ if "area" in target:
250
+ area = target["area"]
251
+ scaled_area = area * (ratio_width * ratio_height)
252
+ target["area"] = scaled_area
253
+
254
+ target["size"] = torch.tensor([h, w])
255
+
256
+ return rescaled_img, target
257
+
258
+
259
+ class RandomPad(object):
260
+ def __init__(self, max_pad):
261
+ self.max_pad = max_pad
262
+
263
+ def __call__(self, img, target):
264
+ pad_x = random.randint(0, self.max_pad)
265
+ pad_y = random.randint(0, self.max_pad)
266
+ return pad(img, target, (pad_x, pad_y))
267
+
268
+
269
+ class PILtoNdArray(object):
270
+
271
+ def __call__(self, img, target):
272
+ return np.asarray(img), target
273
+
274
+
275
+ class NdArraytoPIL(object):
276
+
277
+ def __call__(self, img, target):
278
+ return F.to_pil_image(img.astype('uint8')), target
279
+
280
+
281
+ class Pad(object):
282
+ def __init__(self,
283
+ size=None,
284
+ size_divisor=32,
285
+ pad_mode=0,
286
+ offsets=None,
287
+ fill_value=(127.5, 127.5, 127.5)):
288
+ """
289
+ Pad image to a specified size or multiple of size_divisor.
290
+ Args:
291
+ size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
292
+ size_divisor (int): size divisor, default 32
293
+ pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
294
+ if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
295
+ offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
296
+ fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
297
+ """
298
+
299
+ if not isinstance(size, (int, Sequence)):
300
+ raise TypeError(
301
+ "Type of target_size is invalid when random_size is True. \
302
+ Must be List, now is {}".format(type(size)))
303
+
304
+ if isinstance(size, int):
305
+ size = [size, size]
306
+
307
+ assert pad_mode in [
308
+ -1, 0, 1, 2
309
+ ], 'currently only supports four modes [-1, 0, 1, 2]'
310
+ if pad_mode == -1:
311
+ assert offsets, 'if pad_mode is -1, offsets should not be None'
312
+
313
+ self.size = size
314
+ self.size_divisor = size_divisor
315
+ self.pad_mode = pad_mode
316
+ self.fill_value = fill_value
317
+ self.offsets = offsets
318
+
319
+ def apply_bbox(self, bbox, offsets):
320
+ return bbox + np.array(offsets * 2, dtype=np.float32)
321
+
322
+ def apply_image(self, image, offsets, im_size, size):
323
+ x, y = offsets
324
+ im_h, im_w = im_size
325
+ h, w = size
326
+ canvas = np.ones((h, w, 3), dtype=np.float32)
327
+ canvas *= np.array(self.fill_value, dtype=np.float32)
328
+ canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
329
+ return canvas
330
+
331
+ def __call__(self, im, target):
332
+ im_h, im_w = im.shape[:2]
333
+ if self.size:
334
+ h, w = self.size
335
+ assert (
336
+ im_h <= h and im_w <= w
337
+ ), '(h, w) of target size should be greater than (im_h, im_w)'
338
+ else:
339
+ h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
340
+ w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
341
+
342
+ if h == im_h and w == im_w:
343
+ return im.astype(np.float32), target
344
+
345
+ if self.pad_mode == -1:
346
+ offset_x, offset_y = self.offsets
347
+ elif self.pad_mode == 0:
348
+ offset_y, offset_x = 0, 0
349
+ elif self.pad_mode == 1:
350
+ offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
351
+ else:
352
+ offset_y, offset_x = h - im_h, w - im_w
353
+
354
+ offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
355
+
356
+ im = self.apply_image(im, offsets, im_size, size)
357
+
358
+ if self.pad_mode == 0:
359
+ target["size"] = torch.tensor([h, w])
360
+ return im, target
361
+ if 'boxes' in target and len(target['boxes']) > 0:
362
+ boxes = np.asarray(target["boxes"])
363
+ target["boxes"] = torch.from_numpy(self.apply_bbox(boxes, offsets))
364
+ target["size"] = torch.tensor([h, w])
365
+
366
+ return im, target
367
+
368
+
369
+ class RandomExpand(object):
370
+ """Random expand the canvas.
371
+ Args:
372
+ ratio (float): maximum expansion ratio.
373
+ prob (float): probability to expand.
374
+ fill_value (list): color value used to fill the canvas. in RGB order.
375
+ """
376
+
377
+ def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
378
+ assert ratio > 1.01, "expand ratio must be larger than 1.01"
379
+ self.ratio = ratio
380
+ self.prob = prob
381
+ assert isinstance(fill_value, (Number, Sequence)), \
382
+ "fill value must be either float or sequence"
383
+ if isinstance(fill_value, Number):
384
+ fill_value = (fill_value, ) * 3
385
+ if not isinstance(fill_value, tuple):
386
+ fill_value = tuple(fill_value)
387
+ self.fill_value = fill_value
388
+
389
+ def __call__(self, img, target):
390
+ if np.random.uniform(0., 1.) < self.prob:
391
+ return img, target
392
+
393
+ height, width = img.shape[:2]
394
+ ratio = np.random.uniform(1., self.ratio)
395
+ h = int(height * ratio)
396
+ w = int(width * ratio)
397
+ if not h > height or not w > width:
398
+ return img, target
399
+ y = np.random.randint(0, h - height)
400
+ x = np.random.randint(0, w - width)
401
+ offsets, size = [x, y], [h, w]
402
+
403
+ pad = Pad(size,
404
+ pad_mode=-1,
405
+ offsets=offsets,
406
+ fill_value=self.fill_value)
407
+
408
+ return pad(img, target)
409
+
410
+
411
+ class RandomSelect(object):
412
+ """
413
+ Randomly selects between transforms1 and transforms2,
414
+ with probability p for transforms1 and (1 - p) for transforms2
415
+ """
416
+ def __init__(self, transforms1, transforms2, p=0.5):
417
+ self.transforms1 = transforms1
418
+ self.transforms2 = transforms2
419
+ self.p = p
420
+
421
+ def __call__(self, img, target):
422
+ if random.random() < self.p:
423
+ return self.transforms1(img, target)
424
+ return self.transforms2(img, target)
425
+
426
+
427
+ class ToTensor(object):
428
+ def __call__(self, img, target):
429
+ return F.to_tensor(img), target
430
+
431
+
432
+ class RandomErasing(object):
433
+
434
+ def __init__(self, *args, **kwargs):
435
+ self.eraser = T.RandomErasing(*args, **kwargs)
436
+
437
+ def __call__(self, img, target):
438
+ return self.eraser(img), target
439
+
440
+
441
+ class Normalize(object):
442
+ def __init__(self, mean, std):
443
+ self.mean = mean
444
+ self.std = std
445
+
446
+ def __call__(self, image, target=None):
447
+ image = F.normalize(image, mean=self.mean, std=self.std)
448
+ if target is None:
449
+ return image, None
450
+ target = target.copy()
451
+ h, w = image.shape[-2:]
452
+ if "boxes" in target:
453
+ boxes = target["boxes"]
454
+ boxes = box_xyxy_to_cxcywh(boxes)
455
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
456
+ target["boxes"] = boxes
457
+ return image, target
458
+
459
+
460
+ class Compose(object):
461
+ def __init__(self, transforms):
462
+ self.transforms = transforms
463
+
464
+ def __call__(self, image, target):
465
+ for t in self.transforms:
466
+ image, target = t(image, target)
467
+ return image, target
468
+
469
+ def __repr__(self):
470
+ format_string = self.__class__.__name__ + "("
471
+ for t in self.transforms:
472
+ format_string += "\n"
473
+ format_string += " {0}".format(t)
474
+ format_string += "\n)"
475
+ return format_string
rfdetr/deploy/__init__.py ADDED
File without changes
rfdetr/deploy/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (213 Bytes). View file
 
rfdetr/deploy/__pycache__/benchmark.cpython-313.pyc ADDED
Binary file (36 kB). View file
 
rfdetr/deploy/__pycache__/export.cpython-313.pyc ADDED
Binary file (14.5 kB). View file
 
rfdetr/deploy/_onnx/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # LW-DETR
3
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ """
7
+ onnx optimizer and symbolic registry
8
+ """
9
+ from . import optimizer
10
+ from . import symbolic
11
+
12
+ from .optimizer import OnnxOptimizer
13
+ from .symbolic import CustomOpSymbolicRegistry
rfdetr/deploy/_onnx/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (433 Bytes). View file
 
rfdetr/deploy/_onnx/__pycache__/optimizer.cpython-313.pyc ADDED
Binary file (46.7 kB). View file
 
rfdetr/deploy/_onnx/__pycache__/symbolic.cpython-313.pyc ADDED
Binary file (1.55 kB). View file
 
rfdetr/deploy/_onnx/optimizer.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ OnnxOptimizer
12
+ """
13
+ import os
14
+ from collections import OrderedDict
15
+ from copy import deepcopy
16
+
17
+ import numpy as np
18
+ import onnx
19
+ import torch
20
+ from onnx import shape_inference
21
+ import onnx_graphsurgeon as gs
22
+ from polygraphy.backend.onnx.loader import fold_constants
23
+ from onnx_graphsurgeon.logger.logger import G_LOGGER
24
+
25
+ from .symbolic import CustomOpSymbolicRegistry
26
+
27
+
28
+ class OnnxOptimizer():
29
+ def __init__(
30
+ self,
31
+ input,
32
+ severity=G_LOGGER.INFO
33
+ ):
34
+ if isinstance(input, str):
35
+ onnx_graph = self.load_onnx(input)
36
+ else:
37
+ onnx_graph = input
38
+ self.graph = gs.import_onnx(onnx_graph)
39
+ self.severity = severity
40
+ self.set_severity(severity)
41
+
42
+ def set_severity(self, severity):
43
+ G_LOGGER.severity = severity
44
+
45
+ def load_onnx(self, onnx_path:str):
46
+ """Load onnx from file
47
+ """
48
+ assert os.path.isfile(onnx_path), f"not found onnx file: {onnx_path}"
49
+ onnx_graph = onnx.load(onnx_path)
50
+ G_LOGGER.info(f"load onnx file: {onnx_path}")
51
+ return onnx_graph
52
+
53
+ def save_onnx(self, onnx_path:str):
54
+ onnx_graph = gs.export_onnx(self.graph)
55
+ G_LOGGER.info(f"save onnx file: {onnx_path}")
56
+ onnx.save(onnx_graph, onnx_path)
57
+
58
+ def info(self, prefix=''):
59
+ G_LOGGER.verbose(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
60
+
61
+ def cleanup(self, return_onnx=False):
62
+ self.graph.cleanup().toposort()
63
+ if return_onnx:
64
+ return gs.export_onnx(self.graph)
65
+
66
+ def select_outputs(self, keep, names=None):
67
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
68
+ if names:
69
+ for i, name in enumerate(names):
70
+ self.graph.outputs[i].name = name
71
+
72
+ def find_node_input(self, node, name:str=None, value=None) -> int:
73
+ for i, inp in enumerate(node.inputs):
74
+ if isinstance(name, str) and inp.name == name:
75
+ index = i
76
+ elif inp == value:
77
+ index = i
78
+ assert index >= 0, f"not found {name}({value}) in node.inputs"
79
+ return index
80
+
81
+ def find_node_output(self, node, name:str=None, value=None) -> int:
82
+ for i, inp in enumerate(node.outputs):
83
+ if isinstance(name, str) and inp.name == name:
84
+ index = i
85
+ elif inp == value:
86
+ index = i
87
+ assert index >= 0, f"not found {name}({value}) in node.outputs"
88
+ return index
89
+
90
+ def common_opt(self, return_onnx=False):
91
+ for fn in CustomOpSymbolicRegistry._OPTIMIZER:
92
+ fn(self)
93
+ self.cleanup()
94
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=False)
95
+ if onnx_graph.ByteSize() > 2147483648:
96
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
97
+ else:
98
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
99
+ self.graph = gs.import_onnx(onnx_graph)
100
+ self.cleanup()
101
+ if return_onnx:
102
+ return onnx_graph
103
+
104
+ def resize_fix(self):
105
+ '''
106
+ This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs).
107
+ It substitutes found Resize with Resize that takes the size of the output tensor instead of scales.
108
+ It adds Shape->Slice->Concat
109
+ Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor.
110
+ This fix is required for the dynamic shape support.
111
+ '''
112
+ mResizeNodes = 0
113
+ for node in self.graph.nodes:
114
+ if node.op == "Resize" and len(node.inputs) == 3:
115
+ name = node.name + "/"
116
+
117
+ add_node = node.o().o().i(1)
118
+ div_node = node.i()
119
+
120
+ shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4])
121
+ shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out])
122
+
123
+ const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64))
124
+ const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64))
125
+ const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64))
126
+
127
+ slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2])
128
+ slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out])
129
+
130
+ shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2])
131
+ shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out])
132
+
133
+ slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2])
134
+ slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out])
135
+
136
+ concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4])
137
+ concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out])
138
+
139
+ none_var = gs.Variable.empty()
140
+
141
+ resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]])
142
+
143
+ self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw])
144
+
145
+ node.inputs = []
146
+ node.outputs = []
147
+
148
+ mResizeNodes += 1
149
+
150
+ self.cleanup()
151
+ return mResizeNodes
152
+
153
+ def adjustAddNode(self):
154
+ nAdjustAddNode = 0
155
+ for node in self.graph.nodes:
156
+ # Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT.
157
+ if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant):
158
+ tensor = node.inputs[1]
159
+ bias = node.inputs[0]
160
+ node.inputs = [tensor, bias]
161
+ nAdjustAddNode += 1
162
+
163
+ self.cleanup()
164
+ return nAdjustAddNode
165
+
166
+ def decompose_instancenorms(self):
167
+ nRemoveInstanceNorm = 0
168
+ for node in self.graph.nodes:
169
+ if node.op == "InstanceNormalization":
170
+ name = node.name + "/"
171
+ input_tensor = node.inputs[0]
172
+ output_tensor = node.outputs[0]
173
+ mean_out = gs.Variable(name=name + "mean_out")
174
+ mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
175
+ sub_out = gs.Variable(name=name + "sub_out")
176
+ sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
177
+ pow_out = gs.Variable(name=name + "pow_out")
178
+ pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
179
+ pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
180
+ mean2_out = gs.Variable(name=name + "mean2_out")
181
+ mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
182
+ epsilon_out = gs.Variable(name=name + "epsilon_out")
183
+ epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
184
+ epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
185
+ sqrt_out = gs.Variable(name=name + "sqrt_out")
186
+ sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
187
+ div_out = gs.Variable(name=name + "div_out")
188
+ div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
189
+ constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
190
+ constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
191
+ mul_out = gs.Variable(name=name + "mul_out")
192
+ mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
193
+ add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
194
+ self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
195
+ node.inputs = []
196
+ node.outputs = []
197
+ nRemoveInstanceNorm += 1
198
+
199
+ self.cleanup()
200
+ return nRemoveInstanceNorm
201
+
202
+ def insert_groupnorm_plugin(self):
203
+ nGroupNormPlugin = 0
204
+ for node in self.graph.nodes:
205
+ if node.op == "Reshape" and node.outputs != [] and \
206
+ node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
207
+ node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
208
+ node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
209
+ len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3:
210
+ # "node.outputs != []" is added for VAE
211
+
212
+ inputTensor = node.inputs[0]
213
+
214
+ gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
215
+ index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
216
+ gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
217
+ constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
218
+
219
+ betaNode = gammaNode.o()
220
+ index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
221
+ beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
222
+ constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
223
+
224
+ epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]
225
+
226
+ if betaNode.o().op == "Sigmoid": # need Swish
227
+ bSwish = True
228
+ lastNode = betaNode.o().o() # Mul node of Swish
229
+ else:
230
+ bSwish = False
231
+ lastNode = betaNode # Cast node after Group Norm
232
+
233
+ if lastNode.o().op == "Cast":
234
+ lastNode = lastNode.o()
235
+ inputList = [inputTensor, constantGamma, constantBeta]
236
+ groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
237
+ groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
238
+ self.graph.nodes.append(groupNormN)
239
+
240
+ for subNode in self.graph.nodes:
241
+ if lastNode.outputs[0] in subNode.inputs:
242
+ index = subNode.inputs.index(lastNode.outputs[0])
243
+ subNode.inputs[index] = groupNormV
244
+ node.inputs = []
245
+ lastNode.outputs = []
246
+ nGroupNormPlugin += 1
247
+
248
+ self.cleanup()
249
+ return nGroupNormPlugin
250
+
251
+ def insert_layernorm_plugin(self):
252
+ nLayerNormPlugin = 0
253
+ for node in self.graph.nodes:
254
+ if node.op == 'ReduceMean' and \
255
+ node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \
256
+ node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \
257
+ node.o().o(0).o().op == 'ReduceMean' and \
258
+ node.o().o(0).o().o().op == 'Add' and \
259
+ node.o().o(0).o().o().o().op == 'Sqrt' and \
260
+ node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \
261
+ node.o().o(0).o().o().o().o().o().op == 'Mul' and \
262
+ node.o().o(0).o().o().o().o().o().o().op == 'Add' and \
263
+ len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1:
264
+
265
+ if node.i().op == "Add":
266
+ inputTensor = node.inputs[0] # CLIP
267
+ else:
268
+ inputTensor = node.i().inputs[0] # UNet and VAE
269
+
270
+ gammaNode = node.o().o().o().o().o().o().o()
271
+ index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
272
+ gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
273
+ constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
274
+
275
+ betaNode = gammaNode.o()
276
+ index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
277
+ beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
278
+ constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
279
+
280
+ inputList = [inputTensor, constantGamma, constantBeta]
281
+ layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape)
282
+ layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV])
283
+ self.graph.nodes.append(layerNormN)
284
+ nLayerNormPlugin += 1
285
+
286
+ if betaNode.outputs[0] in self.graph.outputs:
287
+ index = self.graph.outputs.index(betaNode.outputs[0])
288
+ self.graph.outputs[index] = layerNormV
289
+ else:
290
+ if betaNode.o().op == "Cast":
291
+ lastNode = betaNode.o()
292
+ else:
293
+ lastNode = betaNode
294
+ for subNode in self.graph.nodes:
295
+ if lastNode.outputs[0] in subNode.inputs:
296
+ index = subNode.inputs.index(lastNode.outputs[0])
297
+ subNode.inputs[index] = layerNormV
298
+ lastNode.outputs = []
299
+
300
+ self.cleanup()
301
+ return nLayerNormPlugin
302
+
303
+ def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0):
304
+ # Get weights of K
305
+ weights_k = node_k.inputs[1].values
306
+ # Get weights of V
307
+ weights_v = node_v.inputs[1].values
308
+ # Input number of channels to K and V
309
+ C = weights_k.shape[0]
310
+ # Number of heads
311
+ H = heads
312
+ # Dimension per head
313
+ D = weights_k.shape[1] // H
314
+
315
+ # Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape
316
+ weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D)
317
+
318
+ # K and V have the same input
319
+ input_tensor = node_k.inputs[0]
320
+ # K and V must have the same output which we feed into fmha plugin
321
+ output_tensor_k = node_k.outputs[0]
322
+ # Create tensor
323
+ constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv))
324
+
325
+ # Create fused KV node
326
+ fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k])
327
+ self.graph.nodes.append(fused_kv_node)
328
+
329
+ # Connect the output of fused node to the inputs of the nodes after K and V
330
+ node_v.o(num_dynamic).inputs[0] = output_tensor_k
331
+ node_k.o(num_dynamic).inputs[0] = output_tensor_k
332
+ for i in range(0,num_dynamic):
333
+ node_v.o().inputs.clear()
334
+ node_k.o().inputs.clear()
335
+
336
+ # Clear inputs and outputs of K and V to ge these nodes cleared
337
+ node_k.outputs.clear()
338
+ node_v.outputs.clear()
339
+ node_k.inputs.clear()
340
+ node_v.inputs.clear()
341
+
342
+ self.cleanup()
343
+ return fused_kv_node
344
+
345
+ def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0):
346
+ # Get inputs and outputs for the fMHCA plugin
347
+ # We take an output of reshape that follows the Q GEMM
348
+ output_q = node_q.o(num_dynamic).o().inputs[0]
349
+ output_kv = node_kv.o().inputs[0]
350
+ output_final_tranpose = final_tranpose.outputs[0]
351
+
352
+ # Clear the inputs of the nodes that follow the Q and KV GEMM
353
+ # to delete these subgraphs (it will be substituted by fMHCA plugin)
354
+ node_kv.outputs[0].outputs[0].inputs.clear()
355
+ node_kv.outputs[0].outputs[0].inputs.clear()
356
+ node_q.o(num_dynamic).o().inputs.clear()
357
+ for i in range(0,num_dynamic):
358
+ node_q.o(i).o().o(1).inputs.clear()
359
+
360
+ weights_kv = node_kv.inputs[1].values
361
+ dims_per_head = weights_kv.shape[1] // (heads * 2)
362
+
363
+ # Reshape dims
364
+ shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64)))
365
+
366
+ # Reshape output tensor
367
+ output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None)
368
+ # Create fMHA plugin
369
+ reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape])
370
+ # Insert node
371
+ self.graph.nodes.append(reshape)
372
+
373
+ # Create fMHCA plugin
374
+ fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose])
375
+ # Insert node
376
+ self.graph.nodes.append(fmhca)
377
+
378
+ # Connect input of fMHCA to output of Q GEMM
379
+ node_q.o(num_dynamic).outputs[0] = output_q
380
+
381
+ if num_dynamic > 0:
382
+ reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None)
383
+ reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out])
384
+ self.graph.nodes.append(reshape2_input1_shape)
385
+ final_tranpose.o().inputs[1] = reshape2_input1_out
386
+
387
+ # Clear outputs of transpose to get this subgraph cleared
388
+ final_tranpose.outputs.clear()
389
+
390
+ self.cleanup()
391
+
392
+ def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0):
393
+ # Get weights of Q
394
+ weights_q = node_q.inputs[1].values
395
+ # Get weights of K
396
+ weights_k = node_k.inputs[1].values
397
+ # Get weights of V
398
+ weights_v = node_v.inputs[1].values
399
+
400
+ # Input number of channels to Q, K and V
401
+ C = weights_k.shape[0]
402
+ # Number of heads
403
+ H = heads
404
+ # Hidden dimension per head
405
+ D = weights_k.shape[1] // H
406
+
407
+ # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
408
+ weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D)
409
+
410
+ input_tensor = node_k.inputs[0] # K and V have the same input
411
+ # Q, K and V must have the same output which we feed into fmha plugin
412
+ output_tensor_k = node_k.outputs[0]
413
+ # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
414
+ constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv))
415
+
416
+ # Created a fused node
417
+ fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k])
418
+ self.graph.nodes.append(fused_qkv_node)
419
+
420
+ # Connect the output of the fused node to the inputs of the nodes after Q, K and V
421
+ node_q.o(num_dynamic).inputs[0] = output_tensor_k
422
+ node_k.o(num_dynamic).inputs[0] = output_tensor_k
423
+ node_v.o(num_dynamic).inputs[0] = output_tensor_k
424
+ for i in range(0,num_dynamic):
425
+ node_q.o().inputs.clear()
426
+ node_k.o().inputs.clear()
427
+ node_v.o().inputs.clear()
428
+
429
+ # Clear inputs and outputs of Q, K and V to ge these nodes cleared
430
+ node_q.outputs.clear()
431
+ node_k.outputs.clear()
432
+ node_v.outputs.clear()
433
+
434
+ node_q.inputs.clear()
435
+ node_k.inputs.clear()
436
+ node_v.inputs.clear()
437
+
438
+ self.cleanup()
439
+ return fused_qkv_node
440
+
441
+ def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0):
442
+ # Get inputs and outputs for the fMHA plugin
443
+ output_qkv = node_qkv.o().inputs[0]
444
+ output_final_tranpose = final_tranpose.outputs[0]
445
+
446
+ # Clear the inputs of the nodes that follow the QKV GEMM
447
+ # to delete these subgraphs (it will be substituted by fMHA plugin)
448
+ node_qkv.outputs[0].outputs[2].inputs.clear()
449
+ node_qkv.outputs[0].outputs[1].inputs.clear()
450
+ node_qkv.outputs[0].outputs[0].inputs.clear()
451
+
452
+ weights_qkv = node_qkv.inputs[1].values
453
+ dims_per_head = weights_qkv.shape[1] // (heads * 3)
454
+
455
+ # Reshape dims
456
+ shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64)))
457
+
458
+ # Reshape output tensor
459
+ output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None)
460
+ # Create fMHA plugin
461
+ reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape])
462
+ # Insert node
463
+ self.graph.nodes.append(reshape)
464
+
465
+ # Create fMHA plugin
466
+ fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose])
467
+ # Insert node
468
+ self.graph.nodes.append(fmha)
469
+
470
+ if num_dynamic > 0:
471
+ reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None)
472
+ reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out])
473
+ self.graph.nodes.append(reshape2_input1_shape)
474
+ final_tranpose.o().inputs[1] = reshape2_input1_out
475
+
476
+ # Clear outputs of transpose to get this subgraph cleared
477
+ final_tranpose.outputs.clear()
478
+
479
+ self.cleanup()
480
+
481
+ def mha_mhca_detected(self, node, mha):
482
+ # Go from V GEMM down to the S*V MatMul and all way up to K GEMM
483
+ # If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
484
+ # If we are looking for MHA inputs (K and V) must be not equal.
485
+ if node.op == "MatMul" and len(node.outputs) == 1 and \
486
+ ((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \
487
+ (not mha and len(node.inputs[0].inputs) == 0)):
488
+
489
+ if node.o().op == 'Shape':
490
+ if node.o(1).op == 'Shape':
491
+ num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
492
+ else:
493
+ num_dynamic_kv = 1
494
+ # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
495
+ num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1
496
+ else:
497
+ num_dynamic_kv = 0
498
+ num_dynamic_q = 0
499
+
500
+ o = node.o(num_dynamic_kv)
501
+ if o.op == "Reshape" and \
502
+ o.o().op == "Transpose" and \
503
+ o.o().o().op == "Reshape" and \
504
+ o.o().o().o().op == "MatMul" and \
505
+ o.o().o().o().i(0).op == "Softmax" and \
506
+ o.o().o().o().i(1).op == "Reshape" and \
507
+ o.o().o().o().i(0).i().op == "Mul" and \
508
+ o.o().o().o().i(0).i().i().op == "MatMul" and \
509
+ o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
510
+ o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
511
+ o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
512
+ o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
513
+ o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
514
+ o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
515
+ node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
516
+ # "len(node.outputs) == 1" to make sure we are not in the already fused node
517
+ node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
518
+ node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
519
+ node_v = node
520
+ final_tranpose = o.o().o().o().o(num_dynamic_q).o()
521
+ # Sanity check to make sure that the graph looks like expected
522
+ if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
523
+ return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose
524
+ return False, 0, 0, None, None, None, None
525
+
526
+ def fuse_kv_insert_fmhca(self, heads, mhca_index, sm):
527
+ nodes = self.graph.nodes
528
+ # Iterate over graph and search for MHCA pattern
529
+ for idx, _ in enumerate(nodes):
530
+ # fMHCA can't be at the 2 last layers of the network. It is a guard from OOB
531
+ if idx + 1 > len(nodes) or idx + 2 > len(nodes):
532
+ continue
533
+
534
+ # Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected
535
+ detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
536
+ self.mha_mhca_detected(nodes[idx], mha=False)
537
+ if detected:
538
+ assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1
539
+ # Skip the FMHCA plugin for SM75 except for when the dim per head is 40.
540
+ if sm == 75 and node_q.inputs[1].shape[1] // heads == 160:
541
+ continue
542
+ # Fuse K and V GEMMS
543
+ node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv)
544
+ # Insert fMHCA plugin
545
+ self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q)
546
+ return True
547
+ return False
548
+
549
+ def fuse_qkv_insert_fmha(self, heads, mha_index):
550
+ nodes = self.graph.nodes
551
+ # Iterate over graph and search for MHA pattern
552
+ for idx, _ in enumerate(nodes):
553
+ # fMHA can't be at the 2 last layers of the network. It is a guard from OOB
554
+ if idx + 1 > len(nodes) or idx + 2 > len(nodes):
555
+ continue
556
+
557
+ # Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected
558
+ detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
559
+ self.mha_mhca_detected(nodes[idx], mha=True)
560
+ if detected:
561
+ assert num_dynamic_q == num_dynamic_kv
562
+ # Fuse Q, K and V GEMMS
563
+ node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
564
+ # Insert fMHA plugin
565
+ self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
566
+ return True
567
+ return False
568
+
569
+ def insert_fmhca_plugin(self, num_heads, sm):
570
+ mhca_index = 0
571
+ while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm):
572
+ mhca_index += 1
573
+ return mhca_index
574
+
575
+ def insert_fmha_plugin(self, num_heads):
576
+ mha_index = 0
577
+ while self.fuse_qkv_insert_fmha(num_heads, mha_index):
578
+ mha_index += 1
579
+ return mha_index
rfdetr/deploy/_onnx/symbolic.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ """
10
+ CustomOpSymbolicRegistry class
11
+ """
12
+ from copy import deepcopy
13
+
14
+ import onnx
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.onnx import register_custom_op_symbolic
19
+ from torch.onnx.symbolic_helper import parse_args
20
+ from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes
21
+ from torch.autograd import Function
22
+
23
+
24
+ class CustomOpSymbolicRegistry:
25
+ # _SYMBOLICS = {}
26
+ _OPTIMIZER = []
27
+
28
+ @classmethod
29
+ def optimizer(cls, fn):
30
+ cls._OPTIMIZER.append(fn)
31
+
32
+
33
+ def register_optimizer():
34
+ def optimizer_wrapper(fn):
35
+ CustomOpSymbolicRegistry.optimizer(fn)
36
+ return fn
37
+ return optimizer_wrapper
rfdetr/deploy/benchmark.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ This tool provides performance benchmarks by using ONNX Runtime and TensorRT
12
+ to run inference on a given model with the COCO validation set. It offers
13
+ reliable measurements of inference latency using ONNX Runtime or TensorRT
14
+ on the device.
15
+ """
16
+ import argparse
17
+ import copy
18
+ import contextlib
19
+ import datetime
20
+ import json
21
+ import os
22
+ import os.path as osp
23
+ import random
24
+ import time
25
+ import ast
26
+ from pathlib import Path
27
+ from collections import namedtuple, OrderedDict
28
+
29
+ from pycocotools.cocoeval import COCOeval
30
+ from pycocotools.coco import COCO
31
+ import pycocotools.mask as mask_util
32
+
33
+ import numpy as np
34
+ from PIL import Image
35
+ import torch
36
+ from torch.utils.data import DataLoader, DistributedSampler
37
+ import torchvision.transforms as T
38
+ import torchvision.transforms.functional as F
39
+ import tqdm
40
+
41
+ import pycuda.driver as cuda
42
+ import pycuda.autoinit
43
+ import onnxruntime as nxrun
44
+ import tensorrt as trt
45
+
46
+
47
+ def parser_args():
48
+ parser = argparse.ArgumentParser('performance benchmark tool for onnx/trt model')
49
+ parser.add_argument('--path', type=str, help='engine file path')
50
+ parser.add_argument('--coco_path', type=str, default="data/coco", help='coco dataset path')
51
+ parser.add_argument('--device', default=0, type=int)
52
+ parser.add_argument('--run_benchmark', action='store_true', help='repeat the inference to benchmark the latency')
53
+ parser.add_argument('--disable_eval', action='store_true', help='disable evaluation')
54
+ return parser.parse_args()
55
+
56
+
57
+ class CocoEvaluator(object):
58
+ def __init__(self, coco_gt, iou_types):
59
+ assert isinstance(iou_types, (list, tuple))
60
+ coco_gt = COCO(coco_gt)
61
+ coco_gt = copy.deepcopy(coco_gt)
62
+ self.coco_gt = coco_gt
63
+
64
+ self.iou_types = iou_types
65
+ self.coco_eval = {}
66
+ for iou_type in iou_types:
67
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
68
+
69
+ self.img_ids = []
70
+ self.eval_imgs = {k: [] for k in iou_types}
71
+
72
+ def update(self, predictions):
73
+ img_ids = list(np.unique(list(predictions.keys())))
74
+ self.img_ids.extend(img_ids)
75
+
76
+ for iou_type in self.iou_types:
77
+ results = self.prepare(predictions, iou_type)
78
+
79
+ # suppress pycocotools prints
80
+ with open(os.devnull, 'w') as devnull:
81
+ with contextlib.redirect_stdout(devnull):
82
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
83
+ coco_eval = self.coco_eval[iou_type]
84
+
85
+ coco_eval.cocoDt = coco_dt
86
+ coco_eval.params.imgIds = list(img_ids)
87
+ img_ids, eval_imgs = evaluate(coco_eval)
88
+
89
+ self.eval_imgs[iou_type].append(eval_imgs)
90
+
91
+ def synchronize_between_processes(self):
92
+ for iou_type in self.iou_types:
93
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
94
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
95
+
96
+ def accumulate(self):
97
+ for coco_eval in self.coco_eval.values():
98
+ coco_eval.accumulate()
99
+
100
+ def summarize(self):
101
+ for iou_type, coco_eval in self.coco_eval.items():
102
+ print("IoU metric: {}".format(iou_type))
103
+ coco_eval.summarize()
104
+
105
+ def prepare(self, predictions, iou_type):
106
+ if iou_type == "bbox":
107
+ return self.prepare_for_coco_detection(predictions)
108
+ else:
109
+ raise ValueError("Unknown iou type {}".format(iou_type))
110
+
111
+ def prepare_for_coco_detection(self, predictions):
112
+ coco_results = []
113
+ for original_id, prediction in predictions.items():
114
+ if len(prediction) == 0:
115
+ continue
116
+
117
+ boxes = prediction["boxes"]
118
+ boxes = convert_to_xywh(boxes).tolist()
119
+ scores = prediction["scores"].tolist()
120
+ labels = prediction["labels"].tolist()
121
+
122
+ coco_results.extend(
123
+ [
124
+ {
125
+ "image_id": original_id,
126
+ "category_id": labels[k],
127
+ "bbox": box,
128
+ "score": scores[k],
129
+ }
130
+ for k, box in enumerate(boxes)
131
+ ]
132
+ )
133
+ return coco_results
134
+
135
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
136
+ img_ids = list(img_ids)
137
+ eval_imgs = list(eval_imgs.flatten())
138
+
139
+ coco_eval.evalImgs = eval_imgs
140
+ coco_eval.params.imgIds = img_ids
141
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
142
+
143
+ def evaluate(self):
144
+ '''
145
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
146
+ :return: None
147
+ '''
148
+ # Running per image evaluation...
149
+ p = self.params
150
+ # add backward compatibility if useSegm is specified in params
151
+ if p.useSegm is not None:
152
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
153
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
154
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
155
+ p.imgIds = list(np.unique(p.imgIds))
156
+ if p.useCats:
157
+ p.catIds = list(np.unique(p.catIds))
158
+ p.maxDets = sorted(p.maxDets)
159
+ self.params = p
160
+
161
+ self._prepare()
162
+ # loop through images, area range, max detection number
163
+ catIds = p.catIds if p.useCats else [-1]
164
+
165
+ if p.iouType == 'segm' or p.iouType == 'bbox':
166
+ computeIoU = self.computeIoU
167
+ elif p.iouType == 'keypoints':
168
+ computeIoU = self.computeOks
169
+ self.ious = {
170
+ (imgId, catId): computeIoU(imgId, catId)
171
+ for imgId in p.imgIds
172
+ for catId in catIds}
173
+
174
+ evaluateImg = self.evaluateImg
175
+ maxDet = p.maxDets[-1]
176
+ evalImgs = [
177
+ evaluateImg(imgId, catId, areaRng, maxDet)
178
+ for catId in catIds
179
+ for areaRng in p.areaRng
180
+ for imgId in p.imgIds
181
+ ]
182
+ # this is NOT in the pycocotools code, but could be done outside
183
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
184
+ self._paramsEval = copy.deepcopy(self.params)
185
+ return p.imgIds, evalImgs
186
+
187
+ def convert_to_xywh(boxes):
188
+ boxes[:, 2:] -= boxes[:, :2]
189
+ return boxes
190
+
191
+
192
+ def get_image_list(ann_file):
193
+ with open(ann_file, 'r') as fin:
194
+ data = json.load(fin)
195
+ return data['images']
196
+
197
+
198
+ def load_image(file_path):
199
+ return Image.open(file_path).convert("RGB")
200
+
201
+
202
+ class Compose(object):
203
+ def __init__(self, transforms):
204
+ self.transforms = transforms
205
+
206
+ def __call__(self, image, target):
207
+ for t in self.transforms:
208
+ image, target = t(image, target)
209
+ return image, target
210
+
211
+ def __repr__(self):
212
+ format_string = self.__class__.__name__ + "("
213
+ for t in self.transforms:
214
+ format_string += "\n"
215
+ format_string += " {0}".format(t)
216
+ format_string += "\n)"
217
+ return format_string
218
+
219
+
220
+ class ToTensor(object):
221
+ def __call__(self, img, target):
222
+ return F.to_tensor(img), target
223
+
224
+
225
+ class Normalize(object):
226
+ def __init__(self, mean, std):
227
+ self.mean = mean
228
+ self.std = std
229
+
230
+ def __call__(self, image, target=None):
231
+ image = F.normalize(image, mean=self.mean, std=self.std)
232
+ if target is None:
233
+ return image, None
234
+ target = target.copy()
235
+ h, w = image.shape[-2:]
236
+ if "boxes" in target:
237
+ boxes = target["boxes"]
238
+ boxes = box_xyxy_to_cxcywh(boxes)
239
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
240
+ target["boxes"] = boxes
241
+ return image, target
242
+
243
+
244
+ class SquareResize(object):
245
+ def __init__(self, sizes):
246
+ assert isinstance(sizes, (list, tuple))
247
+ self.sizes = sizes
248
+
249
+ def __call__(self, img, target=None):
250
+ size = random.choice(self.sizes)
251
+ rescaled_img=F.resize(img, (size, size))
252
+ w, h = rescaled_img.size
253
+ if target is None:
254
+ return rescaled_img, None
255
+ ratios = tuple(
256
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size))
257
+ ratio_width, ratio_height = ratios
258
+
259
+ target = target.copy()
260
+ if "boxes" in target:
261
+ boxes = target["boxes"]
262
+ scaled_boxes = boxes * torch.as_tensor(
263
+ [ratio_width, ratio_height, ratio_width, ratio_height])
264
+ target["boxes"] = scaled_boxes
265
+
266
+ if "area" in target:
267
+ area = target["area"]
268
+ scaled_area = area * (ratio_width * ratio_height)
269
+ target["area"] = scaled_area
270
+
271
+ target["size"] = torch.tensor([h, w])
272
+
273
+ return rescaled_img, target
274
+
275
+
276
+ def infer_transforms():
277
+ normalize = Compose([
278
+ ToTensor(),
279
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
280
+ ])
281
+ return Compose([
282
+ SquareResize([640]),
283
+ normalize,
284
+ ])
285
+
286
+
287
+ def box_cxcywh_to_xyxy(x):
288
+ x_c, y_c, w, h = x.unbind(-1)
289
+ b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)),
290
+ (x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))]
291
+ return torch.stack(b, dim=-1)
292
+
293
+
294
+ def post_process(outputs, target_sizes):
295
+ out_logits, out_bbox = outputs['labels'], outputs['dets']
296
+
297
+ assert len(out_logits) == len(target_sizes)
298
+ assert target_sizes.shape[1] == 2
299
+
300
+ prob = out_logits.sigmoid()
301
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
302
+ scores = topk_values
303
+ topk_boxes = topk_indexes // out_logits.shape[2]
304
+ labels = topk_indexes % out_logits.shape[2]
305
+ boxes = box_cxcywh_to_xyxy(out_bbox)
306
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
307
+
308
+ # and from relative [0, 1] to absolute [0, height] coordinates
309
+ img_h, img_w = target_sizes.unbind(1)
310
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
311
+ boxes = boxes * scale_fct[:, None, :]
312
+
313
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
314
+
315
+ return results
316
+
317
+
318
+ def infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
319
+ time_list = []
320
+ for img_dict in tqdm.tqdm(img_list):
321
+ image = load_image(os.path.join(prefix, img_dict['file_name']))
322
+ width, height = image.size
323
+ orig_target_sizes = torch.Tensor([height, width])
324
+ image_tensor, _ = infer_transforms()(image, None) # target is None
325
+
326
+ samples = image_tensor[None].numpy()
327
+
328
+ time_profile.reset()
329
+ with time_profile:
330
+ for _ in range(repeats):
331
+ res = sess.run(None, {"input": samples})
332
+ time_list.append(time_profile.total / repeats)
333
+ outputs = {}
334
+ outputs['labels'] = torch.Tensor(res[1]).to(device)
335
+ outputs['dets'] = torch.Tensor(res[0]).to(device)
336
+
337
+ orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
338
+ results = post_process(outputs, orig_target_sizes)
339
+ res = {img_dict['id']: results[0]}
340
+ if coco_evaluator is not None:
341
+ coco_evaluator.update(res)
342
+
343
+ print("Model latency with ONNX Runtime: {}ms".format(1000 * sum(time_list) / len(img_list)))
344
+
345
+ # accumulate predictions from all images
346
+ stats = {}
347
+ if coco_evaluator is not None:
348
+ coco_evaluator.synchronize_between_processes()
349
+ coco_evaluator.accumulate()
350
+ coco_evaluator.summarize()
351
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
352
+ print(stats)
353
+
354
+
355
+ def infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device, repeats=1):
356
+ time_list = []
357
+ for img_dict in tqdm.tqdm(img_list):
358
+ image = load_image(os.path.join(prefix, img_dict['file_name']))
359
+ width, height = image.size
360
+ orig_target_sizes = torch.Tensor([height, width])
361
+ image_tensor, _ = infer_transforms()(image, None) # target is None
362
+
363
+ samples = image_tensor[None].to(device)
364
+ _, _, h, w = samples.shape
365
+ im_shape = torch.Tensor(np.array([h, w]).reshape((1, 2)).astype(np.float32)).to(device)
366
+ scale_factor = torch.Tensor(np.array([h / height, w / width]).reshape((1, 2)).astype(np.float32)).to(device)
367
+
368
+ time_profile.reset()
369
+ with time_profile:
370
+ for _ in range(repeats):
371
+ outputs = model({"input": samples})
372
+
373
+ time_list.append(time_profile.total / repeats)
374
+ orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device)
375
+ if coco_evaluator is not None:
376
+ results = post_process(outputs, orig_target_sizes)
377
+ res = {img_dict['id']: results[0]}
378
+ coco_evaluator.update(res)
379
+
380
+ print("Model latency with TensorRT: {}ms".format(1000 * sum(time_list) / len(img_list)))
381
+
382
+ # accumulate predictions from all images
383
+ stats = {}
384
+ if coco_evaluator is not None:
385
+ coco_evaluator.synchronize_between_processes()
386
+ coco_evaluator.accumulate()
387
+ coco_evaluator.summarize()
388
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
389
+ print(stats)
390
+
391
+
392
+ class TRTInference(object):
393
+ """TensorRT inference engine
394
+ """
395
+ def __init__(self, engine_path='dino.engine', device='cuda:0', sync_mode:bool=False, max_batch_size=32, verbose=False):
396
+ self.engine_path = engine_path
397
+ self.device = device
398
+ self.sync_mode = sync_mode
399
+ self.max_batch_size = max_batch_size
400
+
401
+ self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
402
+
403
+ self.engine = self.load_engine(engine_path)
404
+
405
+ self.context = self.engine.create_execution_context()
406
+
407
+ self.bindings = self.get_bindings(self.engine, self.context, self.max_batch_size, self.device)
408
+ self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items())
409
+
410
+ self.input_names = self.get_input_names()
411
+ self.output_names = self.get_output_names()
412
+
413
+ if not self.sync_mode:
414
+ self.stream = cuda.Stream()
415
+
416
+ # self.time_profile = TimeProfiler()
417
+ self.time_profile = None
418
+
419
+ def get_dummy_input(self, batch_size:int):
420
+ blob = {}
421
+ for name, binding in self.bindings.items():
422
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
423
+ print(f"make dummy input {name} with shape {binding.shape}")
424
+ blob[name] = torch.rand(batch_size, *binding.shape[1:]).float().to('cuda:0')
425
+ return blob
426
+
427
+ def load_engine(self, path):
428
+ '''load engine
429
+ '''
430
+ trt.init_libnvinfer_plugins(self.logger, '')
431
+ with open(path, 'rb') as f, trt.Runtime(self.logger) as runtime:
432
+ return runtime.deserialize_cuda_engine(f.read())
433
+
434
+ def get_input_names(self, ):
435
+ names = []
436
+ for _, name in enumerate(self.engine):
437
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
438
+ names.append(name)
439
+ return names
440
+
441
+ def get_output_names(self, ):
442
+ names = []
443
+ for _, name in enumerate(self.engine):
444
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
445
+ names.append(name)
446
+ return names
447
+
448
+ def get_bindings(self, engine, context, max_batch_size=32, device=None):
449
+ '''build binddings
450
+ '''
451
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
452
+ bindings = OrderedDict()
453
+
454
+ for i, name in enumerate(engine):
455
+ shape = engine.get_tensor_shape(name)
456
+ dtype = trt.nptype(engine.get_tensor_dtype(name))
457
+
458
+ if shape[0] == -1:
459
+ raise NotImplementedError
460
+
461
+ if False:
462
+ if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
463
+ data = np.random.randn(*shape).astype(dtype)
464
+ ptr = cuda.mem_alloc(data.nbytes)
465
+ bindings[name] = Binding(name, dtype, shape, data, ptr)
466
+ else:
467
+ data = cuda.pagelocked_empty(trt.volume(shape), dtype)
468
+ ptr = cuda.mem_alloc(data.nbytes)
469
+ bindings[name] = Binding(name, dtype, shape, data, ptr)
470
+
471
+ else:
472
+ data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
473
+ bindings[name] = Binding(name, dtype, shape, data, data.data_ptr())
474
+
475
+ return bindings
476
+
477
+ def run_sync(self, blob):
478
+ self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
479
+ self.context.execute_v2(list(self.bindings_addr.values()))
480
+ outputs = {n: self.bindings[n].data for n in self.output_names}
481
+ return outputs
482
+
483
+ def run_async(self, blob):
484
+ self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
485
+ bindings_addr = [int(v) for _, v in self.bindings_addr.items()]
486
+ self.context.execute_async_v2(bindings=bindings_addr, stream_handle=self.stream.handle)
487
+ outputs = {n: self.bindings[n].data for n in self.output_names}
488
+ self.stream.synchronize()
489
+ return outputs
490
+
491
+ def __call__(self, blob):
492
+ if self.sync_mode:
493
+ return self.run_sync(blob)
494
+ else:
495
+ return self.run_async(blob)
496
+
497
+ def synchronize(self, ):
498
+ if not self.sync_mode and torch.cuda.is_available():
499
+ torch.cuda.synchronize()
500
+ elif self.sync_mode:
501
+ self.stream.synchronize()
502
+
503
+ def speed(self, blob, n):
504
+ self.time_profile.reset()
505
+ with self.time_profile:
506
+ for _ in range(n):
507
+ _ = self(blob)
508
+ return self.time_profile.total / n
509
+
510
+
511
+ def build_engine(self, onnx_file_path, engine_file_path, max_batch_size=32):
512
+ '''Takes an ONNX file and creates a TensorRT engine to run inference with
513
+ http://gitlab.baidu.com/paddle-inference/benchmark/blob/main/backend_trt.py#L57
514
+ '''
515
+ EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
516
+ with trt.Builder(self.logger) as builder, \
517
+ builder.create_network(EXPLICIT_BATCH) as network, \
518
+ trt.OnnxParser(network, self.logger) as parser, \
519
+ builder.create_builder_config() as config:
520
+
521
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1024 MiB
522
+ config.set_flag(trt.BuilderFlag.FP16)
523
+
524
+ with open(onnx_file_path, 'rb') as model:
525
+ if not parser.parse(model.read()):
526
+ print('ERROR: Failed to parse the ONNX file.')
527
+ for error in range(parser.num_errors):
528
+ print(parser.get_error(error))
529
+ return None
530
+
531
+ serialized_engine = builder.build_serialized_network(network, config)
532
+ with open(engine_file_path, 'wb') as f:
533
+ f.write(serialized_engine)
534
+
535
+ return serialized_engine
536
+
537
+
538
+ class TimeProfiler(contextlib.ContextDecorator):
539
+ def __init__(self, ):
540
+ self.total = 0
541
+
542
+ def __enter__(self, ):
543
+ self.start = self.time()
544
+ return self
545
+
546
+ def __exit__(self, type, value, traceback):
547
+ self.total += self.time() - self.start
548
+
549
+ def reset(self, ):
550
+ self.total = 0
551
+
552
+ def time(self, ):
553
+ if torch.cuda.is_available():
554
+ torch.cuda.synchronize()
555
+ return time.perf_counter()
556
+
557
+
558
+ def main(args):
559
+ print(args)
560
+
561
+ coco_gt = osp.join(args.coco_path, 'annotations/instances_val2017.json')
562
+ img_list = get_image_list(coco_gt)
563
+ prefix = osp.join(args.coco_path, 'val2017')
564
+ if args.run_benchmark:
565
+ repeats = 10
566
+ print('Inference for each image will be repeated 10 times to obtain '
567
+ 'a reliable measurement of inference latency.')
568
+ else:
569
+ repeats = 1
570
+
571
+ if args.disable_eval:
572
+ coco_evaluator = None
573
+ else:
574
+ coco_evaluator = CocoEvaluator(coco_gt, ('bbox',))
575
+
576
+ time_profile = TimeProfiler()
577
+
578
+ if args.path.endswith(".onnx"):
579
+ sess = nxrun.InferenceSession(args.path, providers=['CUDAExecutionProvider'])
580
+ infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
581
+ elif args.path.endswith(".engine"):
582
+ model = TRTInference(args.path, sync_mode=True, device=f'cuda:{args.device}')
583
+ infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats)
584
+ else:
585
+ raise NotImplementedError('Only model file names ending with ".onnx" and ".engine" are supported.')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ args = parser_args()
590
+ main(args)
rfdetr/deploy/export.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ export ONNX model and TensorRT engine for deployment
12
+ """
13
+ import os
14
+ import ast
15
+ import random
16
+ import argparse
17
+ import subprocess
18
+ import torch.nn as nn
19
+ from pathlib import Path
20
+ import time
21
+ from collections import defaultdict
22
+
23
+ import onnx
24
+ import torch
25
+ import onnxsim
26
+ import numpy as np
27
+ from PIL import Image
28
+
29
+ import rfdetr.util.misc as utils
30
+ import rfdetr.datasets.transforms as T
31
+ from rfdetr.models import build_model
32
+ from rfdetr.deploy._onnx import OnnxOptimizer
33
+ import re
34
+ import sys
35
+
36
+
37
+ def run_command_shell(command, dry_run:bool = False) -> int:
38
+ if dry_run:
39
+ print("")
40
+ print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} {command}")
41
+ print("")
42
+ try:
43
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
44
+ return result
45
+ except subprocess.CalledProcessError as e:
46
+ print(f"Command failed with exit code {e.returncode}")
47
+ print(f"Error output:\n{e.stderr.decode('utf-8')}")
48
+ raise
49
+
50
+
51
+ def make_infer_image(infer_dir, shape, batch_size, device="cuda"):
52
+ if infer_dir is None:
53
+ dummy = np.random.randint(0, 256, (shape[0], shape[1], 3), dtype=np.uint8)
54
+ image = Image.fromarray(dummy, mode="RGB")
55
+ else:
56
+ image = Image.open(infer_dir).convert("RGB")
57
+
58
+ transforms = T.Compose([
59
+ T.SquareResize([shape[0]]),
60
+ T.ToTensor(),
61
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
62
+ ])
63
+
64
+ inps, _ = transforms(image, None)
65
+ inps = inps.to(device)
66
+ # inps = utils.nested_tensor_from_tensor_list([inps for _ in range(args.batch_size)])
67
+ inps = torch.stack([inps for _ in range(batch_size)])
68
+ return inps
69
+
70
+ def export_onnx(output_dir, model, input_names, input_tensors, output_names, dynamic_axes, backbone_only=False, verbose=True, opset_version=17):
71
+ export_name = "backbone_model" if backbone_only else "inference_model"
72
+ output_file = os.path.join(output_dir, f"{export_name}.onnx")
73
+
74
+ # Prepare model for export
75
+ if hasattr(model, "export"):
76
+ model.export()
77
+
78
+ torch.onnx.export(
79
+ model,
80
+ input_tensors,
81
+ output_file,
82
+ input_names=input_names,
83
+ output_names=output_names,
84
+ export_params=True,
85
+ keep_initializers_as_inputs=False,
86
+ do_constant_folding=True,
87
+ verbose=verbose,
88
+ opset_version=opset_version,
89
+ dynamic_axes=dynamic_axes)
90
+
91
+ print(f'\nSuccessfully exported ONNX model: {output_file}')
92
+ return output_file
93
+
94
+
95
+ def onnx_simplify(onnx_dir:str, input_names, input_tensors, force=False):
96
+ sim_onnx_dir = onnx_dir.replace(".onnx", ".sim.onnx")
97
+ if os.path.isfile(sim_onnx_dir) and not force:
98
+ return sim_onnx_dir
99
+
100
+ if isinstance(input_tensors, torch.Tensor):
101
+ input_tensors = [input_tensors]
102
+
103
+ print(f'start simplify ONNX model: {onnx_dir}')
104
+ opt = OnnxOptimizer(onnx_dir)
105
+ opt.info('Model: original')
106
+ opt.common_opt()
107
+ opt.info('Model: optimized')
108
+ opt.save_onnx(sim_onnx_dir)
109
+ input_dict = {name: tensor.detach().cpu().numpy() for name, tensor in zip(input_names, input_tensors)}
110
+ model_opt, check_ok = onnxsim.simplify(
111
+ onnx_dir,
112
+ check_n = 3,
113
+ input_data=input_dict,
114
+ dynamic_input_shape=False)
115
+ if check_ok:
116
+ onnx.save(model_opt, sim_onnx_dir)
117
+ else:
118
+ raise RuntimeError("Failed to simplify ONNX model.")
119
+ print(f'Successfully simplified ONNX model: {sim_onnx_dir}')
120
+ return sim_onnx_dir
121
+
122
+
123
+ def trtexec(onnx_dir:str, args) -> None:
124
+ engine_dir = onnx_dir.replace(".onnx", f".engine")
125
+
126
+ # Base trtexec command
127
+ trt_command = " ".join([
128
+ "trtexec",
129
+ f"--onnx={onnx_dir}",
130
+ f"--saveEngine={engine_dir}",
131
+ f"--memPoolSize=workspace:4096 --fp16",
132
+ f"--useCudaGraph --useSpinWait --warmUp=500 --avgRuns=1000 --duration=10",
133
+ f"{'--verbose' if args.verbose else ''}"])
134
+
135
+ if args.profile:
136
+ profile_dir = onnx_dir.replace(".onnx", f".nsys-rep")
137
+ # Wrap with nsys profile command
138
+ command = " ".join([
139
+ "nsys profile",
140
+ f"--output={profile_dir}",
141
+ "--trace=cuda,nvtx",
142
+ "--force-overwrite true",
143
+ trt_command
144
+ ])
145
+ print(f'Profile data will be saved to: {profile_dir}')
146
+ else:
147
+ command = trt_command
148
+
149
+ output = run_command_shell(command, args.dry_run)
150
+ stats = parse_trtexec_output(output.stdout)
151
+
152
+ def parse_trtexec_output(output_text):
153
+ print(output_text)
154
+ # Common patterns in trtexec output
155
+ gpu_compute_pattern = r"GPU Compute Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms, median = (\d+\.\d+) ms"
156
+ h2d_pattern = r"Host to Device Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
157
+ d2h_pattern = r"Device to Host Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
158
+ latency_pattern = r"Latency: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms"
159
+ throughput_pattern = r"Throughput: (\d+\.\d+) qps"
160
+
161
+ stats = {}
162
+
163
+ # Extract compute times
164
+ if match := re.search(gpu_compute_pattern, output_text):
165
+ stats.update({
166
+ 'compute_min_ms': float(match.group(1)),
167
+ 'compute_max_ms': float(match.group(2)),
168
+ 'compute_mean_ms': float(match.group(3)),
169
+ 'compute_median_ms': float(match.group(4))
170
+ })
171
+
172
+ # Extract H2D times
173
+ if match := re.search(h2d_pattern, output_text):
174
+ stats.update({
175
+ 'h2d_min_ms': float(match.group(1)),
176
+ 'h2d_max_ms': float(match.group(2)),
177
+ 'h2d_mean_ms': float(match.group(3))
178
+ })
179
+
180
+ # Extract D2H times
181
+ if match := re.search(d2h_pattern, output_text):
182
+ stats.update({
183
+ 'd2h_min_ms': float(match.group(1)),
184
+ 'd2h_max_ms': float(match.group(2)),
185
+ 'd2h_mean_ms': float(match.group(3))
186
+ })
187
+
188
+ if match := re.search(latency_pattern, output_text):
189
+ stats.update({
190
+ 'latency_min_ms': float(match.group(1)),
191
+ 'latency_max_ms': float(match.group(2)),
192
+ 'latency_mean_ms': float(match.group(3))
193
+ })
194
+
195
+ # Extract throughput
196
+ if match := re.search(throughput_pattern, output_text):
197
+ stats['throughput_qps'] = float(match.group(1))
198
+
199
+ return stats
200
+
201
+ def no_batch_norm(model):
202
+ for module in model.modules():
203
+ if isinstance(module, nn.BatchNorm2d):
204
+ raise ValueError("BatchNorm2d found in the model. Please remove it.")
205
+
206
+ def main(args):
207
+ print("git:\n {}\n".format(utils.get_sha()))
208
+ print(args)
209
+ # convert device to device_id
210
+ if args.device == 'cuda':
211
+ device_id = "0"
212
+ elif args.device == 'cpu':
213
+ device_id = ""
214
+ else:
215
+ device_id = str(int(args.device))
216
+ args.device = f"cuda:{device_id}"
217
+
218
+ # device for export onnx
219
+ # TODO: export onnx with cuda failed with onnx error
220
+ device = torch.device("cpu")
221
+ os.environ["CUDA_VISIBLE_DEVICES"] = device_id
222
+
223
+ # fix the seed for reproducibility
224
+ seed = args.seed + utils.get_rank()
225
+ torch.manual_seed(seed)
226
+ np.random.seed(seed)
227
+ random.seed(seed)
228
+
229
+ model, criterion, postprocessors = build_model(args)
230
+ n_parameters = sum(p.numel() for p in model.parameters())
231
+ print(f"number of parameters: {n_parameters}")
232
+ n_backbone_parameters = sum(p.numel() for p in model.backbone.parameters())
233
+ print(f"number of backbone parameters: {n_backbone_parameters}")
234
+ n_projector_parameters = sum(p.numel() for p in model.backbone[0].projector.parameters())
235
+ print(f"number of projector parameters: {n_projector_parameters}")
236
+ n_backbone_encoder_parameters = sum(p.numel() for p in model.backbone[0].encoder.parameters())
237
+ print(f"number of backbone encoder parameters: {n_backbone_encoder_parameters}")
238
+ n_transformer_parameters = sum(p.numel() for p in model.transformer.parameters())
239
+ print(f"number of transformer parameters: {n_transformer_parameters}")
240
+ if args.resume:
241
+ checkpoint = torch.load(args.resume, map_location='cpu')
242
+ model.load_state_dict(checkpoint['model'], strict=True)
243
+ print(f"load checkpoints {args.resume}")
244
+
245
+ if args.layer_norm:
246
+ no_batch_norm(model)
247
+
248
+ model.to(device)
249
+
250
+ input_tensors = make_infer_image(args, device)
251
+ input_names = ['input']
252
+ output_names = ['features'] if args.backbone_only else ['dets', 'labels']
253
+ dynamic_axes = None
254
+ # Run model inference in pytorch mode
255
+ model.eval().to("cuda")
256
+ input_tensors = input_tensors.to("cuda")
257
+ with torch.no_grad():
258
+ if args.backbone_only:
259
+ features = model(input_tensors)
260
+ print(f"PyTorch inference output shape: {features.shape}")
261
+ else:
262
+ outputs = model(input_tensors)
263
+ dets = outputs['pred_boxes']
264
+ labels = outputs['pred_logits']
265
+ print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
266
+ model.cpu()
267
+ input_tensors = input_tensors.cpu()
268
+
269
+
270
+ output_file = export_onnx(model, args, input_names, input_tensors, output_names, dynamic_axes)
271
+
272
+ if args.simplify:
273
+ output_file = onnx_simplify(output_file, input_names, input_tensors, args)
274
+
275
+ if args.tensorrt:
276
+ output_file = trtexec(output_file, args)
rfdetr/detr.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ import json
9
+ import os
10
+ from collections import defaultdict
11
+ from logging import getLogger
12
+ from typing import Union, List
13
+ from copy import deepcopy
14
+
15
+ import numpy as np
16
+ import supervision as sv
17
+ import torch
18
+ import torchvision.transforms.functional as F
19
+ from PIL import Image
20
+
21
+ try:
22
+ torch.set_float32_matmul_precision('high')
23
+ except:
24
+ pass
25
+
26
+ from rfdetr.config import (
27
+ RFDETRBaseConfig,
28
+ RFDETRLargeConfig,
29
+ RFDETRNanoConfig,
30
+ RFDETRSmallConfig,
31
+ RFDETRMediumConfig,
32
+ TrainConfig,
33
+ ModelConfig
34
+ )
35
+ from rfdetr.main import Model, download_pretrain_weights
36
+ from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink
37
+ from rfdetr.util.coco_classes import COCO_CLASSES
38
+
39
+ logger = getLogger(__name__)
40
+ class RFDETR:
41
+ """
42
+ The base RF-DETR class implements the core methods for training RF-DETR models,
43
+ running inference on the models, optimising models, and uploading trained
44
+ models for deployment.
45
+ """
46
+ means = [0.485, 0.456, 0.406]
47
+ stds = [0.229, 0.224, 0.225]
48
+ size = None
49
+
50
+ def __init__(self, **kwargs):
51
+ self.model_config = self.get_model_config(**kwargs)
52
+ self.maybe_download_pretrain_weights()
53
+ self.model = self.get_model(self.model_config)
54
+ self.callbacks = defaultdict(list)
55
+
56
+ self.model.inference_model = None
57
+ self._is_optimized_for_inference = False
58
+ self._has_warned_about_not_being_optimized_for_inference = False
59
+ self._optimized_has_been_compiled = False
60
+ self._optimized_batch_size = None
61
+ self._optimized_resolution = None
62
+ self._optimized_dtype = None
63
+
64
+ def maybe_download_pretrain_weights(self):
65
+ """
66
+ Download pre-trained weights if they are not already downloaded.
67
+ """
68
+ download_pretrain_weights(self.model_config.pretrain_weights)
69
+
70
+ def get_model_config(self, **kwargs):
71
+ """
72
+ Retrieve the configuration parameters used by the model.
73
+ """
74
+ return ModelConfig(**kwargs)
75
+
76
+ def train(self, **kwargs):
77
+ """
78
+ Train an RF-DETR model.
79
+ """
80
+ config = self.get_train_config(**kwargs)
81
+ self.train_from_config(config, **kwargs)
82
+
83
+ def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
84
+ self.remove_optimized_model()
85
+
86
+ self.model.inference_model = deepcopy(self.model.model)
87
+ self.model.inference_model.eval()
88
+ self.model.inference_model.export()
89
+
90
+ self._optimized_resolution = self.model.resolution
91
+ self._is_optimized_for_inference = True
92
+
93
+ self.model.inference_model = self.model.inference_model.to(dtype=dtype)
94
+ self._optimized_dtype = dtype
95
+
96
+ if compile:
97
+ self.model.inference_model = torch.jit.trace(
98
+ self.model.inference_model,
99
+ torch.randn(
100
+ batch_size, 3, self.model.resolution, self.model.resolution,
101
+ device=self.model.device,
102
+ dtype=dtype
103
+ )
104
+ )
105
+ self._optimized_has_been_compiled = True
106
+ self._optimized_batch_size = batch_size
107
+
108
+ def remove_optimized_model(self):
109
+ self.model.inference_model = None
110
+ self._is_optimized_for_inference = False
111
+ self._optimized_has_been_compiled = False
112
+ self._optimized_batch_size = None
113
+ self._optimized_resolution = None
114
+ self._optimized_half = False
115
+
116
+ def export(self, **kwargs):
117
+ """
118
+ Export your model to an ONNX file.
119
+
120
+ See [the ONNX export documentation](https://rfdetr.roboflow.com/learn/train/#onnx-export) for more information.
121
+ """
122
+ self.model.export(**kwargs)
123
+
124
+ def train_from_config(self, config: TrainConfig, **kwargs):
125
+ with open(
126
+ os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r"
127
+ ) as f:
128
+ anns = json.load(f)
129
+ num_classes = len(anns["categories"])
130
+ class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"]
131
+ self.model.class_names = class_names
132
+
133
+ if self.model_config.num_classes != num_classes:
134
+ logger.warning(
135
+ f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n"
136
+ f"reinitializing your detection head with {num_classes} classes."
137
+ )
138
+ self.model.reinitialize_detection_head(num_classes)
139
+
140
+
141
+ train_config = config.dict()
142
+ model_config = self.model_config.dict()
143
+ model_config.pop("num_classes")
144
+ if "class_names" in model_config:
145
+ model_config.pop("class_names")
146
+
147
+ if "class_names" in train_config and train_config["class_names"] is None:
148
+ train_config["class_names"] = class_names
149
+
150
+ for k, v in train_config.items():
151
+ if k in model_config:
152
+ model_config.pop(k)
153
+ if k in kwargs:
154
+ kwargs.pop(k)
155
+
156
+ all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes}
157
+
158
+ metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)
159
+ self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update)
160
+ self.callbacks["on_train_end"].append(metrics_plot_sink.save)
161
+
162
+ if config.tensorboard:
163
+ metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=config.output_dir)
164
+ self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update)
165
+ self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close)
166
+
167
+ if config.wandb:
168
+ metrics_wandb_sink = MetricsWandBSink(
169
+ output_dir=config.output_dir,
170
+ project=config.project,
171
+ run=config.run,
172
+ config=config.model_dump()
173
+ )
174
+ self.callbacks["on_fit_epoch_end"].append(metrics_wandb_sink.update)
175
+ self.callbacks["on_train_end"].append(metrics_wandb_sink.close)
176
+
177
+ if config.early_stopping:
178
+ from rfdetr.util.early_stopping import EarlyStoppingCallback
179
+ early_stopping_callback = EarlyStoppingCallback(
180
+ model=self.model,
181
+ patience=config.early_stopping_patience,
182
+ min_delta=config.early_stopping_min_delta,
183
+ use_ema=config.early_stopping_use_ema
184
+ )
185
+ self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update)
186
+
187
+ self.model.train(
188
+ **all_kwargs,
189
+ callbacks=self.callbacks,
190
+ )
191
+
192
+ def get_train_config(self, **kwargs):
193
+ """
194
+ Retrieve the configuration parameters that will be used for training.
195
+ """
196
+ return TrainConfig(**kwargs)
197
+
198
+ def get_model(self, config: ModelConfig):
199
+ """
200
+ Retrieve a model instance based on the provided configuration.
201
+ """
202
+ return Model(**config.dict())
203
+
204
+ # Get class_names from the model
205
+ @property
206
+ def class_names(self):
207
+ """
208
+ Retrieve the class names supported by the loaded model.
209
+
210
+ Returns:
211
+ dict: A dictionary mapping class IDs to class names. The keys are integers starting from
212
+ """
213
+ if hasattr(self.model, 'class_names') and self.model.class_names:
214
+ return {i+1: name for i, name in enumerate(self.model.class_names)}
215
+
216
+ return COCO_CLASSES
217
+
218
+ def predict(
219
+ self,
220
+ images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]],
221
+ threshold: float = 0.5,
222
+ **kwargs,
223
+ ) -> Union[sv.Detections, List[sv.Detections]]:
224
+ """Performs object detection on the input images and returns bounding box
225
+ predictions.
226
+
227
+ This method accepts a single image or a list of images in various formats
228
+ (file path, PIL Image, NumPy array, or torch.Tensor). The images should be in
229
+ RGB channel order. If a torch.Tensor is provided, it must already be normalized
230
+ to values in the [0, 1] range and have the shape (C, H, W).
231
+
232
+ Args:
233
+ images (Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]]):
234
+ A single image or a list of images to process. Images can be provided
235
+ as file paths, PIL Images, NumPy arrays, or torch.Tensors.
236
+ threshold (float, optional):
237
+ The minimum confidence score needed to consider a detected bounding box valid.
238
+ **kwargs:
239
+ Additional keyword arguments.
240
+
241
+ Returns:
242
+ Union[sv.Detections, List[sv.Detections]]: A single or multiple Detections
243
+ objects, each containing bounding box coordinates, confidence scores,
244
+ and class IDs.
245
+ """
246
+ if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference:
247
+ logger.warning(
248
+ "Model is not optimized for inference. "
249
+ "Latency may be higher than expected. "
250
+ "You can optimize the model for inference by calling model.optimize_for_inference()."
251
+ )
252
+ self._has_warned_about_not_being_optimized_for_inference = True
253
+
254
+ self.model.model.eval()
255
+
256
+ if not isinstance(images, list):
257
+ images = [images]
258
+
259
+ orig_sizes = []
260
+ processed_images = []
261
+
262
+ for img in images:
263
+
264
+ if isinstance(img, str):
265
+ img = Image.open(img)
266
+
267
+ if not isinstance(img, torch.Tensor):
268
+ img = F.to_tensor(img)
269
+
270
+ if (img > 1).any():
271
+ raise ValueError(
272
+ "Image has pixel values above 1. Please ensure the image is "
273
+ "normalized (scaled to [0, 1])."
274
+ )
275
+ if img.shape[0] != 3:
276
+ raise ValueError(
277
+ f"Invalid image shape. Expected 3 channels (RGB), but got "
278
+ f"{img.shape[0]} channels."
279
+ )
280
+ img_tensor = img
281
+
282
+ h, w = img_tensor.shape[1:]
283
+ orig_sizes.append((h, w))
284
+
285
+ img_tensor = img_tensor.to(self.model.device)
286
+ img_tensor = F.normalize(img_tensor, self.means, self.stds)
287
+ img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution))
288
+
289
+ processed_images.append(img_tensor)
290
+
291
+ batch_tensor = torch.stack(processed_images)
292
+
293
+ if self._is_optimized_for_inference:
294
+ if self._optimized_resolution != batch_tensor.shape[2]:
295
+ # this could happen if someone manually changes self.model.resolution after optimizing the model
296
+ raise ValueError(f"Resolution mismatch. "
297
+ f"Model was optimized for resolution {self._optimized_resolution}, "
298
+ f"but got {batch_tensor.shape[2]}. "
299
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model().")
300
+ if self._optimized_has_been_compiled:
301
+ if self._optimized_batch_size != batch_tensor.shape[0]:
302
+ raise ValueError(f"Batch size mismatch. "
303
+ f"Optimized model was compiled for batch size {self._optimized_batch_size}, "
304
+ f"but got {batch_tensor.shape[0]}. "
305
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
306
+ "Alternatively, you can recompile the optimized model for a different batch size "
307
+ "by calling model.optimize_for_inference(batch_size=<new_batch_size>).")
308
+
309
+ with torch.inference_mode():
310
+ if self._is_optimized_for_inference:
311
+ predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
312
+ else:
313
+ predictions = self.model.model(batch_tensor)
314
+ if isinstance(predictions, tuple):
315
+ predictions = {
316
+ "pred_logits": predictions[1],
317
+ "pred_boxes": predictions[0]
318
+ }
319
+ target_sizes = torch.tensor(orig_sizes, device=self.model.device)
320
+ results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes)
321
+
322
+ detections_list = []
323
+ for result in results:
324
+ scores = result["scores"]
325
+ labels = result["labels"]
326
+ boxes = result["boxes"]
327
+
328
+ keep = scores > threshold
329
+ scores = scores[keep]
330
+ labels = labels[keep]
331
+ boxes = boxes[keep]
332
+
333
+ detections = sv.Detections(
334
+ xyxy=boxes.float().cpu().numpy(),
335
+ confidence=scores.float().cpu().numpy(),
336
+ class_id=labels.cpu().numpy(),
337
+ )
338
+ detections_list.append(detections)
339
+
340
+ return detections_list if len(detections_list) > 1 else detections_list[0]
341
+
342
+ def deploy_to_roboflow(self, workspace: str, project_id: str, version: str, api_key: str = None, size: str = None):
343
+ """
344
+ Deploy the trained RF-DETR model to Roboflow.
345
+
346
+ Deploying with Roboflow will create a Serverless API to which you can make requests.
347
+
348
+ You can also download weights into a Roboflow Inference deployment for use in Roboflow Workflows and on-device deployment.
349
+
350
+ Args:
351
+ workspace (str): The name of the Roboflow workspace to deploy to.
352
+ project_ids (List[str]): A list of project IDs to which the model will be deployed
353
+ api_key (str, optional): Your Roboflow API key. If not provided,
354
+ it will be read from the environment variable `ROBOFLOW_API_KEY`.
355
+ size (str, optional): The size of the model to deploy. If not provided,
356
+ it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.).
357
+ model_name (str, optional): The name you want to give the uploaded model.
358
+ If not provided, it will default to "<size>-uploaded".
359
+ Raises:
360
+ ValueError: If the `api_key` is not provided and not found in the environment
361
+ variable `ROBOFLOW_API_KEY`, or if the `size` is not set for custom architectures.
362
+ """
363
+ from roboflow import Roboflow
364
+ import shutil
365
+ if api_key is None:
366
+ api_key = os.getenv("ROBOFLOW_API_KEY")
367
+ if api_key is None:
368
+ raise ValueError("Set api_key=<KEY> in deploy_to_roboflow or export ROBOFLOW_API_KEY=<KEY>")
369
+
370
+
371
+ rf = Roboflow(api_key=api_key)
372
+ workspace = rf.workspace(workspace)
373
+
374
+ if self.size is None and size is None:
375
+ raise ValueError("Must set size for custom architectures")
376
+
377
+ size = self.size or size
378
+ tmp_out_dir = ".roboflow_temp_upload"
379
+ os.makedirs(tmp_out_dir, exist_ok=True)
380
+ outpath = os.path.join(tmp_out_dir, "weights.pt")
381
+ torch.save(
382
+ {
383
+ "model": self.model.model.state_dict(),
384
+ "args": self.model.args
385
+ }, outpath
386
+ )
387
+ project = workspace.project(project_id)
388
+ version = project.version(version)
389
+ version.deploy(
390
+ model_type=size,
391
+ model_path=tmp_out_dir,
392
+ filename="weights.pt"
393
+ )
394
+ shutil.rmtree(tmp_out_dir)
395
+
396
+
397
+
398
+ class RFDETRBase(RFDETR):
399
+ """
400
+ Train an RF-DETR Base model (29M parameters).
401
+ """
402
+ size = "rfdetr-base"
403
+ def get_model_config(self, **kwargs):
404
+ return RFDETRBaseConfig(**kwargs)
405
+
406
+ def get_train_config(self, **kwargs):
407
+ return TrainConfig(**kwargs)
408
+
409
+ class RFDETRLarge(RFDETR):
410
+ """
411
+ Train an RF-DETR Large model.
412
+ """
413
+ size = "rfdetr-large"
414
+ def get_model_config(self, **kwargs):
415
+ return RFDETRLargeConfig(**kwargs)
416
+
417
+ def get_train_config(self, **kwargs):
418
+ return TrainConfig(**kwargs)
419
+
420
+ class RFDETRNano(RFDETR):
421
+ """
422
+ Train an RF-DETR Nano model.
423
+ """
424
+ size = "rfdetr-nano"
425
+ def get_model_config(self, **kwargs):
426
+ return RFDETRNanoConfig(**kwargs)
427
+
428
+ def get_train_config(self, **kwargs):
429
+ return TrainConfig(**kwargs)
430
+
431
+ class RFDETRSmall(RFDETR):
432
+ """
433
+ Train an RF-DETR Small model.
434
+ """
435
+ size = "rfdetr-small"
436
+ def get_model_config(self, **kwargs):
437
+ return RFDETRSmallConfig(**kwargs)
438
+
439
+ def get_train_config(self, **kwargs):
440
+ return TrainConfig(**kwargs)
441
+
442
+ class RFDETRMedium(RFDETR):
443
+ """
444
+ Train an RF-DETR Medium model.
445
+ """
446
+ size = "rfdetr-medium"
447
+ def get_model_config(self, **kwargs):
448
+ return RFDETRMediumConfig(**kwargs)
449
+
450
+ def get_train_config(self, **kwargs):
451
+ return TrainConfig(**kwargs)
rfdetr/engine.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Conditional DETR
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
12
+ # ------------------------------------------------------------------------
13
+ # Copied from DETR (https://github.com/facebookresearch/detr)
14
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
15
+ # ------------------------------------------------------------------------
16
+
17
+ """
18
+ Train and eval functions used in main.py
19
+ """
20
+ import math
21
+ import sys
22
+ from typing import Iterable
23
+ import random
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ import rfdetr.util.misc as utils
29
+ from rfdetr.datasets.coco_eval import CocoEvaluator
30
+ from rfdetr.datasets.coco import compute_multi_scale_scales
31
+
32
+ try:
33
+ from torch.amp import autocast, GradScaler
34
+ DEPRECATED_AMP = False
35
+ except ImportError:
36
+ from torch.cuda.amp import autocast, GradScaler
37
+ DEPRECATED_AMP = True
38
+ from typing import DefaultDict, List, Callable
39
+ from rfdetr.util.misc import NestedTensor
40
+ import numpy as np
41
+
42
+ def get_autocast_args(args):
43
+ if DEPRECATED_AMP:
44
+ return {'enabled': args.amp, 'dtype': torch.bfloat16}
45
+ else:
46
+ return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}
47
+
48
+
49
+ def train_one_epoch(
50
+ model: torch.nn.Module,
51
+ criterion: torch.nn.Module,
52
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
53
+ data_loader: Iterable,
54
+ optimizer: torch.optim.Optimizer,
55
+ device: torch.device,
56
+ epoch: int,
57
+ batch_size: int,
58
+ max_norm: float = 0,
59
+ ema_m: torch.nn.Module = None,
60
+ schedules: dict = {},
61
+ num_training_steps_per_epoch=None,
62
+ vit_encoder_num_layers=None,
63
+ args=None,
64
+ callbacks: DefaultDict[str, List[Callable]] = None,
65
+ ):
66
+ metric_logger = utils.MetricLogger(delimiter=" ")
67
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
68
+ metric_logger.add_meter(
69
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
70
+ )
71
+ header = "Epoch: [{}]".format(epoch)
72
+ print_freq = 10
73
+ start_steps = epoch * num_training_steps_per_epoch
74
+
75
+ print("Grad accum steps: ", args.grad_accum_steps)
76
+ print("Total batch size: ", batch_size * utils.get_world_size())
77
+
78
+ # Add gradient scaler for AMP
79
+ if DEPRECATED_AMP:
80
+ scaler = GradScaler(enabled=args.amp)
81
+ else:
82
+ scaler = GradScaler('cuda', enabled=args.amp)
83
+
84
+ optimizer.zero_grad()
85
+ assert batch_size % args.grad_accum_steps == 0
86
+ sub_batch_size = batch_size // args.grad_accum_steps
87
+ print("LENGTH OF DATA LOADER:", len(data_loader))
88
+ for data_iter_step, (samples, targets) in enumerate(
89
+ metric_logger.log_every(data_loader, print_freq, header)
90
+ ):
91
+ it = start_steps + data_iter_step
92
+ callback_dict = {
93
+ "step": it,
94
+ "model": model,
95
+ "epoch": epoch,
96
+ }
97
+ for callback in callbacks["on_train_batch_start"]:
98
+ callback(callback_dict)
99
+ if "dp" in schedules:
100
+ if args.distributed:
101
+ model.module.update_drop_path(
102
+ schedules["dp"][it], vit_encoder_num_layers
103
+ )
104
+ else:
105
+ model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers)
106
+ if "do" in schedules:
107
+ if args.distributed:
108
+ model.module.update_dropout(schedules["do"][it])
109
+ else:
110
+ model.update_dropout(schedules["do"][it])
111
+
112
+ if args.multi_scale and not args.do_random_resize_via_padding:
113
+ scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows)
114
+ random.seed(it)
115
+ scale = random.choice(scales)
116
+ with torch.inference_mode():
117
+ samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False)
118
+ samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool()
119
+
120
+ for i in range(args.grad_accum_steps):
121
+ start_idx = i * sub_batch_size
122
+ final_idx = start_idx + sub_batch_size
123
+ new_samples_tensors = samples.tensors[start_idx:final_idx]
124
+ new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
125
+ new_samples = new_samples.to(device)
126
+ new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]
127
+
128
+ with autocast(**get_autocast_args(args)):
129
+ outputs = model(new_samples, new_targets)
130
+ loss_dict = criterion(outputs, new_targets)
131
+ weight_dict = criterion.weight_dict
132
+ losses = sum(
133
+ (1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
134
+ for k in loss_dict.keys()
135
+ if k in weight_dict
136
+ )
137
+
138
+
139
+ scaler.scale(losses).backward()
140
+
141
+ # reduce losses over all GPUs for logging purposes
142
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
143
+ loss_dict_reduced_unscaled = {
144
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
145
+ }
146
+ loss_dict_reduced_scaled = {
147
+ k: v * weight_dict[k]
148
+ for k, v in loss_dict_reduced.items()
149
+ if k in weight_dict
150
+ }
151
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
152
+
153
+ loss_value = losses_reduced_scaled.item()
154
+
155
+ if not math.isfinite(loss_value):
156
+ print(loss_dict_reduced)
157
+ raise ValueError("Loss is {}, stopping training".format(loss_value))
158
+
159
+ if max_norm > 0:
160
+ scaler.unscale_(optimizer)
161
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
162
+
163
+ scaler.step(optimizer)
164
+ scaler.update()
165
+ lr_scheduler.step()
166
+ optimizer.zero_grad()
167
+ if ema_m is not None:
168
+ if epoch >= 0:
169
+ ema_m.update(model)
170
+ metric_logger.update(
171
+ loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
172
+ )
173
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
174
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
175
+ # gather the stats from all processes
176
+ metric_logger.synchronize_between_processes()
177
+ print("Averaged stats:", metric_logger)
178
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
179
+
180
+
181
+ def coco_extended_metrics(coco_eval):
182
+ """
183
+ Safe version: ignores the –1 sentinel entries so precision/F1 never explode.
184
+ """
185
+
186
+ iou_thrs, rec_thrs = coco_eval.params.iouThrs, coco_eval.params.recThrs
187
+ iou50_idx, area_idx, maxdet_idx = (
188
+ int(np.argwhere(np.isclose(iou_thrs, 0.50))), 0, 2)
189
+
190
+ P = coco_eval.eval["precision"]
191
+ S = coco_eval.eval["scores"]
192
+
193
+ prec_raw = P[iou50_idx, :, :, area_idx, maxdet_idx]
194
+
195
+ prec = prec_raw.copy().astype(float)
196
+ prec[prec < 0] = np.nan
197
+
198
+ f1_cls = 2 * prec * rec_thrs[:, None] / (prec + rec_thrs[:, None])
199
+ f1_macro = np.nanmean(f1_cls, axis=1)
200
+
201
+ best_j = int(f1_macro.argmax())
202
+
203
+ macro_precision = float(np.nanmean(prec[best_j]))
204
+ macro_recall = float(rec_thrs[best_j])
205
+ macro_f1 = float(f1_macro[best_j])
206
+
207
+ score_vec = S[iou50_idx, best_j, :, area_idx, maxdet_idx].astype(float)
208
+ score_vec[prec_raw[best_j] < 0] = np.nan
209
+ score_thr = float(np.nanmean(score_vec))
210
+
211
+ map_50_95, map_50 = float(coco_eval.stats[0]), float(coco_eval.stats[1])
212
+
213
+ per_class = []
214
+ cat_ids = coco_eval.params.catIds
215
+ cat_id_to_name = {c["id"]: c["name"] for c in coco_eval.cocoGt.loadCats(cat_ids)}
216
+ for k, cid in enumerate(cat_ids):
217
+ p_slice = P[:, :, k, area_idx, maxdet_idx]
218
+ valid = p_slice > -1
219
+ ap_50_95 = float(p_slice[valid].mean()) if valid.any() else float("nan")
220
+ ap_50 = float(p_slice[iou50_idx][p_slice[iou50_idx] > -1].mean()) if (p_slice[iou50_idx] > -1).any() else float("nan")
221
+
222
+ pc = float(prec[best_j, k]) if prec_raw[best_j, k] > -1 else float("nan")
223
+ rc = macro_recall
224
+
225
+ #Doing to this to filter out dataset class
226
+ if np.isnan(ap_50_95) or np.isnan(ap_50) or np.isnan(pc) or np.isnan(rc):
227
+ continue
228
+
229
+ per_class.append({
230
+ "class" : cat_id_to_name[int(cid)],
231
+ "map@50:95" : ap_50_95,
232
+ "map@50" : ap_50,
233
+ "precision" : pc,
234
+ "recall" : rc,
235
+ })
236
+
237
+ per_class.append({
238
+ "class" : "all",
239
+ "map@50:95" : map_50_95,
240
+ "map@50" : map_50,
241
+ "precision" : macro_precision,
242
+ "recall" : macro_recall,
243
+ })
244
+
245
+ return {
246
+ "class_map": per_class,
247
+ "map" : map_50,
248
+ "precision": macro_precision,
249
+ "recall" : macro_recall
250
+ }
251
+
252
+ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
253
+ model.eval()
254
+ if args.fp16_eval:
255
+ model.half()
256
+ criterion.eval()
257
+
258
+ metric_logger = utils.MetricLogger(delimiter=" ")
259
+ metric_logger.add_meter(
260
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
261
+ )
262
+ header = "Test:"
263
+
264
+ iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
265
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
266
+
267
+ for samples, targets in metric_logger.log_every(data_loader, 10, header):
268
+ samples = samples.to(device)
269
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
270
+
271
+ if args.fp16_eval:
272
+ samples.tensors = samples.tensors.half()
273
+
274
+ # Add autocast for evaluation
275
+ with autocast(**get_autocast_args(args)):
276
+ outputs = model(samples)
277
+
278
+ if args.fp16_eval:
279
+ for key in outputs.keys():
280
+ if key == "enc_outputs":
281
+ for sub_key in outputs[key].keys():
282
+ outputs[key][sub_key] = outputs[key][sub_key].float()
283
+ elif key == "aux_outputs":
284
+ for idx in range(len(outputs[key])):
285
+ for sub_key in outputs[key][idx].keys():
286
+ outputs[key][idx][sub_key] = outputs[key][idx][
287
+ sub_key
288
+ ].float()
289
+ else:
290
+ outputs[key] = outputs[key].float()
291
+
292
+ loss_dict = criterion(outputs, targets)
293
+ weight_dict = criterion.weight_dict
294
+
295
+ # reduce losses over all GPUs for logging purposes
296
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
297
+ loss_dict_reduced_scaled = {
298
+ k: v * weight_dict[k]
299
+ for k, v in loss_dict_reduced.items()
300
+ if k in weight_dict
301
+ }
302
+ loss_dict_reduced_unscaled = {
303
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
304
+ }
305
+ metric_logger.update(
306
+ loss=sum(loss_dict_reduced_scaled.values()),
307
+ **loss_dict_reduced_scaled,
308
+ **loss_dict_reduced_unscaled,
309
+ )
310
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
311
+
312
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
313
+ results = postprocessors["bbox"](outputs, orig_target_sizes)
314
+ res = {
315
+ target["image_id"].item(): output
316
+ for target, output in zip(targets, results)
317
+ }
318
+ if coco_evaluator is not None:
319
+ coco_evaluator.update(res)
320
+
321
+ # gather the stats from all processes
322
+ metric_logger.synchronize_between_processes()
323
+ print("Averaged stats:", metric_logger)
324
+ if coco_evaluator is not None:
325
+ coco_evaluator.synchronize_between_processes()
326
+
327
+ # accumulate predictions from all images
328
+ if coco_evaluator is not None:
329
+ coco_evaluator.accumulate()
330
+ coco_evaluator.summarize()
331
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
332
+ if coco_evaluator is not None:
333
+ results_json = coco_extended_metrics(coco_evaluator.coco_eval["bbox"])
334
+ stats["results_json"] = results_json
335
+ if "bbox" in postprocessors.keys():
336
+ stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
337
+
338
+ if "segm" in postprocessors.keys():
339
+ stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
340
+ return stats, coco_evaluator
rfdetr/main.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Modified from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ cleaned main file
18
+ """
19
+ import argparse
20
+ import ast
21
+ import copy
22
+ import datetime
23
+ import json
24
+ import math
25
+ import os
26
+ import random
27
+ import shutil
28
+ import time
29
+ from copy import deepcopy
30
+ from logging import getLogger
31
+ from pathlib import Path
32
+ from typing import DefaultDict, List, Callable
33
+
34
+ import numpy as np
35
+ import torch
36
+ from peft import LoraConfig, get_peft_model
37
+ from torch.utils.data import DataLoader, DistributedSampler
38
+
39
+ import rfdetr.util.misc as utils
40
+ from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
41
+ from rfdetr.engine import evaluate, train_one_epoch
42
+ from rfdetr.models import build_model, build_criterion_and_postprocessors
43
+ from rfdetr.util.benchmark import benchmark
44
+ from rfdetr.util.drop_scheduler import drop_scheduler
45
+ from rfdetr.util.files import download_file
46
+ from rfdetr.util.get_param_dicts import get_param_dict
47
+ from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict
48
+
49
+ if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
50
+ import torch.multiprocessing
51
+ torch.multiprocessing.set_sharing_strategy('file_system')
52
+
53
+ logger = getLogger(__name__)
54
+
55
+ HOSTED_MODELS = {
56
+ "rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
57
+ # below is a less converged model that may be better for finetuning but worse for inference
58
+ "rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
59
+ "rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth",
60
+ "rf-detr-nano.pth": "https://storage.googleapis.com/rfdetr/nano_coco/checkpoint_best_regular.pth",
61
+ "rf-detr-small.pth": "https://storage.googleapis.com/rfdetr/small_coco/checkpoint_best_regular.pth",
62
+ "rf-detr-medium.pth": "https://storage.googleapis.com/rfdetr/medium_coco/checkpoint_best_regular.pth",
63
+ }
64
+
65
+ def download_pretrain_weights(pretrain_weights: str, redownload=False):
66
+ if pretrain_weights in HOSTED_MODELS:
67
+ if redownload or not os.path.exists(pretrain_weights):
68
+ logger.info(
69
+ f"Downloading pretrained weights for {pretrain_weights}"
70
+ )
71
+ download_file(
72
+ HOSTED_MODELS[pretrain_weights],
73
+ pretrain_weights,
74
+ )
75
+
76
+ class Model:
77
+ def __init__(self, **kwargs):
78
+ args = populate_args(**kwargs)
79
+ self.args = args
80
+ self.resolution = args.resolution
81
+ self.model = build_model(args)
82
+ self.device = torch.device(args.device)
83
+ if args.pretrain_weights is not None:
84
+ print("Loading pretrain weights")
85
+ try:
86
+ checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
87
+ except Exception as e:
88
+ print(f"Failed to load pretrain weights: {e}")
89
+ # re-download weights if they are corrupted
90
+ print("Failed to load pretrain weights, re-downloading")
91
+ download_pretrain_weights(args.pretrain_weights, redownload=True)
92
+ checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
93
+
94
+ # Extract class_names from checkpoint if available
95
+ if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'):
96
+ self.args.class_names = checkpoint['args'].class_names
97
+ self.class_names = checkpoint['args'].class_names
98
+
99
+ checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
100
+ if checkpoint_num_classes != args.num_classes + 1:
101
+ logger.warning(
102
+ f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n"
103
+ f"reinitializing detection head with {checkpoint_num_classes - 1} classes"
104
+ )
105
+ self.reinitialize_detection_head(checkpoint_num_classes)
106
+ # add support to exclude_keys
107
+ # e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]`
108
+ if args.pretrain_exclude_keys is not None:
109
+ assert isinstance(args.pretrain_exclude_keys, list)
110
+ for exclude_key in args.pretrain_exclude_keys:
111
+ checkpoint['model'].pop(exclude_key)
112
+ if args.pretrain_keys_modify_to_load is not None:
113
+ from util.obj365_to_coco_model import get_coco_pretrain_from_obj365
114
+ assert isinstance(args.pretrain_keys_modify_to_load, list)
115
+ for modify_key_to_load in args.pretrain_keys_modify_to_load:
116
+ try:
117
+ checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365(
118
+ model_without_ddp.state_dict()[modify_key_to_load],
119
+ checkpoint['model'][modify_key_to_load]
120
+ )
121
+ except:
122
+ print(f"Failed to load {modify_key_to_load}, deleting from checkpoint")
123
+ checkpoint['model'].pop(modify_key_to_load)
124
+
125
+ # we may want to resume training with a smaller number of groups for group detr
126
+ num_desired_queries = args.num_queries * args.group_detr
127
+ query_param_names = ["refpoint_embed.weight", "query_feat.weight"]
128
+ for name, state in checkpoint['model'].items():
129
+ if any(name.endswith(x) for x in query_param_names):
130
+ checkpoint['model'][name] = state[:num_desired_queries]
131
+
132
+ self.model.load_state_dict(checkpoint['model'], strict=False)
133
+
134
+ if args.backbone_lora:
135
+ print("Applying LORA to backbone")
136
+ lora_config = LoraConfig(
137
+ r=16,
138
+ lora_alpha=16,
139
+ use_dora=True,
140
+ target_modules=[
141
+ "q_proj", "v_proj", "k_proj", # covers OWL-ViT
142
+ "qkv", # covers open_clip ie Siglip2
143
+ "query", "key", "value", "cls_token", "register_tokens", # covers Dinov2 with windowed attn
144
+ ]
145
+ )
146
+ self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
147
+ self.model = self.model.to(self.device)
148
+ self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
149
+ self.stop_early = False
150
+
151
+ def reinitialize_detection_head(self, num_classes):
152
+ self.model.reinitialize_detection_head(num_classes)
153
+
154
+ def request_early_stop(self):
155
+ self.stop_early = True
156
+ print("Early stopping requested, will complete current epoch and stop")
157
+
158
+ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
159
+ currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
160
+ for key in callbacks.keys():
161
+ if key not in currently_supported_callbacks:
162
+ raise ValueError(
163
+ f"Callback {key} is not currently supported, please file an issue if you need it!\n"
164
+ f"Currently supported callbacks: {currently_supported_callbacks}"
165
+ )
166
+ args = populate_args(**kwargs)
167
+ if getattr(args, 'class_names') is not None:
168
+ self.args.class_names = args.class_names
169
+ self.args.num_classes = args.num_classes
170
+
171
+ utils.init_distributed_mode(args)
172
+ print("git:\n {}\n".format(utils.get_sha()))
173
+ print(args)
174
+ device = torch.device(args.device)
175
+
176
+ # fix the seed for reproducibility
177
+ seed = args.seed + utils.get_rank()
178
+ torch.manual_seed(seed)
179
+ np.random.seed(seed)
180
+ random.seed(seed)
181
+
182
+ criterion, postprocessors = build_criterion_and_postprocessors(args)
183
+ model = self.model
184
+ model.to(device)
185
+
186
+ model_without_ddp = model
187
+ if args.distributed:
188
+ if args.sync_bn:
189
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
190
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
191
+ model_without_ddp = model.module
192
+
193
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
194
+ print('number of params:', n_parameters)
195
+ param_dicts = get_param_dict(args, model_without_ddp)
196
+
197
+ param_dicts = [p for p in param_dicts if p['params'].requires_grad]
198
+
199
+ optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
200
+ weight_decay=args.weight_decay)
201
+ # Choose the learning rate scheduler based on the new argument
202
+
203
+ dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
204
+ dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)
205
+ dataset_test = build_dataset(image_set='test', args=args, resolution=args.resolution)
206
+
207
+ # for cosine annealing, calculate total training steps and warmup steps
208
+ total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps
209
+ num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr
210
+ total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs
211
+ warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs
212
+ def lr_lambda(current_step: int):
213
+ if current_step < warmup_steps_lr:
214
+ # Linear warmup
215
+ return float(current_step) / float(max(1, warmup_steps_lr))
216
+ else:
217
+ # Cosine annealing from multiplier 1.0 down to lr_min_factor
218
+ if args.lr_scheduler == 'cosine':
219
+ progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr))
220
+ return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress))
221
+ elif args.lr_scheduler == 'step':
222
+ if current_step < args.lr_drop * num_training_steps_per_epoch_lr:
223
+ return 1.0
224
+ else:
225
+ return 0.1
226
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
227
+
228
+ if args.distributed:
229
+ sampler_train = DistributedSampler(dataset_train)
230
+ sampler_val = DistributedSampler(dataset_val, shuffle=False)
231
+ sampler_test = DistributedSampler(dataset_test, shuffle=False)
232
+ else:
233
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
234
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
235
+ sampler_test = torch.utils.data.SequentialSampler(dataset_test)
236
+
237
+ effective_batch_size = args.batch_size * args.grad_accum_steps
238
+ min_batches = kwargs.get('min_batches', 5)
239
+ if len(dataset_train) < effective_batch_size * min_batches:
240
+ logger.info(
241
+ f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}"
242
+ )
243
+ sampler = torch.utils.data.RandomSampler(
244
+ dataset_train,
245
+ replacement=True,
246
+ num_samples=effective_batch_size * min_batches,
247
+ )
248
+ data_loader_train = DataLoader(
249
+ dataset_train,
250
+ batch_size=effective_batch_size,
251
+ collate_fn=utils.collate_fn,
252
+ num_workers=args.num_workers,
253
+ sampler=sampler,
254
+ )
255
+ else:
256
+ batch_sampler_train = torch.utils.data.BatchSampler(
257
+ sampler_train, effective_batch_size, drop_last=True)
258
+ data_loader_train = DataLoader(
259
+ dataset_train,
260
+ batch_sampler=batch_sampler_train,
261
+ collate_fn=utils.collate_fn,
262
+ num_workers=args.num_workers
263
+ )
264
+
265
+ data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
266
+ drop_last=False, collate_fn=utils.collate_fn,
267
+ num_workers=args.num_workers)
268
+ data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test,
269
+ drop_last=False, collate_fn=utils.collate_fn,
270
+ num_workers=args.num_workers)
271
+
272
+ base_ds = get_coco_api_from_dataset(dataset_val)
273
+ base_ds_test = get_coco_api_from_dataset(dataset_test)
274
+ if args.use_ema:
275
+ self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau)
276
+ else:
277
+ self.ema_m = None
278
+
279
+
280
+ output_dir = Path(args.output_dir)
281
+
282
+ if utils.is_main_process():
283
+ print("Get benchmark")
284
+ if args.do_benchmark:
285
+ benchmark_model = copy.deepcopy(model_without_ddp)
286
+ bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
287
+ print(json.dumps(bm, indent=2))
288
+ del benchmark_model
289
+
290
+ if args.resume:
291
+ checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
292
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
293
+ if args.use_ema:
294
+ if 'ema_model' in checkpoint:
295
+ self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
296
+ else:
297
+ del self.ema_m
298
+ self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau)
299
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
300
+ optimizer.load_state_dict(checkpoint['optimizer'])
301
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
302
+ args.start_epoch = checkpoint['epoch'] + 1
303
+
304
+ if args.eval:
305
+ test_stats, coco_evaluator = evaluate(
306
+ model, criterion, postprocessors, data_loader_val, base_ds, device, args)
307
+ if args.output_dir:
308
+ utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
309
+ return
310
+
311
+ # for drop
312
+ total_batch_size = effective_batch_size * utils.get_world_size()
313
+ num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size
314
+ schedules = {}
315
+ if args.dropout > 0:
316
+ schedules['do'] = drop_scheduler(
317
+ args.dropout, args.epochs, num_training_steps_per_epoch,
318
+ args.cutoff_epoch, args.drop_mode, args.drop_schedule)
319
+ print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
320
+
321
+ if args.drop_path > 0:
322
+ schedules['dp'] = drop_scheduler(
323
+ args.drop_path, args.epochs, num_training_steps_per_epoch,
324
+ args.cutoff_epoch, args.drop_mode, args.drop_schedule)
325
+ print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
326
+
327
+ print("Start training")
328
+ start_time = time.time()
329
+ best_map_holder = BestMetricHolder(use_ema=args.use_ema)
330
+ best_map_5095 = 0
331
+ best_map_50 = 0
332
+ best_map_ema_5095 = 0
333
+ best_map_ema_50 = 0
334
+ for epoch in range(args.start_epoch, args.epochs):
335
+ epoch_start_time = time.time()
336
+ if args.distributed:
337
+ sampler_train.set_epoch(epoch)
338
+
339
+ model.train()
340
+ criterion.train()
341
+ train_stats = train_one_epoch(
342
+ model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch,
343
+ effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules,
344
+ num_training_steps_per_epoch=num_training_steps_per_epoch,
345
+ vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks)
346
+ train_epoch_time = time.time() - epoch_start_time
347
+ train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time)))
348
+ if args.output_dir:
349
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
350
+ # extra checkpoint before LR drop and every `checkpoint_interval` epochs
351
+ if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0:
352
+ checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
353
+ for checkpoint_path in checkpoint_paths:
354
+ weights = {
355
+ 'model': model_without_ddp.state_dict(),
356
+ 'optimizer': optimizer.state_dict(),
357
+ 'lr_scheduler': lr_scheduler.state_dict(),
358
+ 'epoch': epoch,
359
+ 'args': args,
360
+ }
361
+ if args.use_ema:
362
+ weights.update({
363
+ 'ema_model': self.ema_m.module.state_dict(),
364
+ })
365
+ if not args.dont_save_weights:
366
+ # create checkpoint dir
367
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
368
+
369
+ utils.save_on_master(weights, checkpoint_path)
370
+
371
+ with torch.inference_mode():
372
+ test_stats, coco_evaluator = evaluate(
373
+ model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
374
+ )
375
+ map_regular = test_stats["coco_eval_bbox"][0]
376
+ _isbest = best_map_holder.update(map_regular, epoch, is_ema=False)
377
+ if _isbest:
378
+ best_map_5095 = max(best_map_5095, map_regular)
379
+ best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1])
380
+ checkpoint_path = output_dir / 'checkpoint_best_regular.pth'
381
+ if not args.dont_save_weights:
382
+ utils.save_on_master({
383
+ 'model': model_without_ddp.state_dict(),
384
+ 'optimizer': optimizer.state_dict(),
385
+ 'lr_scheduler': lr_scheduler.state_dict(),
386
+ 'epoch': epoch,
387
+ 'args': args,
388
+ }, checkpoint_path)
389
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
390
+ **{f'test_{k}': v for k, v in test_stats.items()},
391
+ 'epoch': epoch,
392
+ 'n_parameters': n_parameters}
393
+ if args.use_ema:
394
+ ema_test_stats, _ = evaluate(
395
+ self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
396
+ )
397
+ log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
398
+ map_ema = ema_test_stats["coco_eval_bbox"][0]
399
+ best_map_ema_5095 = max(best_map_ema_5095, map_ema)
400
+ _isbest = best_map_holder.update(map_ema, epoch, is_ema=True)
401
+ if _isbest:
402
+ best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1])
403
+ checkpoint_path = output_dir / 'checkpoint_best_ema.pth'
404
+ if not args.dont_save_weights:
405
+ utils.save_on_master({
406
+ 'model': self.ema_m.module.state_dict(),
407
+ 'optimizer': optimizer.state_dict(),
408
+ 'lr_scheduler': lr_scheduler.state_dict(),
409
+ 'epoch': epoch,
410
+ 'args': args,
411
+ }, checkpoint_path)
412
+ log_stats.update(best_map_holder.summary())
413
+
414
+ # epoch parameters
415
+ ep_paras = {
416
+ 'epoch': epoch,
417
+ 'n_parameters': n_parameters
418
+ }
419
+ log_stats.update(ep_paras)
420
+ try:
421
+ log_stats.update({'now_time': str(datetime.datetime.now())})
422
+ except:
423
+ pass
424
+ log_stats['train_epoch_time'] = train_epoch_time_str
425
+ epoch_time = time.time() - epoch_start_time
426
+ epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
427
+ log_stats['epoch_time'] = epoch_time_str
428
+ if args.output_dir and utils.is_main_process():
429
+ with (output_dir / "log.txt").open("a") as f:
430
+ f.write(json.dumps(log_stats) + "\n")
431
+
432
+ # for evaluation logs
433
+ if coco_evaluator is not None:
434
+ (output_dir / 'eval').mkdir(exist_ok=True)
435
+ if "bbox" in coco_evaluator.coco_eval:
436
+ filenames = ['latest.pth']
437
+ if epoch % 50 == 0:
438
+ filenames.append(f'{epoch:03}.pth')
439
+ for name in filenames:
440
+ torch.save(coco_evaluator.coco_eval["bbox"].eval,
441
+ output_dir / "eval" / name)
442
+
443
+ for callback in callbacks["on_fit_epoch_end"]:
444
+ callback(log_stats)
445
+
446
+ if self.stop_early:
447
+ print(f"Early stopping requested, stopping at epoch {epoch}")
448
+ break
449
+
450
+ best_is_ema = best_map_ema_5095 > best_map_5095
451
+
452
+ if utils.is_main_process():
453
+ if best_is_ema:
454
+ shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
455
+ else:
456
+ shutil.copy2(output_dir / 'checkpoint_best_regular.pth', output_dir / 'checkpoint_best_total.pth')
457
+
458
+ utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth')
459
+
460
+ best_map_5095 = max(best_map_5095, best_map_ema_5095)
461
+ if best_is_ema:
462
+ results = ema_test_stats["results_json"]
463
+ else:
464
+ results = test_stats["results_json"]
465
+
466
+ class_map = results["class_map"]
467
+ results["class_map"] = {"valid": class_map}
468
+ with open(output_dir / "results.json", "w") as f:
469
+ json.dump(results, f)
470
+
471
+ total_time = time.time() - start_time
472
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
473
+ print('Training time {}'.format(total_time_str))
474
+ print('Results saved to {}'.format(output_dir / "results.json"))
475
+
476
+
477
+ if best_is_ema:
478
+ self.model = self.ema_m.module
479
+ self.model.eval()
480
+
481
+
482
+ if args.run_test:
483
+ best_state_dict = torch.load(output_dir / 'checkpoint_best_total.pth', map_location='cpu', weights_only=False)['model']
484
+ model.load_state_dict(best_state_dict)
485
+ model.eval()
486
+
487
+ test_stats, _ = evaluate(
488
+ model, criterion, postprocessors, data_loader_test, base_ds_test, device, args=args
489
+ )
490
+ print(f"Test results: {test_stats}")
491
+ with open(output_dir / "results.json", "r") as f:
492
+ results = json.load(f)
493
+ test_metrics = test_stats["results_json"]["class_map"]
494
+ results["class_map"]["test"] = test_metrics
495
+ with open(output_dir / "results.json", "w") as f:
496
+ json.dump(results, f)
497
+
498
+ for callback in callbacks["on_train_end"]:
499
+ callback()
500
+
501
+ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs):
502
+ """Export the trained model to ONNX format"""
503
+ print(f"Exporting model to ONNX format")
504
+ try:
505
+ from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image
506
+ except ImportError:
507
+ print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.")
508
+ raise
509
+
510
+
511
+ device = self.device
512
+ model = deepcopy(self.model.to("cpu"))
513
+ model.to(device)
514
+
515
+ os.makedirs(output_dir, exist_ok=True)
516
+ output_dir = Path(output_dir)
517
+ if shape is None:
518
+ shape = (self.resolution, self.resolution)
519
+ else:
520
+ if shape[0] % 14 != 0 or shape[1] % 14 != 0:
521
+ raise ValueError("Shape must be divisible by 14")
522
+
523
+ input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
524
+ input_names = ['input']
525
+ output_names = ['features'] if backbone_only else ['dets', 'labels']
526
+ dynamic_axes = None
527
+ self.model.eval()
528
+ with torch.no_grad():
529
+ if backbone_only:
530
+ features = model(input_tensors)
531
+ print(f"PyTorch inference output shape: {features.shape}")
532
+ else:
533
+ outputs = model(input_tensors)
534
+ dets = outputs['pred_boxes']
535
+ labels = outputs['pred_logits']
536
+ print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}")
537
+ model.cpu()
538
+ input_tensors = input_tensors.cpu()
539
+
540
+ # Export to ONNX
541
+ output_file = export_onnx(
542
+ output_dir=output_dir,
543
+ model=model,
544
+ input_names=input_names,
545
+ input_tensors=input_tensors,
546
+ output_names=output_names,
547
+ dynamic_axes=dynamic_axes,
548
+ backbone_only=backbone_only,
549
+ verbose=verbose,
550
+ opset_version=opset_version
551
+ )
552
+
553
+ print(f"Successfully exported ONNX model to: {output_file}")
554
+
555
+ if simplify:
556
+ sim_output_file = onnx_simplify(
557
+ onnx_dir=output_file,
558
+ input_names=input_names,
559
+ input_tensors=input_tensors,
560
+ force=force
561
+ )
562
+ print(f"Successfully simplified ONNX model to: {sim_output_file}")
563
+
564
+ print("ONNX export completed successfully")
565
+ self.model = self.model.to(device)
566
+
567
+
568
+ if __name__ == '__main__':
569
+ parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()])
570
+ args = parser.parse_args()
571
+
572
+ if args.output_dir:
573
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
574
+
575
+ config = vars(args) # Convert Namespace to dictionary
576
+
577
+ if args.subcommand == 'distill':
578
+ distill(**config)
579
+ elif args.subcommand is None:
580
+ main(**config)
581
+ elif args.subcommand == 'export_model':
582
+ filter_keys = [
583
+ "num_classes",
584
+ "grad_accum_steps",
585
+ "lr",
586
+ "lr_encoder",
587
+ "weight_decay",
588
+ "epochs",
589
+ "lr_drop",
590
+ "clip_max_norm",
591
+ "lr_vit_layer_decay",
592
+ "lr_component_decay",
593
+ "dropout",
594
+ "drop_path",
595
+ "drop_mode",
596
+ "drop_schedule",
597
+ "cutoff_epoch",
598
+ "pretrained_encoder",
599
+ "pretrain_weights",
600
+ "pretrain_exclude_keys",
601
+ "pretrain_keys_modify_to_load",
602
+ "freeze_florence",
603
+ "freeze_aimv2",
604
+ "decoder_norm",
605
+ "set_cost_class",
606
+ "set_cost_bbox",
607
+ "set_cost_giou",
608
+ "cls_loss_coef",
609
+ "bbox_loss_coef",
610
+ "giou_loss_coef",
611
+ "focal_alpha",
612
+ "aux_loss",
613
+ "sum_group_losses",
614
+ "use_varifocal_loss",
615
+ "use_position_supervised_loss",
616
+ "ia_bce_loss",
617
+ "dataset_file",
618
+ "coco_path",
619
+ "dataset_dir",
620
+ "square_resize_div_64",
621
+ "output_dir",
622
+ "checkpoint_interval",
623
+ "seed",
624
+ "resume",
625
+ "start_epoch",
626
+ "eval",
627
+ "use_ema",
628
+ "ema_decay",
629
+ "ema_tau",
630
+ "num_workers",
631
+ "device",
632
+ "world_size",
633
+ "dist_url",
634
+ "sync_bn",
635
+ "fp16_eval",
636
+ "infer_dir",
637
+ "verbose",
638
+ "opset_version",
639
+ "dry_run",
640
+ "shape",
641
+ ]
642
+ for key in filter_keys:
643
+ config.pop(key, None) # Use pop with None to avoid KeyError
644
+
645
+ from deploy.export import main as export_main
646
+ if args.batch_size != 1:
647
+ config['batch_size'] = 1
648
+ print(f"Only batch_size 1 is supported for onnx export, \
649
+ but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.")
650
+ export_main(**config)
651
+
652
+ def get_args_parser():
653
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
654
+ parser.add_argument('--num_classes', default=2, type=int)
655
+ parser.add_argument('--grad_accum_steps', default=1, type=int)
656
+ parser.add_argument('--amp', default=False, type=bool)
657
+ parser.add_argument('--lr', default=1e-4, type=float)
658
+ parser.add_argument('--lr_encoder', default=1.5e-4, type=float)
659
+ parser.add_argument('--batch_size', default=2, type=int)
660
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
661
+ parser.add_argument('--epochs', default=12, type=int)
662
+ parser.add_argument('--lr_drop', default=11, type=int)
663
+ parser.add_argument('--clip_max_norm', default=0.1, type=float,
664
+ help='gradient clipping max norm')
665
+ parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float)
666
+ parser.add_argument('--lr_component_decay', default=1.0, type=float)
667
+ parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model')
668
+
669
+ # drop args
670
+ # dropout and stochastic depth drop rate; set at most one to non-zero
671
+ parser.add_argument('--dropout', type=float, default=0,
672
+ help='Drop path rate (default: 0.0)')
673
+ parser.add_argument('--drop_path', type=float, default=0,
674
+ help='Drop path rate (default: 0.0)')
675
+
676
+ # early / late dropout and stochastic depth settings
677
+ parser.add_argument('--drop_mode', type=str, default='standard',
678
+ choices=['standard', 'early', 'late'], help='drop mode')
679
+ parser.add_argument('--drop_schedule', type=str, default='constant',
680
+ choices=['constant', 'linear'],
681
+ help='drop schedule for early dropout / s.d. only')
682
+ parser.add_argument('--cutoff_epoch', type=int, default=0,
683
+ help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
684
+
685
+ # Model parameters
686
+ parser.add_argument('--pretrained_encoder', type=str, default=None,
687
+ help="Path to the pretrained encoder.")
688
+ parser.add_argument('--pretrain_weights', type=str, default=None,
689
+ help="Path to the pretrained model.")
690
+ parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
691
+ help="Keys you do not want to load.")
692
+ parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+',
693
+ help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.")
694
+
695
+ # * Backbone
696
+ parser.add_argument('--encoder', default='vit_tiny', type=str,
697
+ help="Name of the transformer or convolutional encoder to use")
698
+ parser.add_argument('--vit_encoder_num_layers', default=12, type=int,
699
+ help="Number of layers used in ViT encoder")
700
+ parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+')
701
+ parser.add_argument('--position_embedding', default='sine', type=str,
702
+ choices=('sine', 'learned'),
703
+ help="Type of positional embedding to use on top of the image features")
704
+ parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now')
705
+ parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder")
706
+ parser.add_argument("--layer_norm", action="store_true", dest="layer_norm")
707
+ parser.add_argument("--rms_norm", action="store_true", dest="rms_norm")
708
+ parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora")
709
+ parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain")
710
+
711
+ # * Transformer
712
+ parser.add_argument('--dec_layers', default=3, type=int,
713
+ help="Number of decoding layers in the transformer")
714
+ parser.add_argument('--dim_feedforward', default=2048, type=int,
715
+ help="Intermediate size of the feedforward layers in the transformer blocks")
716
+ parser.add_argument('--hidden_dim', default=256, type=int,
717
+ help="Size of the embeddings (dimension of the transformer)")
718
+ parser.add_argument('--sa_nheads', default=8, type=int,
719
+ help="Number of attention heads inside the transformer's self-attentions")
720
+ parser.add_argument('--ca_nheads', default=8, type=int,
721
+ help="Number of attention heads inside the transformer's cross-attentions")
722
+ parser.add_argument('--num_queries', default=300, type=int,
723
+ help="Number of query slots")
724
+ parser.add_argument('--group_detr', default=13, type=int,
725
+ help="Number of groups to speed up detr training")
726
+ parser.add_argument('--two_stage', action='store_true')
727
+ parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6'))
728
+ parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up')
729
+ parser.add_argument('--num_select', default=100, type=int,
730
+ help='the number of predictions selected for evaluation')
731
+ parser.add_argument('--dec_n_points', default=4, type=int,
732
+ help='the number of sampling points')
733
+ parser.add_argument('--decoder_norm', default='LN', type=str)
734
+ parser.add_argument('--bbox_reparam', action='store_true')
735
+ parser.add_argument('--freeze_batch_norm', action='store_true')
736
+ # * Matcher
737
+ parser.add_argument('--set_cost_class', default=2, type=float,
738
+ help="Class coefficient in the matching cost")
739
+ parser.add_argument('--set_cost_bbox', default=5, type=float,
740
+ help="L1 box coefficient in the matching cost")
741
+ parser.add_argument('--set_cost_giou', default=2, type=float,
742
+ help="giou box coefficient in the matching cost")
743
+
744
+ # * Loss coefficients
745
+ parser.add_argument('--cls_loss_coef', default=2, type=float)
746
+ parser.add_argument('--bbox_loss_coef', default=5, type=float)
747
+ parser.add_argument('--giou_loss_coef', default=2, type=float)
748
+ parser.add_argument('--focal_alpha', default=0.25, type=float)
749
+
750
+ # Loss
751
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
752
+ help="Disables auxiliary decoding losses (loss at each layer)")
753
+ parser.add_argument('--sum_group_losses', action='store_true',
754
+ help="To sum losses across groups or mean losses.")
755
+ parser.add_argument('--use_varifocal_loss', action='store_true')
756
+ parser.add_argument('--use_position_supervised_loss', action='store_true')
757
+ parser.add_argument('--ia_bce_loss', action='store_true')
758
+
759
+ # dataset parameters
760
+ parser.add_argument('--dataset_file', default='coco')
761
+ parser.add_argument('--coco_path', type=str)
762
+ parser.add_argument('--dataset_dir', type=str)
763
+ parser.add_argument('--square_resize_div_64', action='store_true')
764
+
765
+ parser.add_argument('--output_dir', default='output',
766
+ help='path where to save, empty for no saving')
767
+ parser.add_argument('--dont_save_weights', action='store_true')
768
+ parser.add_argument('--checkpoint_interval', default=10, type=int,
769
+ help='epoch interval to save checkpoint')
770
+ parser.add_argument('--seed', default=42, type=int)
771
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
772
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
773
+ help='start epoch')
774
+ parser.add_argument('--eval', action='store_true')
775
+ parser.add_argument('--use_ema', action='store_true')
776
+ parser.add_argument('--ema_decay', default=0.9997, type=float)
777
+ parser.add_argument('--ema_tau', default=0, type=float)
778
+
779
+ parser.add_argument('--num_workers', default=2, type=int)
780
+
781
+ # distributed training parameters
782
+ parser.add_argument('--device', default='cuda',
783
+ help='device to use for training / testing')
784
+ parser.add_argument('--world_size', default=1, type=int,
785
+ help='number of distributed processes')
786
+ parser.add_argument('--dist_url', default='env://',
787
+ help='url used to set up distributed training')
788
+ parser.add_argument('--sync_bn', default=True, type=bool,
789
+ help='setup synchronized BatchNorm for distributed training')
790
+
791
+ # fp16
792
+ parser.add_argument('--fp16_eval', default=False, action='store_true',
793
+ help='evaluate in fp16 precision.')
794
+
795
+ # custom args
796
+ parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only')
797
+ parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only')
798
+ parser.add_argument('--resolution', type=int, default=640, help="input resolution")
799
+ parser.add_argument('--use_cls_token', action='store_true', help='use cls token')
800
+ parser.add_argument('--multi_scale', action='store_true', help='use multi scale')
801
+ parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales')
802
+ parser.add_argument('--do_random_resize_via_padding', action='store_true', help='use random resize via padding')
803
+ parser.add_argument('--warmup_epochs', default=1, type=float,
804
+ help='Number of warmup epochs for linear warmup before cosine annealing')
805
+ # Add scheduler type argument: 'step' or 'cosine'
806
+ parser.add_argument(
807
+ '--lr_scheduler',
808
+ default='step',
809
+ choices=['step', 'cosine'],
810
+ help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'"
811
+ )
812
+ parser.add_argument('--lr_min_factor', default=0.0, type=float,
813
+ help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
814
+ # Early stopping parameters
815
+ parser.add_argument('--early_stopping', action='store_true',
816
+ help='Enable early stopping based on mAP improvement')
817
+ parser.add_argument('--early_stopping_patience', default=10, type=int,
818
+ help='Number of epochs with no improvement after which training will be stopped')
819
+ parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
820
+ help='Minimum change in mAP to qualify as an improvement')
821
+ parser.add_argument('--early_stopping_use_ema', action='store_true',
822
+ help='Use EMA model metrics for early stopping')
823
+ # subparsers
824
+ subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
825
+ description='valid subcommands', help='additional help')
826
+
827
+ # subparser for export model
828
+ parser_export = subparsers.add_parser('export_model', help='LWDETR model export')
829
+ parser_export.add_argument('--infer_dir', type=str, default=None)
830
+ parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True)
831
+ parser_export.add_argument('--opset_version', type=int, default=17)
832
+ parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model")
833
+ parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true',
834
+ help="build tensorrt engine")
835
+ parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command")
836
+ parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export')
837
+ parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)")
838
+ return parser
839
+
840
+ def populate_args(
841
+ # Basic training parameters
842
+ num_classes=2,
843
+ grad_accum_steps=1,
844
+ amp=False,
845
+ lr=1e-4,
846
+ lr_encoder=1.5e-4,
847
+ batch_size=2,
848
+ weight_decay=1e-4,
849
+ epochs=12,
850
+ lr_drop=11,
851
+ clip_max_norm=0.1,
852
+ lr_vit_layer_decay=0.8,
853
+ lr_component_decay=1.0,
854
+ do_benchmark=False,
855
+
856
+ # Drop parameters
857
+ dropout=0,
858
+ drop_path=0,
859
+ drop_mode='standard',
860
+ drop_schedule='constant',
861
+ cutoff_epoch=0,
862
+
863
+ # Model parameters
864
+ pretrained_encoder=None,
865
+ pretrain_weights=None,
866
+ pretrain_exclude_keys=None,
867
+ pretrain_keys_modify_to_load=None,
868
+ pretrained_distiller=None,
869
+
870
+ # Backbone parameters
871
+ encoder='vit_tiny',
872
+ vit_encoder_num_layers=12,
873
+ window_block_indexes=None,
874
+ position_embedding='sine',
875
+ out_feature_indexes=[-1],
876
+ freeze_encoder=False,
877
+ layer_norm=False,
878
+ rms_norm=False,
879
+ backbone_lora=False,
880
+ force_no_pretrain=False,
881
+
882
+ # Transformer parameters
883
+ dec_layers=3,
884
+ dim_feedforward=2048,
885
+ hidden_dim=256,
886
+ sa_nheads=8,
887
+ ca_nheads=8,
888
+ num_queries=300,
889
+ group_detr=13,
890
+ two_stage=False,
891
+ projector_scale='P4',
892
+ lite_refpoint_refine=False,
893
+ num_select=100,
894
+ dec_n_points=4,
895
+ decoder_norm='LN',
896
+ bbox_reparam=False,
897
+ freeze_batch_norm=False,
898
+
899
+ # Matcher parameters
900
+ set_cost_class=2,
901
+ set_cost_bbox=5,
902
+ set_cost_giou=2,
903
+
904
+ # Loss coefficients
905
+ cls_loss_coef=2,
906
+ bbox_loss_coef=5,
907
+ giou_loss_coef=2,
908
+ focal_alpha=0.25,
909
+ aux_loss=True,
910
+ sum_group_losses=False,
911
+ use_varifocal_loss=False,
912
+ use_position_supervised_loss=False,
913
+ ia_bce_loss=False,
914
+
915
+ # Dataset parameters
916
+ dataset_file='coco',
917
+ coco_path=None,
918
+ dataset_dir=None,
919
+ square_resize_div_64=False,
920
+
921
+ # Output parameters
922
+ output_dir='output',
923
+ dont_save_weights=False,
924
+ checkpoint_interval=10,
925
+ seed=42,
926
+ resume='',
927
+ start_epoch=0,
928
+ eval=False,
929
+ use_ema=False,
930
+ ema_decay=0.9997,
931
+ ema_tau=0,
932
+ num_workers=2,
933
+
934
+ # Distributed training parameters
935
+ device='cuda',
936
+ world_size=1,
937
+ dist_url='env://',
938
+ sync_bn=True,
939
+
940
+ # FP16
941
+ fp16_eval=False,
942
+
943
+ # Custom args
944
+ encoder_only=False,
945
+ backbone_only=False,
946
+ resolution=640,
947
+ use_cls_token=False,
948
+ multi_scale=False,
949
+ expanded_scales=False,
950
+ do_random_resize_via_padding=False,
951
+ warmup_epochs=1,
952
+ lr_scheduler='step',
953
+ lr_min_factor=0.0,
954
+ # Early stopping parameters
955
+ early_stopping=True,
956
+ early_stopping_patience=10,
957
+ early_stopping_min_delta=0.001,
958
+ early_stopping_use_ema=False,
959
+ gradient_checkpointing=False,
960
+ # Additional
961
+ subcommand=None,
962
+ **extra_kwargs # To handle any unexpected arguments
963
+ ):
964
+ args = argparse.Namespace(
965
+ num_classes=num_classes,
966
+ grad_accum_steps=grad_accum_steps,
967
+ amp=amp,
968
+ lr=lr,
969
+ lr_encoder=lr_encoder,
970
+ batch_size=batch_size,
971
+ weight_decay=weight_decay,
972
+ epochs=epochs,
973
+ lr_drop=lr_drop,
974
+ clip_max_norm=clip_max_norm,
975
+ lr_vit_layer_decay=lr_vit_layer_decay,
976
+ lr_component_decay=lr_component_decay,
977
+ do_benchmark=do_benchmark,
978
+ dropout=dropout,
979
+ drop_path=drop_path,
980
+ drop_mode=drop_mode,
981
+ drop_schedule=drop_schedule,
982
+ cutoff_epoch=cutoff_epoch,
983
+ pretrained_encoder=pretrained_encoder,
984
+ pretrain_weights=pretrain_weights,
985
+ pretrain_exclude_keys=pretrain_exclude_keys,
986
+ pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
987
+ pretrained_distiller=pretrained_distiller,
988
+ encoder=encoder,
989
+ vit_encoder_num_layers=vit_encoder_num_layers,
990
+ window_block_indexes=window_block_indexes,
991
+ position_embedding=position_embedding,
992
+ out_feature_indexes=out_feature_indexes,
993
+ freeze_encoder=freeze_encoder,
994
+ layer_norm=layer_norm,
995
+ rms_norm=rms_norm,
996
+ backbone_lora=backbone_lora,
997
+ force_no_pretrain=force_no_pretrain,
998
+ dec_layers=dec_layers,
999
+ dim_feedforward=dim_feedforward,
1000
+ hidden_dim=hidden_dim,
1001
+ sa_nheads=sa_nheads,
1002
+ ca_nheads=ca_nheads,
1003
+ num_queries=num_queries,
1004
+ group_detr=group_detr,
1005
+ two_stage=two_stage,
1006
+ projector_scale=projector_scale,
1007
+ lite_refpoint_refine=lite_refpoint_refine,
1008
+ num_select=num_select,
1009
+ dec_n_points=dec_n_points,
1010
+ decoder_norm=decoder_norm,
1011
+ bbox_reparam=bbox_reparam,
1012
+ freeze_batch_norm=freeze_batch_norm,
1013
+ set_cost_class=set_cost_class,
1014
+ set_cost_bbox=set_cost_bbox,
1015
+ set_cost_giou=set_cost_giou,
1016
+ cls_loss_coef=cls_loss_coef,
1017
+ bbox_loss_coef=bbox_loss_coef,
1018
+ giou_loss_coef=giou_loss_coef,
1019
+ focal_alpha=focal_alpha,
1020
+ aux_loss=aux_loss,
1021
+ sum_group_losses=sum_group_losses,
1022
+ use_varifocal_loss=use_varifocal_loss,
1023
+ use_position_supervised_loss=use_position_supervised_loss,
1024
+ ia_bce_loss=ia_bce_loss,
1025
+ dataset_file=dataset_file,
1026
+ coco_path=coco_path,
1027
+ dataset_dir=dataset_dir,
1028
+ square_resize_div_64=square_resize_div_64,
1029
+ output_dir=output_dir,
1030
+ dont_save_weights=dont_save_weights,
1031
+ checkpoint_interval=checkpoint_interval,
1032
+ seed=seed,
1033
+ resume=resume,
1034
+ start_epoch=start_epoch,
1035
+ eval=eval,
1036
+ use_ema=use_ema,
1037
+ ema_decay=ema_decay,
1038
+ ema_tau=ema_tau,
1039
+ num_workers=num_workers,
1040
+ device=device,
1041
+ world_size=world_size,
1042
+ dist_url=dist_url,
1043
+ sync_bn=sync_bn,
1044
+ fp16_eval=fp16_eval,
1045
+ encoder_only=encoder_only,
1046
+ backbone_only=backbone_only,
1047
+ resolution=resolution,
1048
+ use_cls_token=use_cls_token,
1049
+ multi_scale=multi_scale,
1050
+ expanded_scales=expanded_scales,
1051
+ do_random_resize_via_padding=do_random_resize_via_padding,
1052
+ warmup_epochs=warmup_epochs,
1053
+ lr_scheduler=lr_scheduler,
1054
+ lr_min_factor=lr_min_factor,
1055
+ early_stopping=early_stopping,
1056
+ early_stopping_patience=early_stopping_patience,
1057
+ early_stopping_min_delta=early_stopping_min_delta,
1058
+ early_stopping_use_ema=early_stopping_use_ema,
1059
+ gradient_checkpointing=gradient_checkpointing,
1060
+ **extra_kwargs
1061
+ )
1062
+ return args
rfdetr/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Copied from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ from .lwdetr import build_model, build_criterion_and_postprocessors
rfdetr/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (309 Bytes). View file
 
rfdetr/models/__pycache__/lwdetr.cpython-313.pyc ADDED
Binary file (39.9 kB). View file
 
rfdetr/models/__pycache__/matcher.cpython-313.pyc ADDED
Binary file (6.68 kB). View file
 
rfdetr/models/__pycache__/position_encoding.cpython-313.pyc ADDED
Binary file (8.77 kB). View file
 
rfdetr/models/__pycache__/transformer.cpython-313.pyc ADDED
Binary file (29.4 kB). View file
 
rfdetr/models/backbone/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ from typing import Dict, List
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from rfdetr.util.misc import NestedTensor
16
+ from rfdetr.models.position_encoding import build_position_encoding
17
+ from rfdetr.models.backbone.backbone import *
18
+ from typing import Callable
19
+
20
+ class Joiner(nn.Sequential):
21
+ def __init__(self, backbone, position_embedding):
22
+ super().__init__(backbone, position_embedding)
23
+ self._export = False
24
+
25
+ def forward(self, tensor_list: NestedTensor):
26
+ """ """
27
+ x = self[0](tensor_list)
28
+ pos = []
29
+ for x_ in x:
30
+ pos.append(self[1](x_, align_dim_orders=False).to(x_.tensors.dtype))
31
+ return x, pos
32
+
33
+ def export(self):
34
+ self._export = True
35
+ self._forward_origin = self.forward
36
+ self.forward = self.forward_export
37
+ for name, m in self.named_modules():
38
+ if (
39
+ hasattr(m, "export")
40
+ and isinstance(m.export, Callable)
41
+ and hasattr(m, "_export")
42
+ and not m._export
43
+ ):
44
+ m.export()
45
+
46
+ def forward_export(self, inputs: torch.Tensor):
47
+ feats, masks = self[0](inputs)
48
+ poss = []
49
+ for feat, mask in zip(feats, masks):
50
+ poss.append(self[1](mask, align_dim_orders=False).to(feat.dtype))
51
+ return feats, None, poss
52
+
53
+
54
+ def build_backbone(
55
+ encoder,
56
+ vit_encoder_num_layers,
57
+ pretrained_encoder,
58
+ window_block_indexes,
59
+ drop_path,
60
+ out_channels,
61
+ out_feature_indexes,
62
+ projector_scale,
63
+ use_cls_token,
64
+ hidden_dim,
65
+ position_embedding,
66
+ freeze_encoder,
67
+ layer_norm,
68
+ target_shape,
69
+ rms_norm,
70
+ backbone_lora,
71
+ force_no_pretrain,
72
+ gradient_checkpointing,
73
+ load_dinov2_weights,
74
+ patch_size,
75
+ num_windows,
76
+ positional_encoding_size,
77
+ ):
78
+ """
79
+ Useful args:
80
+ - encoder: encoder name
81
+ - lr_encoder:
82
+ - dilation
83
+ - use_checkpoint: for swin only for now
84
+
85
+ """
86
+ position_embedding = build_position_encoding(hidden_dim, position_embedding)
87
+
88
+ backbone = Backbone(
89
+ encoder,
90
+ pretrained_encoder,
91
+ window_block_indexes=window_block_indexes,
92
+ drop_path=drop_path,
93
+ out_channels=out_channels,
94
+ out_feature_indexes=out_feature_indexes,
95
+ projector_scale=projector_scale,
96
+ use_cls_token=use_cls_token,
97
+ layer_norm=layer_norm,
98
+ freeze_encoder=freeze_encoder,
99
+ target_shape=target_shape,
100
+ rms_norm=rms_norm,
101
+ backbone_lora=backbone_lora,
102
+ gradient_checkpointing=gradient_checkpointing,
103
+ load_dinov2_weights=load_dinov2_weights,
104
+ patch_size=patch_size,
105
+ num_windows=num_windows,
106
+ positional_encoding_size=positional_encoding_size,
107
+ )
108
+
109
+ model = Joiner(backbone, position_embedding)
110
+ return model
rfdetr/models/backbone/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (3.85 kB). View file
 
rfdetr/models/backbone/__pycache__/backbone.cpython-313.pyc ADDED
Binary file (8.13 kB). View file
 
rfdetr/models/backbone/__pycache__/base.cpython-313.pyc ADDED
Binary file (1.05 kB). View file
 
rfdetr/models/backbone/__pycache__/dinov2.cpython-313.pyc ADDED
Binary file (8.66 kB). View file
 
rfdetr/models/backbone/__pycache__/dinov2_with_windowed_attn.cpython-313.pyc ADDED
Binary file (61.2 kB). View file
 
rfdetr/models/backbone/__pycache__/projector.cpython-313.pyc ADDED
Binary file (15 kB). View file
 
rfdetr/models/backbone/backbone.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ Backbone modules.
18
+ """
19
+ from functools import partial
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+ from transformers import AutoModel, AutoProcessor, AutoModelForCausalLM, AutoConfig, AutoBackbone
25
+ from peft import LoraConfig, get_peft_model, PeftModel
26
+
27
+ from rfdetr.util.misc import NestedTensor, is_main_process
28
+
29
+ from rfdetr.models.backbone.base import BackboneBase
30
+ from rfdetr.models.backbone.projector import MultiScaleProjector
31
+ from rfdetr.models.backbone.dinov2 import DinoV2
32
+
33
+ __all__ = ["Backbone"]
34
+
35
+
36
+ class Backbone(BackboneBase):
37
+ """backbone."""
38
+ def __init__(self,
39
+ name: str,
40
+ pretrained_encoder: str=None,
41
+ window_block_indexes: list=None,
42
+ drop_path=0.0,
43
+ out_channels=256,
44
+ out_feature_indexes: list=None,
45
+ projector_scale: list=None,
46
+ use_cls_token: bool = False,
47
+ freeze_encoder: bool = False,
48
+ layer_norm: bool = False,
49
+ target_shape: tuple[int, int] = (640, 640),
50
+ rms_norm: bool = False,
51
+ backbone_lora: bool = False,
52
+ gradient_checkpointing: bool = False,
53
+ load_dinov2_weights: bool = True,
54
+ patch_size: int = 14,
55
+ num_windows: int = 4,
56
+ positional_encoding_size: bool = False,
57
+ ):
58
+ super().__init__()
59
+ # an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
60
+ # if "registers" is in the name, then use_registers is set to True, otherwise it is set to False
61
+ # similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False
62
+ # the last part of the name should be the size
63
+ # and the start should be dinov2
64
+ name_parts = name.split("_")
65
+ assert name_parts[0] == "dinov2"
66
+ size = name_parts[-1]
67
+ use_registers = False
68
+ if "registers" in name_parts:
69
+ use_registers = True
70
+ name_parts.remove("registers")
71
+ use_windowed_attn = False
72
+ if "windowed" in name_parts:
73
+ use_windowed_attn = True
74
+ name_parts.remove("windowed")
75
+ assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size"
76
+ self.encoder = DinoV2(
77
+ size=name_parts[-1],
78
+ out_feature_indexes=out_feature_indexes,
79
+ shape=target_shape,
80
+ use_registers=use_registers,
81
+ use_windowed_attn=use_windowed_attn,
82
+ gradient_checkpointing=gradient_checkpointing,
83
+ load_dinov2_weights=load_dinov2_weights,
84
+ patch_size=patch_size,
85
+ num_windows=num_windows,
86
+ positional_encoding_size=positional_encoding_size,
87
+ )
88
+ # build encoder + projector as backbone module
89
+ if freeze_encoder:
90
+ for param in self.encoder.parameters():
91
+ param.requires_grad = False
92
+
93
+ self.projector_scale = projector_scale
94
+ assert len(self.projector_scale) > 0
95
+ # x[0]
96
+ assert (
97
+ sorted(self.projector_scale) == self.projector_scale
98
+ ), "only support projector scale P3/P4/P5/P6 in ascending order."
99
+ level2scalefactor = dict(P3=2.0, P4=1.0, P5=0.5, P6=0.25)
100
+ scale_factors = [level2scalefactor[lvl] for lvl in self.projector_scale]
101
+
102
+ self.projector = MultiScaleProjector(
103
+ in_channels=self.encoder._out_feature_channels,
104
+ out_channels=out_channels,
105
+ scale_factors=scale_factors,
106
+ layer_norm=layer_norm,
107
+ rms_norm=rms_norm,
108
+ )
109
+
110
+ self._export = False
111
+
112
+ def export(self):
113
+ self._export = True
114
+ self._forward_origin = self.forward
115
+ self.forward = self.forward_export
116
+
117
+ if isinstance(self.encoder, PeftModel):
118
+ print("Merging and unloading LoRA weights")
119
+ self.encoder.merge_and_unload()
120
+
121
+ def forward(self, tensor_list: NestedTensor):
122
+ """ """
123
+ # (H, W, B, C)
124
+ feats = self.encoder(tensor_list.tensors)
125
+ feats = self.projector(feats)
126
+ # x: [(B, C, H, W)]
127
+ out = []
128
+ for feat in feats:
129
+ m = tensor_list.mask
130
+ assert m is not None
131
+ mask = F.interpolate(m[None].float(), size=feat.shape[-2:]).to(torch.bool)[
132
+ 0
133
+ ]
134
+ out.append(NestedTensor(feat, mask))
135
+ return out
136
+
137
+ def forward_export(self, tensors: torch.Tensor):
138
+ feats = self.encoder(tensors)
139
+ feats = self.projector(feats)
140
+ out_feats = []
141
+ out_masks = []
142
+ for feat in feats:
143
+ # x: [(B, C, H, W)]
144
+ b, _, h, w = feat.shape
145
+ out_masks.append(
146
+ torch.zeros((b, h, w), dtype=torch.bool, device=feat.device)
147
+ )
148
+ out_feats.append(feat)
149
+ return out_feats, out_masks
150
+
151
+ def get_named_param_lr_pairs(self, args, prefix: str = "backbone.0"):
152
+ num_layers = args.out_feature_indexes[-1] + 1
153
+ backbone_key = "backbone.0.encoder"
154
+ named_param_lr_pairs = {}
155
+ for n, p in self.named_parameters():
156
+ n = prefix + "." + n
157
+ if backbone_key in n and p.requires_grad:
158
+ lr = (
159
+ args.lr_encoder
160
+ * get_dinov2_lr_decay_rate(
161
+ n,
162
+ lr_decay_rate=args.lr_vit_layer_decay,
163
+ num_layers=num_layers,
164
+ )
165
+ * args.lr_component_decay**2
166
+ )
167
+ wd = args.weight_decay * get_dinov2_weight_decay_rate(n)
168
+ named_param_lr_pairs[n] = {
169
+ "params": p,
170
+ "lr": lr,
171
+ "weight_decay": wd,
172
+ }
173
+ return named_param_lr_pairs
174
+
175
+
176
+ def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
177
+ """
178
+ Calculate lr decay rate for different ViT blocks.
179
+
180
+ Args:
181
+ name (string): parameter name.
182
+ lr_decay_rate (float): base lr decay rate.
183
+ num_layers (int): number of ViT blocks.
184
+ Returns:
185
+ lr decay rate for the given parameter.
186
+ """
187
+ layer_id = num_layers + 1
188
+ if name.startswith("backbone"):
189
+ if "embeddings" in name:
190
+ layer_id = 0
191
+ elif ".layer." in name and ".residual." not in name:
192
+ layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1
193
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
194
+
195
+ def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0):
196
+ if (
197
+ ("gamma" in name)
198
+ or ("pos_embed" in name)
199
+ or ("rel_pos" in name)
200
+ or ("bias" in name)
201
+ or ("norm" in name)
202
+ or ("embeddings" in name)
203
+ ):
204
+ weight_decay_rate = 0.0
205
+ return weight_decay_rate
rfdetr/models/backbone/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+
15
+ class BackboneBase(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def get_named_param_lr_pairs(self, args, prefix:str):
20
+ raise NotImplementedError
rfdetr/models/backbone/dinov2.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import AutoBackbone
10
+ import torch.nn.functional as F
11
+ import types
12
+ import math
13
+ import json
14
+ import os
15
+
16
+ from .dinov2_with_windowed_attn import WindowedDinov2WithRegistersConfig, WindowedDinov2WithRegistersBackbone
17
+
18
+
19
+ size_to_width = {
20
+ "tiny": 192,
21
+ "small": 384,
22
+ "base": 768,
23
+ "large": 1024,
24
+ }
25
+
26
+ size_to_config = {
27
+ "small": "dinov2_small.json",
28
+ "base": "dinov2_base.json",
29
+ "large": "dinov2_large.json",
30
+ }
31
+
32
+ size_to_config_with_registers = {
33
+ "small": "dinov2_with_registers_small.json",
34
+ "base": "dinov2_with_registers_base.json",
35
+ "large": "dinov2_with_registers_large.json",
36
+ }
37
+
38
+ def get_config(size, use_registers):
39
+ config_dict = size_to_config_with_registers if use_registers else size_to_config
40
+ current_dir = os.path.dirname(os.path.abspath(__file__))
41
+ configs_dir = os.path.join(current_dir, "dinov2_configs")
42
+ config_path = os.path.join(configs_dir, config_dict[size])
43
+ with open(config_path, "r") as f:
44
+ dino_config = json.load(f)
45
+ return dino_config
46
+
47
+
48
+ class DinoV2(nn.Module):
49
+ def __init__(self,
50
+ shape=(640, 640),
51
+ out_feature_indexes=[2, 4, 5, 9],
52
+ size="base",
53
+ use_registers=True,
54
+ use_windowed_attn=True,
55
+ gradient_checkpointing=False,
56
+ load_dinov2_weights=True,
57
+ patch_size=14,
58
+ num_windows=4,
59
+ positional_encoding_size=37,
60
+ ):
61
+ super().__init__()
62
+
63
+ name = f"facebook/dinov2-with-registers-{size}" if use_registers else f"facebook/dinov2-{size}"
64
+
65
+ self.shape = shape
66
+ self.patch_size = patch_size
67
+ self.num_windows = num_windows
68
+
69
+ # Create the encoder
70
+
71
+ if not use_windowed_attn:
72
+ assert not gradient_checkpointing, "Gradient checkpointing is not supported for non-windowed attention"
73
+ assert load_dinov2_weights, "Using non-windowed attention requires loading dinov2 weights from hub"
74
+ self.encoder = AutoBackbone.from_pretrained(
75
+ name,
76
+ out_features=[f"stage{i}" for i in out_feature_indexes],
77
+ return_dict=False,
78
+ )
79
+ else:
80
+ window_block_indexes = set(range(out_feature_indexes[-1] + 1))
81
+ window_block_indexes.difference_update(out_feature_indexes)
82
+ window_block_indexes = list(window_block_indexes)
83
+
84
+ dino_config = get_config(size, use_registers)
85
+
86
+ dino_config["return_dict"] = False
87
+ dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes]
88
+
89
+ implied_resolution = positional_encoding_size * patch_size
90
+
91
+ if implied_resolution != dino_config["image_size"]:
92
+ print(f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.")
93
+ dino_config["image_size"] = implied_resolution
94
+ load_dinov2_weights = False
95
+
96
+ if patch_size != 14:
97
+ print(f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.")
98
+ dino_config["patch_size"] = patch_size
99
+ load_dinov2_weights = False
100
+
101
+ if use_registers:
102
+ windowed_dino_config = WindowedDinov2WithRegistersConfig(
103
+ **dino_config,
104
+ num_windows=num_windows,
105
+ window_block_indexes=window_block_indexes,
106
+ gradient_checkpointing=gradient_checkpointing,
107
+ )
108
+ else:
109
+ windowed_dino_config = WindowedDinov2WithRegistersConfig(
110
+ **dino_config,
111
+ num_windows=num_windows,
112
+ window_block_indexes=window_block_indexes,
113
+ num_register_tokens=0,
114
+ gradient_checkpointing=gradient_checkpointing,
115
+ )
116
+ self.encoder = WindowedDinov2WithRegistersBackbone.from_pretrained(
117
+ name,
118
+ config=windowed_dino_config,
119
+ ) if load_dinov2_weights else WindowedDinov2WithRegistersBackbone(windowed_dino_config)
120
+
121
+
122
+ self._out_feature_channels = [size_to_width[size]] * len(out_feature_indexes)
123
+ self._export = False
124
+
125
+ def export(self):
126
+ if self._export:
127
+ return
128
+ self._export = True
129
+ shape = self.shape
130
+ def make_new_interpolated_pos_encoding(
131
+ position_embeddings, patch_size, height, width
132
+ ):
133
+
134
+ num_positions = position_embeddings.shape[1] - 1
135
+ dim = position_embeddings.shape[-1]
136
+ height = height // patch_size
137
+ width = width // patch_size
138
+
139
+ class_pos_embed = position_embeddings[:, 0]
140
+ patch_pos_embed = position_embeddings[:, 1:]
141
+
142
+ # Reshape and permute
143
+ patch_pos_embed = patch_pos_embed.reshape(
144
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
145
+ )
146
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
147
+
148
+ # Use bilinear interpolation without antialias
149
+ patch_pos_embed = F.interpolate(
150
+ patch_pos_embed,
151
+ size=(height, width),
152
+ mode="bicubic",
153
+ align_corners=False,
154
+ antialias=True,
155
+ )
156
+
157
+ # Reshape back
158
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
159
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
160
+
161
+ # If the shape of self.encoder.embeddings.position_embeddings
162
+ # matches the shape of your new tensor, use copy_:
163
+ with torch.no_grad():
164
+ new_positions = make_new_interpolated_pos_encoding(
165
+ self.encoder.embeddings.position_embeddings,
166
+ self.encoder.config.patch_size,
167
+ shape[0],
168
+ shape[1],
169
+ )
170
+ # Create a new Parameter with the new size
171
+ old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding
172
+ def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
173
+ num_patches = embeddings.shape[1] - 1
174
+ num_positions = self_mod.position_embeddings.shape[1] - 1
175
+ if num_patches == num_positions and height == width:
176
+ return self_mod.position_embeddings
177
+ return old_interpolate_pos_encoding(embeddings, height, width)
178
+
179
+ self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions)
180
+ self.encoder.embeddings.interpolate_pos_encoding = types.MethodType(
181
+ new_interpolate_pos_encoding,
182
+ self.encoder.embeddings
183
+ )
184
+
185
+ def forward(self, x):
186
+ block_size = self.patch_size * self.num_windows
187
+ assert x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0, f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}"
188
+ x = self.encoder(x)
189
+ return list(x[0])
190
+
191
+ if __name__ == "__main__":
192
+ model = DinoV2()
193
+ model.export()
194
+ x = torch.randn(1, 3, 640, 640)
195
+ print(model(x))
196
+ for j in model(x):
197
+ print(j.shape)