Theo Viel commited on
Commit
9954323
·
1 Parent(s): c2a54f1
model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import importlib
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from yolox.boxes import postprocess
10
+
11
+
12
+ def define_model(config_name="page_element_v3", verbose=True):
13
+ """
14
+ Defines and initializes the model based on the configuration.
15
+
16
+ Args:
17
+ config_name (str): Configuration name. Defaults to "page_element_v3".
18
+ verbose (bool): Whether to print verbose output. Defaults to True.
19
+
20
+ Returns:
21
+ torch.nn.Module: The initialized YOLOX model.
22
+ """
23
+ # Load model from exp_file
24
+ sys.path.append(os.path.dirname(config_name))
25
+ exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0])
26
+
27
+ config = exp_module.Exp()
28
+ model = config.get_model()
29
+
30
+ # Load weights
31
+ if verbose:
32
+ print(" -> Loading weights from", config.ckpt)
33
+
34
+ ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
35
+ model.load_state_dict(ckpt["model"], strict=True)
36
+
37
+ model = YoloXWrapper(model, config)
38
+ return model.eval().to(config.device)
39
+
40
+
41
+ def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor:
42
+ """
43
+ Resizes and pads an image to a given size.
44
+ The goal is to preserve the aspect ratio of the image.
45
+
46
+ Args:
47
+ img (torch.Tensor[C x H x W]): The image to resize and pad.
48
+ size (tuple[2]): The size to resize and pad the image to.
49
+
50
+ Returns:
51
+ torch.Tensor: The resized and padded image.
52
+ """
53
+ img = img.float()
54
+ _, h, w = img.shape
55
+ scale = min(size[0] / h, size[1] / w)
56
+ nh = int(h * scale)
57
+ nw = int(w * scale)
58
+ img = F.interpolate(
59
+ img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False
60
+ ).squeeze(0)
61
+ img = torch.clamp(img, 0, 255)
62
+ pad_b = size[0] - nh
63
+ pad_r = size[1] - nw
64
+ img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0)
65
+ return img
66
+
67
+
68
+ class YoloXWrapper(nn.Module):
69
+ """
70
+ Wrapper for YoloX models.
71
+ """
72
+ def __init__(self, model, config):
73
+ """
74
+ Constructor
75
+
76
+ Args:
77
+ model (torch model): Yolo model.
78
+ config (Config): Config.
79
+ """
80
+ super().__init__()
81
+ self.model = model
82
+ self.config = config
83
+
84
+ # Copy config parameters
85
+ self.device = config.device
86
+ self.img_size = config.size
87
+ self.min_bbox_size = config.min_bbox_size
88
+ self.normalize_boxes = config.normalize_boxes
89
+ self.conf_thresh = config.conf_thresh
90
+ self.iou_thresh = config.iou_thresh
91
+ self.class_agnostic = config.class_agnostic
92
+ self.thresholds_per_class = config.thresholds_per_class
93
+ self.labels = config.labels
94
+ self.num_classes = config.num_classes
95
+
96
+ def reformat_input(self, x, orig_sizes):
97
+ """
98
+ Reformats the input data and original sizes to the correct format.
99
+
100
+ Args:
101
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
102
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
103
+ Returns:
104
+ torch tensor [BS x C x H x W]: Input image batch.
105
+ torch tensor [BS x 2]: Original image sizes (before resizing and padding).
106
+ """
107
+ # Convert image size to tensor
108
+ if isinstance(orig_sizes, (list, tuple)):
109
+ orig_sizes = np.array(orig_sizes)
110
+ if orig_sizes.shape[-1] == 3: # remove channel
111
+ orig_sizes = orig_sizes[..., :2]
112
+ if isinstance(orig_sizes, np.ndarray):
113
+ orig_sizes = torch.from_numpy(orig_sizes).to(self.device)
114
+
115
+ # Add batch dimension if not present
116
+ if len(x.size()) == 3:
117
+ x = x.unsqueeze(0)
118
+ if len(orig_sizes.size()) == 1:
119
+ orig_sizes = orig_sizes.unsqueeze(0)
120
+
121
+ return x, orig_sizes
122
+
123
+ def preprocess(self, image):
124
+ """
125
+ YoloX preprocessing function:
126
+ - Resizes to the longest edge to img_size while preserving the aspect ratio
127
+ - Pads the shortest edge to img_size
128
+
129
+ Args:
130
+ image (torch tensor or np array [H x W x 3]): Input images in uint8 format.
131
+
132
+ Returns:
133
+ torch tensor [3 x H x W]: Processed image.
134
+ """
135
+ if not isinstance(image, torch.Tensor):
136
+ image = torch.from_numpy(image)
137
+ image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
138
+ image = resize_pad(image, self.img_size)
139
+ return image.float()
140
+
141
+ def forward(self, x, orig_sizes):
142
+ """
143
+ Forward pass of the model.
144
+ Applies NMS and reformats the predictions.
145
+
146
+ Args:
147
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
148
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
149
+
150
+ Returns:
151
+ list[dict]: List of prediction dictionaries. Each dictionary contains:
152
+ - labels (torch.Tensor[N]): Class labels
153
+ - boxes (torch.Tensor[N x 4]): Bounding boxes
154
+ - scores (torch.Tensor[N]): Confidence scores.
155
+ """
156
+ x, orig_sizes = self.reformat_input(x, orig_sizes)
157
+
158
+ # Scale to 0-255 if in range 0-1
159
+ if x.max() <= 1:
160
+ x *= 255
161
+
162
+ pred_boxes = self.model(x.to(self.device))
163
+
164
+ # NMS
165
+ pred_boxes = postprocess(
166
+ pred_boxes,
167
+ self.config.num_classes,
168
+ self.conf_thresh,
169
+ self.iou_thresh,
170
+ class_agnostic=self.class_agnostic,
171
+ )
172
+
173
+ # Reformat output
174
+ preds = []
175
+ for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)):
176
+ if p is None: # No detections
177
+ preds.append({
178
+ "labels": torch.empty(0),
179
+ "boxes": torch.empty((0, 4)),
180
+ "scores": torch.empty(0),
181
+ })
182
+ continue
183
+
184
+ p = p.view(-1, p.size(-1))
185
+ ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1])
186
+ bboxes = p[:, :4] / ratio
187
+
188
+ # Clip
189
+ bboxes[:, [0, 2]] = torch.clamp(bboxes[:, [0, 2]], 0, size[1])
190
+ bboxes[:, [1, 3]] = torch.clamp(bboxes[:, [1, 3]], 0, size[0])
191
+
192
+ # Remove too small
193
+ kept = (
194
+ (bboxes[:, 2] - bboxes[:, 0] > self.min_bbox_size) &
195
+ (bboxes[:, 3] - bboxes[:, 1] > self.min_bbox_size)
196
+ )
197
+ bboxes = bboxes[kept]
198
+ p = p[kept]
199
+
200
+ # Normalize to 0-1
201
+ if self.normalize_boxes:
202
+ bboxes[:, [0, 2]] /= size[1]
203
+ bboxes[:, [1, 3]] /= size[0]
204
+
205
+ scores = p[:, 4] * p[:, 5]
206
+ labels = p[:, 6]
207
+
208
+ preds.append({"labels": labels, "boxes": bboxes, "scores": scores})
209
+
210
+ return preds
page_element_v3.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Exp:
6
+ """
7
+ Configuration class for the page element model.
8
+ """
9
+ def __init__(self):
10
+ self.name = "page-element-v3"
11
+ self.ckpt = "weights.pth"
12
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+
14
+ # YOLOX architecture parameters
15
+ self.act = "silu"
16
+ self.depth = 1.00
17
+ self.width = 1.00
18
+ self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"]
19
+ self.num_classes = len(self.labels)
20
+
21
+ # Inference parameters
22
+ self.size = (1024, 1024)
23
+ self.min_bbox_size = 0
24
+ self.normalize_boxes = True
25
+
26
+ # NMS & thresholding. These can be updated
27
+ self.conf_thresh = 0.01
28
+ self.iou_thresh = 0.5
29
+ self.class_agnostic = True
30
+
31
+ self.thresholds_per_class = {
32
+ "table": 0.1,
33
+ "chart": 0.01,
34
+ "infographic": 0.01,
35
+ "title": 0.1,
36
+ "paragraph": 0.1,
37
+ "header_footer": 0.1,
38
+ }
39
+
40
+ def get_model(self):
41
+ """
42
+ Get the YOLOX model.
43
+ """
44
+ from yolox import YOLOX, YOLOPAFPN, YOLOXHead
45
+
46
+ # Build model
47
+ if getattr(self, "model", None) is None:
48
+ in_channels = [256, 512, 1024]
49
+ backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
50
+ head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
51
+ self.model = YOLOX(backbone, head)
52
+
53
+ # Update batch-norm parameters
54
+ def init_yolo(M):
55
+ for m in M.modules():
56
+ if isinstance(m, nn.BatchNorm2d):
57
+ m.eps = 1e-3
58
+ m.momentum = 0.03
59
+ self.model.apply(init_yolo)
60
+
61
+ return self.model
utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.patches import Rectangle
4
+
5
+ COLORS = [
6
+ "#003EFF",
7
+ "#FF8F00",
8
+ "#079700",
9
+ "#A123FF",
10
+ "#87CEEB",
11
+ "#FF5733",
12
+ "#C70039",
13
+ "#900C3F",
14
+ "#581845",
15
+ "#11998E",
16
+ ]
17
+
18
+
19
+ def reformat_for_plotting(labels, bboxes, scores, shape, num_classes):
20
+ """
21
+ Reformat YOLOX predictions for plotting.
22
+
23
+ Args:
24
+ labels (np.ndarray): Array of labels.
25
+ bboxes (np.ndarray): Array of bounding boxes.
26
+ scores (np.ndarray): Array of confidence scores.
27
+ shape (tuple): Shape of the image.
28
+ num_classes (int): Number of classes.
29
+
30
+ Returns:
31
+ list[np.ndarray]: List of box bounding boxes per class.
32
+ list[np.ndarray]: List of confidence scores per class.
33
+ """
34
+ boxes_plot = bboxes.copy()
35
+ boxes_plot[:, [0, 2]] *= shape[1]
36
+ boxes_plot[:, [1, 3]] *= shape[0]
37
+ boxes_plot = boxes_plot.astype(int)
38
+ boxes_plot[:, 2] -= boxes_plot[:, 0]
39
+ boxes_plot[:, 3] -= boxes_plot[:, 1]
40
+ boxes_plot = [boxes_plot[labels == c] for c in range(num_classes)]
41
+ confs = [scores[labels == c] for c in range(num_classes)]
42
+ return boxes_plot, confs
43
+
44
+
45
+ def plot_sample(img, boxes_list, confs_list, labels):
46
+ """
47
+ Plots an image with bounding boxes.
48
+ Coordinates are expected in format [x_min, y_min, width, height].
49
+
50
+ Args:
51
+ img (numpy.ndarray): The input image to be plotted.
52
+ boxes_list (list[np.ndarray]): List of box bounding boxes per class.
53
+ confs_list (list[np.ndarray]): List of confidence scores per class.
54
+ labels (list): List of class labels.
55
+ """
56
+ plt.imshow(img, cmap="gray")
57
+ plt.axis(False)
58
+
59
+ for boxes, confs, col, l in zip(boxes_list, confs_list, COLORS, labels):
60
+ for box_idx, box in enumerate(boxes):
61
+ # Better display around boundaries
62
+ h, w, _ = img.shape
63
+ box = np.copy(box)
64
+ box[:2] = np.clip(box[:2], 2, max(h, w))
65
+ box[2] = min(box[2], w - 2 - box[0])
66
+ box[3] = min(box[3], h - 2 - box[1])
67
+
68
+ rect = Rectangle(
69
+ (box[0], box[1]),
70
+ box[2],
71
+ box[3],
72
+ linewidth=2,
73
+ facecolor="none",
74
+ edgecolor=col,
75
+ )
76
+ plt.gca().add_patch(rect)
77
+
78
+ # Add class and index label with proper alignment
79
+ plt.text(
80
+ box[0], box[1],
81
+ f"{l}_{box_idx} conf={confs[box_idx]:.3f}",
82
+ color='white',
83
+ fontsize=8,
84
+ bbox=dict(facecolor=col, alpha=1, edgecolor=col, pad=0, linewidth=2),
85
+ verticalalignment='bottom',
86
+ horizontalalignment='left'
87
+ )
88
+
89
+
90
+ def postprocess_preds_page_element(preds, thresholds_per_class, class_labels):
91
+ """
92
+ Post process predictions for the page element task.
93
+ - Applies thresholding
94
+
95
+ Args:
96
+ preds (dict): Predictions. Keys are "scores", "boxes", "labels".
97
+ thresholds_per_class (dict): Thresholds per class.
98
+ labels (list): List of class labels.
99
+
100
+ Returns:
101
+ labels (numpy.ndarray): Array of labels.
102
+ bboxes (numpy.ndarray): Array of bounding boxes.
103
+ scores (numpy.ndarray): Array of scores.
104
+ """
105
+ labels = preds["labels"].cpu().numpy()
106
+ boxes = preds["boxes"].cpu().numpy()
107
+ scores = preds["scores"].cpu().numpy()
108
+
109
+ # Threshold per class
110
+ thresholds = np.array(
111
+ [thresholds_per_class[class_labels[int(x)]] for x in labels]
112
+ )
113
+ labels = labels[scores > thresholds]
114
+ boxes = boxes[scores > thresholds]
115
+ scores = scores[scores > thresholds]
116
+
117
+ return labels, boxes, scores
yolox/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from .yolo_head import YOLOXHead
6
+ from .yolo_pafpn import YOLOPAFPN
7
+ from .yolox import YOLOX
yolox/boxes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Megvii Inc. All rights reserved.
3
+
4
+ import torch
5
+ import torchvision
6
+
7
+
8
+ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
9
+ """
10
+ Copied from YOLOX/yolox/utils/boxes.py
11
+ """
12
+ box_corner = prediction.new(prediction.shape)
13
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
14
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
15
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
16
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
17
+ prediction[:, :, :4] = box_corner[:, :, :4]
18
+
19
+ output = [None for _ in range(len(prediction))]
20
+ for i, image_pred in enumerate(prediction):
21
+
22
+ # If none are remaining => process next image
23
+ if not image_pred.size(0):
24
+ continue
25
+ # Get score and class with highest confidence
26
+ class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
27
+
28
+ conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
29
+ # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
30
+ detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
31
+ detections = detections[conf_mask]
32
+ if not detections.size(0):
33
+ continue
34
+
35
+ if class_agnostic:
36
+ nms_out_index = torchvision.ops.nms(
37
+ detections[:, :4],
38
+ detections[:, 4] * detections[:, 5],
39
+ nms_thre,
40
+ )
41
+ else:
42
+ nms_out_index = torchvision.ops.batched_nms(
43
+ detections[:, :4],
44
+ detections[:, 4] * detections[:, 5],
45
+ detections[:, 6],
46
+ nms_thre,
47
+ )
48
+
49
+ detections = detections[nms_out_index]
50
+ if output[i] is None:
51
+ output[i] = detections
52
+ else:
53
+ output[i] = torch.cat((output[i], detections))
54
+
55
+ return output
yolox/darknet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from torch import nn
6
+
7
+ from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
8
+
9
+
10
+ class Darknet(nn.Module):
11
+ # number of blocks from dark2 to dark5.
12
+ depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
13
+
14
+ def __init__(
15
+ self,
16
+ depth,
17
+ in_channels=3,
18
+ stem_out_channels=32,
19
+ out_features=("dark3", "dark4", "dark5"),
20
+ ):
21
+ """
22
+ Args:
23
+ depth (int): depth of darknet used in model, usually use [21, 53] for this param.
24
+ in_channels (int): number of input channels, for example, use 3 for RGB image.
25
+ stem_out_channels (int): number of output channels of darknet stem.
26
+ It decides channels of darknet layer2 to layer5.
27
+ out_features (Tuple[str]): desired output layer name.
28
+ """
29
+ super().__init__()
30
+ assert out_features, "please provide output features of Darknet"
31
+ self.out_features = out_features
32
+ self.stem = nn.Sequential(
33
+ BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
34
+ *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
35
+ )
36
+ in_channels = stem_out_channels * 2 # 64
37
+
38
+ num_blocks = Darknet.depth2blocks[depth]
39
+ # create darknet with `stem_out_channels` and `num_blocks` layers.
40
+ # to make model structure more clear, we don't use `for` statement in python.
41
+ self.dark2 = nn.Sequential(
42
+ *self.make_group_layer(in_channels, num_blocks[0], stride=2)
43
+ )
44
+ in_channels *= 2 # 128
45
+ self.dark3 = nn.Sequential(
46
+ *self.make_group_layer(in_channels, num_blocks[1], stride=2)
47
+ )
48
+ in_channels *= 2 # 256
49
+ self.dark4 = nn.Sequential(
50
+ *self.make_group_layer(in_channels, num_blocks[2], stride=2)
51
+ )
52
+ in_channels *= 2 # 512
53
+
54
+ self.dark5 = nn.Sequential(
55
+ *self.make_group_layer(in_channels, num_blocks[3], stride=2),
56
+ *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
57
+ )
58
+
59
+ def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
60
+ "starts with conv layer then has `num_blocks` `ResLayer`"
61
+ return [
62
+ BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
63
+ *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
64
+ ]
65
+
66
+ def make_spp_block(self, filters_list, in_filters):
67
+ m = nn.Sequential(
68
+ *[
69
+ BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
70
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
71
+ SPPBottleneck(
72
+ in_channels=filters_list[1],
73
+ out_channels=filters_list[0],
74
+ activation="lrelu",
75
+ ),
76
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
77
+ BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
78
+ ]
79
+ )
80
+ return m
81
+
82
+ def forward(self, x):
83
+ outputs = {}
84
+ x = self.stem(x)
85
+ outputs["stem"] = x
86
+ x = self.dark2(x)
87
+ outputs["dark2"] = x
88
+ x = self.dark3(x)
89
+ outputs["dark3"] = x
90
+ x = self.dark4(x)
91
+ outputs["dark4"] = x
92
+ x = self.dark5(x)
93
+ outputs["dark5"] = x
94
+ return {k: v for k, v in outputs.items() if k in self.out_features}
95
+
96
+
97
+ class CSPDarknet(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dep_mul,
101
+ wid_mul,
102
+ out_features=("dark3", "dark4", "dark5"),
103
+ depthwise=False,
104
+ act="silu",
105
+ ):
106
+ super().__init__()
107
+ assert out_features, "please provide output features of Darknet"
108
+ self.out_features = out_features
109
+ Conv = DWConv if depthwise else BaseConv
110
+
111
+ base_channels = int(wid_mul * 64) # 64
112
+ base_depth = max(round(dep_mul * 3), 1) # 3
113
+
114
+ # stem
115
+ self.stem = Focus(3, base_channels, ksize=3, act=act)
116
+
117
+ # dark2
118
+ self.dark2 = nn.Sequential(
119
+ Conv(base_channels, base_channels * 2, 3, 2, act=act),
120
+ CSPLayer(
121
+ base_channels * 2,
122
+ base_channels * 2,
123
+ n=base_depth,
124
+ depthwise=depthwise,
125
+ act=act,
126
+ ),
127
+ )
128
+
129
+ # dark3
130
+ self.dark3 = nn.Sequential(
131
+ Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
132
+ CSPLayer(
133
+ base_channels * 4,
134
+ base_channels * 4,
135
+ n=base_depth * 3,
136
+ depthwise=depthwise,
137
+ act=act,
138
+ ),
139
+ )
140
+
141
+ # dark4
142
+ self.dark4 = nn.Sequential(
143
+ Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
144
+ CSPLayer(
145
+ base_channels * 8,
146
+ base_channels * 8,
147
+ n=base_depth * 3,
148
+ depthwise=depthwise,
149
+ act=act,
150
+ ),
151
+ )
152
+
153
+ # dark5
154
+ self.dark5 = nn.Sequential(
155
+ Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
156
+ SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
157
+ CSPLayer(
158
+ base_channels * 16,
159
+ base_channels * 16,
160
+ n=base_depth,
161
+ shortcut=False,
162
+ depthwise=depthwise,
163
+ act=act,
164
+ ),
165
+ )
166
+
167
+ def forward(self, x):
168
+ outputs = {}
169
+ x = self.stem(x)
170
+ outputs["stem"] = x
171
+ x = self.dark2(x)
172
+ outputs["dark2"] = x
173
+ x = self.dark3(x)
174
+ outputs["dark3"] = x
175
+ x = self.dark4(x)
176
+ outputs["dark4"] = x
177
+ x = self.dark5(x)
178
+ outputs["dark5"] = x
179
+ return {k: v for k, v in outputs.items() if k in self.out_features}
yolox/network_blocks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class SiLU(nn.Module):
10
+ """export-friendly version of nn.SiLU()"""
11
+
12
+ @staticmethod
13
+ def forward(x):
14
+ return x * torch.sigmoid(x)
15
+
16
+
17
+ def get_activation(name="silu", inplace=True):
18
+ if name == "silu":
19
+ module = nn.SiLU(inplace=inplace)
20
+ elif name == "relu":
21
+ module = nn.ReLU(inplace=inplace)
22
+ elif name == "lrelu":
23
+ module = nn.LeakyReLU(0.1, inplace=inplace)
24
+ else:
25
+ raise AttributeError("Unsupported act type: {}".format(name))
26
+ return module
27
+
28
+
29
+ class BaseConv(nn.Module):
30
+ """A Conv2d -> Batchnorm -> silu/leaky relu block"""
31
+
32
+ def __init__(
33
+ self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
34
+ ):
35
+ super().__init__()
36
+ # same padding
37
+ pad = (ksize - 1) // 2
38
+ self.conv = nn.Conv2d(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size=ksize,
42
+ stride=stride,
43
+ padding=pad,
44
+ groups=groups,
45
+ bias=bias,
46
+ )
47
+ self.bn = nn.BatchNorm2d(out_channels)
48
+ self.act = get_activation(act, inplace=True)
49
+
50
+ def forward(self, x):
51
+ return self.act(self.bn(self.conv(x)))
52
+
53
+ def fuseforward(self, x):
54
+ return self.act(self.conv(x))
55
+
56
+
57
+ class DWConv(nn.Module):
58
+ """Depthwise Conv + Conv"""
59
+
60
+ def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
61
+ super().__init__()
62
+ self.dconv = BaseConv(
63
+ in_channels,
64
+ in_channels,
65
+ ksize=ksize,
66
+ stride=stride,
67
+ groups=in_channels,
68
+ act=act,
69
+ )
70
+ self.pconv = BaseConv(
71
+ in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
72
+ )
73
+
74
+ def forward(self, x):
75
+ x = self.dconv(x)
76
+ return self.pconv(x)
77
+
78
+
79
+ class Bottleneck(nn.Module):
80
+ # Standard bottleneck
81
+ def __init__(
82
+ self,
83
+ in_channels,
84
+ out_channels,
85
+ shortcut=True,
86
+ expansion=0.5,
87
+ depthwise=False,
88
+ act="silu",
89
+ ):
90
+ super().__init__()
91
+ hidden_channels = int(out_channels * expansion)
92
+ Conv = DWConv if depthwise else BaseConv
93
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
94
+ self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
95
+ self.use_add = shortcut and in_channels == out_channels
96
+
97
+ def forward(self, x):
98
+ y = self.conv2(self.conv1(x))
99
+ if self.use_add:
100
+ y = y + x
101
+ return y
102
+
103
+
104
+ class ResLayer(nn.Module):
105
+ "Residual layer with `in_channels` inputs."
106
+
107
+ def __init__(self, in_channels: int):
108
+ super().__init__()
109
+ mid_channels = in_channels // 2
110
+ self.layer1 = BaseConv(
111
+ in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
112
+ )
113
+ self.layer2 = BaseConv(
114
+ mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
115
+ )
116
+
117
+ def forward(self, x):
118
+ out = self.layer2(self.layer1(x))
119
+ return x + out
120
+
121
+
122
+ class SPPBottleneck(nn.Module):
123
+ """Spatial pyramid pooling layer used in YOLOv3-SPP"""
124
+
125
+ def __init__(
126
+ self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
127
+ ):
128
+ super().__init__()
129
+ hidden_channels = in_channels // 2
130
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
131
+ self.m = nn.ModuleList(
132
+ [
133
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
134
+ for ks in kernel_sizes
135
+ ]
136
+ )
137
+ conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
138
+ self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
139
+
140
+ def forward(self, x):
141
+ x = self.conv1(x)
142
+ x = torch.cat([x] + [m(x) for m in self.m], dim=1)
143
+ x = self.conv2(x)
144
+ return x
145
+
146
+
147
+ class CSPLayer(nn.Module):
148
+ """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
149
+
150
+ def __init__(
151
+ self,
152
+ in_channels,
153
+ out_channels,
154
+ n=1,
155
+ shortcut=True,
156
+ expansion=0.5,
157
+ depthwise=False,
158
+ act="silu",
159
+ ):
160
+ """
161
+ Args:
162
+ in_channels (int): input channels.
163
+ out_channels (int): output channels.
164
+ n (int): number of Bottlenecks. Default value: 1.
165
+ """
166
+ # ch_in, ch_out, number, shortcut, groups, expansion
167
+ super().__init__()
168
+ hidden_channels = int(out_channels * expansion) # hidden channels
169
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
170
+ self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
171
+ self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
172
+ module_list = [
173
+ Bottleneck(
174
+ hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
175
+ )
176
+ for _ in range(n)
177
+ ]
178
+ self.m = nn.Sequential(*module_list)
179
+
180
+ def forward(self, x):
181
+ x_1 = self.conv1(x)
182
+ x_2 = self.conv2(x)
183
+ x_1 = self.m(x_1)
184
+ x = torch.cat((x_1, x_2), dim=1)
185
+ return self.conv3(x)
186
+
187
+
188
+ class Focus(nn.Module):
189
+ """Focus width and height information into channel space."""
190
+
191
+ def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
192
+ super().__init__()
193
+ self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
194
+
195
+ def forward(self, x):
196
+ # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
197
+ patch_top_left = x[..., ::2, ::2]
198
+ patch_top_right = x[..., ::2, 1::2]
199
+ patch_bot_left = x[..., 1::2, ::2]
200
+ patch_bot_right = x[..., 1::2, 1::2]
201
+ x = torch.cat(
202
+ (
203
+ patch_top_left,
204
+ patch_bot_left,
205
+ patch_top_right,
206
+ patch_bot_right,
207
+ ),
208
+ dim=1,
209
+ )
210
+ return self.conv(x)
yolox/yolo_fpn.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import Darknet
9
+ from .network_blocks import BaseConv
10
+
11
+
12
+ class YOLOFPN(nn.Module):
13
+ """
14
+ YOLOFPN module. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=53,
20
+ in_features=["dark3", "dark4", "dark5"],
21
+ ):
22
+ super().__init__()
23
+
24
+ self.backbone = Darknet(depth)
25
+ self.in_features = in_features
26
+
27
+ # out 1
28
+ self.out1_cbl = self._make_cbl(512, 256, 1)
29
+ self.out1 = self._make_embedding([256, 512], 512 + 256)
30
+
31
+ # out 2
32
+ self.out2_cbl = self._make_cbl(256, 128, 1)
33
+ self.out2 = self._make_embedding([128, 256], 256 + 128)
34
+
35
+ # upsample
36
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
37
+
38
+ def _make_cbl(self, _in, _out, ks):
39
+ return BaseConv(_in, _out, ks, stride=1, act="lrelu")
40
+
41
+ def _make_embedding(self, filters_list, in_filters):
42
+ m = nn.Sequential(
43
+ *[
44
+ self._make_cbl(in_filters, filters_list[0], 1),
45
+ self._make_cbl(filters_list[0], filters_list[1], 3),
46
+ self._make_cbl(filters_list[1], filters_list[0], 1),
47
+ self._make_cbl(filters_list[0], filters_list[1], 3),
48
+ self._make_cbl(filters_list[1], filters_list[0], 1),
49
+ ]
50
+ )
51
+ return m
52
+
53
+ def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
54
+ with open(filename, "rb") as f:
55
+ state_dict = torch.load(f, map_location="cpu")
56
+ print("loading pretrained weights...")
57
+ self.backbone.load_state_dict(state_dict)
58
+
59
+ def forward(self, inputs):
60
+ """
61
+ Args:
62
+ inputs (Tensor): input image.
63
+
64
+ Returns:
65
+ Tuple[Tensor]: FPN output features..
66
+ """
67
+ # backbone
68
+ out_features = self.backbone(inputs)
69
+ x2, x1, x0 = [out_features[f] for f in self.in_features]
70
+
71
+ # yolo branch 1
72
+ x1_in = self.out1_cbl(x0)
73
+ x1_in = self.upsample(x1_in)
74
+ x1_in = torch.cat([x1_in, x1], 1)
75
+ out_dark4 = self.out1(x1_in)
76
+
77
+ # yolo branch 2
78
+ x2_in = self.out2_cbl(out_dark4)
79
+ x2_in = self.upsample(x2_in)
80
+ x2_in = torch.cat([x2_in, x2], 1)
81
+ out_dark3 = self.out2(x2_in)
82
+
83
+ outputs = (out_dark3, out_dark4, x0)
84
+ return outputs
yolox/yolo_head.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from .network_blocks import BaseConv, DWConv
8
+
9
+
10
+ _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
11
+
12
+
13
+ def meshgrid(*tensors):
14
+ """
15
+ Copied from YOLOX/yolox/utils/compat.py
16
+ """
17
+ if _TORCH_VER >= [1, 10]:
18
+ return torch.meshgrid(*tensors, indexing="ij")
19
+ else:
20
+ return torch.meshgrid(*tensors)
21
+
22
+
23
+ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
24
+ """
25
+ Copied from YOLOX/yolox/utils/boxes.py
26
+ """
27
+ if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
28
+ raise IndexError
29
+
30
+ if xyxy:
31
+ tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
32
+ br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
33
+ area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
34
+ area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
35
+ else:
36
+ tl = torch.max(
37
+ (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
38
+ (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
39
+ )
40
+ br = torch.min(
41
+ (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
42
+ (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
43
+ )
44
+
45
+ area_a = torch.prod(bboxes_a[:, 2:], 1)
46
+ area_b = torch.prod(bboxes_b[:, 2:], 1)
47
+ en = (tl < br).type(tl.type()).prod(dim=2)
48
+ area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
49
+ return area_i / (area_a[:, None] + area_b - area_i)
50
+
51
+
52
+ class YOLOXHead(nn.Module):
53
+ def __init__(
54
+ self,
55
+ num_classes,
56
+ width=1.0,
57
+ strides=[8, 16, 32],
58
+ in_channels=[256, 512, 1024],
59
+ act="silu",
60
+ depthwise=False,
61
+ ):
62
+ """
63
+ Args:
64
+ act (str): activation type of conv. Defalut value: "silu".
65
+ depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
66
+ """
67
+ super().__init__()
68
+
69
+ self.num_classes = num_classes
70
+ self.decode_in_inference = True # for deploy, set to False
71
+
72
+ self.cls_convs = nn.ModuleList()
73
+ self.reg_convs = nn.ModuleList()
74
+ self.cls_preds = nn.ModuleList()
75
+ self.reg_preds = nn.ModuleList()
76
+ self.obj_preds = nn.ModuleList()
77
+ self.stems = nn.ModuleList()
78
+ Conv = DWConv if depthwise else BaseConv
79
+
80
+ for i in range(len(in_channels)):
81
+ self.stems.append(
82
+ BaseConv(
83
+ in_channels=int(in_channels[i] * width),
84
+ out_channels=int(256 * width),
85
+ ksize=1,
86
+ stride=1,
87
+ act=act,
88
+ )
89
+ )
90
+ self.cls_convs.append(
91
+ nn.Sequential(
92
+ *[
93
+ Conv(
94
+ in_channels=int(256 * width),
95
+ out_channels=int(256 * width),
96
+ ksize=3,
97
+ stride=1,
98
+ act=act,
99
+ ),
100
+ Conv(
101
+ in_channels=int(256 * width),
102
+ out_channels=int(256 * width),
103
+ ksize=3,
104
+ stride=1,
105
+ act=act,
106
+ ),
107
+ ]
108
+ )
109
+ )
110
+ self.reg_convs.append(
111
+ nn.Sequential(
112
+ *[
113
+ Conv(
114
+ in_channels=int(256 * width),
115
+ out_channels=int(256 * width),
116
+ ksize=3,
117
+ stride=1,
118
+ act=act,
119
+ ),
120
+ Conv(
121
+ in_channels=int(256 * width),
122
+ out_channels=int(256 * width),
123
+ ksize=3,
124
+ stride=1,
125
+ act=act,
126
+ ),
127
+ ]
128
+ )
129
+ )
130
+ self.cls_preds.append(
131
+ nn.Conv2d(
132
+ in_channels=int(256 * width),
133
+ out_channels=self.num_classes,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0,
137
+ )
138
+ )
139
+ self.reg_preds.append(
140
+ nn.Conv2d(
141
+ in_channels=int(256 * width),
142
+ out_channels=4,
143
+ kernel_size=1,
144
+ stride=1,
145
+ padding=0,
146
+ )
147
+ )
148
+ self.obj_preds.append(
149
+ nn.Conv2d(
150
+ in_channels=int(256 * width),
151
+ out_channels=1,
152
+ kernel_size=1,
153
+ stride=1,
154
+ padding=0,
155
+ )
156
+ )
157
+
158
+ self.use_l1 = False
159
+ self.l1_loss = nn.L1Loss(reduction="none")
160
+ self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
161
+ self.iou_loss = None
162
+ self.strides = strides
163
+ self.grids = [torch.zeros(1)] * len(in_channels)
164
+
165
+ def forward(self, xin, labels=None, imgs=None):
166
+ outputs = []
167
+ for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
168
+ zip(self.cls_convs, self.reg_convs, self.strides, xin)
169
+ ):
170
+ x = self.stems[k](x)
171
+ cls_x = x
172
+ reg_x = x
173
+
174
+ cls_feat = cls_conv(cls_x)
175
+ cls_output = self.cls_preds[k](cls_feat)
176
+
177
+ reg_feat = reg_conv(reg_x)
178
+ reg_output = self.reg_preds[k](reg_feat)
179
+ obj_output = self.obj_preds[k](reg_feat)
180
+
181
+ output = torch.cat(
182
+ [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
183
+ )
184
+
185
+ outputs.append(output)
186
+
187
+ self.hw = [x.shape[-2:] for x in outputs]
188
+ # [batch, n_anchors_all, 85]
189
+ outputs = torch.cat(
190
+ [x.flatten(start_dim=2) for x in outputs], dim=2
191
+ ).permute(0, 2, 1)
192
+ if self.decode_in_inference:
193
+ return self.decode_outputs(outputs, dtype=xin[0].type())
194
+ else:
195
+ return outputs
196
+
197
+ def get_output_and_grid(self, output, k, stride, dtype):
198
+ grid = self.grids[k]
199
+
200
+ batch_size = output.shape[0]
201
+ n_ch = 5 + self.num_classes
202
+ hsize, wsize = output.shape[-2:]
203
+ if grid.shape[2:4] != output.shape[2:4]:
204
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
205
+ grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
206
+ self.grids[k] = grid
207
+
208
+ output = output.view(batch_size, 1, n_ch, hsize, wsize)
209
+ output = output.permute(0, 1, 3, 4, 2).reshape(
210
+ batch_size, hsize * wsize, -1
211
+ )
212
+ grid = grid.view(1, -1, 2)
213
+ output[..., :2] = (output[..., :2] + grid) * stride
214
+ output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
215
+ return output, grid
216
+
217
+ def decode_outputs(self, outputs, dtype):
218
+ grids = []
219
+ strides = []
220
+ for (hsize, wsize), stride in zip(self.hw, self.strides):
221
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
222
+ grid = torch.stack((xv, yv), 2).view(1, -1, 2)
223
+ grids.append(grid)
224
+ shape = grid.shape[:2]
225
+ strides.append(torch.full((*shape, 1), stride))
226
+
227
+ grids = torch.cat(grids, dim=1).type(dtype)
228
+ strides = torch.cat(strides, dim=1).type(dtype)
229
+
230
+ outputs = torch.cat([
231
+ (outputs[..., 0:2] + grids) * strides,
232
+ torch.exp(outputs[..., 2:4]) * strides,
233
+ outputs[..., 4:]
234
+ ], dim=-1)
235
+ return outputs
yolox/yolo_pafpn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import CSPDarknet
9
+ from .network_blocks import BaseConv, CSPLayer, DWConv
10
+
11
+
12
+ class YOLOPAFPN(nn.Module):
13
+ """
14
+ YOLOv3 model. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=1.0,
20
+ width=1.0,
21
+ in_features=("dark3", "dark4", "dark5"),
22
+ in_channels=[256, 512, 1024],
23
+ depthwise=False,
24
+ act="silu",
25
+ ):
26
+ super().__init__()
27
+ self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
28
+ self.in_features = in_features
29
+ self.in_channels = in_channels
30
+ Conv = DWConv if depthwise else BaseConv
31
+
32
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
33
+ self.lateral_conv0 = BaseConv(
34
+ int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
35
+ )
36
+ self.C3_p4 = CSPLayer(
37
+ int(2 * in_channels[1] * width),
38
+ int(in_channels[1] * width),
39
+ round(3 * depth),
40
+ False,
41
+ depthwise=depthwise,
42
+ act=act,
43
+ ) # cat
44
+
45
+ self.reduce_conv1 = BaseConv(
46
+ int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
47
+ )
48
+ self.C3_p3 = CSPLayer(
49
+ int(2 * in_channels[0] * width),
50
+ int(in_channels[0] * width),
51
+ round(3 * depth),
52
+ False,
53
+ depthwise=depthwise,
54
+ act=act,
55
+ )
56
+
57
+ # bottom-up conv
58
+ self.bu_conv2 = Conv(
59
+ int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
60
+ )
61
+ self.C3_n3 = CSPLayer(
62
+ int(2 * in_channels[0] * width),
63
+ int(in_channels[1] * width),
64
+ round(3 * depth),
65
+ False,
66
+ depthwise=depthwise,
67
+ act=act,
68
+ )
69
+
70
+ # bottom-up conv
71
+ self.bu_conv1 = Conv(
72
+ int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
73
+ )
74
+ self.C3_n4 = CSPLayer(
75
+ int(2 * in_channels[1] * width),
76
+ int(in_channels[2] * width),
77
+ round(3 * depth),
78
+ False,
79
+ depthwise=depthwise,
80
+ act=act,
81
+ )
82
+
83
+ def forward(self, input):
84
+ """
85
+ Args:
86
+ inputs: input images.
87
+
88
+ Returns:
89
+ Tuple[Tensor]: FPN feature.
90
+ """
91
+
92
+ # backbone
93
+ out_features = self.backbone(input)
94
+ features = [out_features[f] for f in self.in_features]
95
+ [x2, x1, x0] = features
96
+
97
+ fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
98
+ f_out0 = self.upsample(fpn_out0) # 512/16
99
+ f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
100
+ f_out0 = self.C3_p4(f_out0) # 1024->512/16
101
+
102
+ fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
103
+ f_out1 = self.upsample(fpn_out1) # 256/8
104
+ f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
105
+ pan_out2 = self.C3_p3(f_out1) # 512->256/8
106
+
107
+ p_out1 = self.bu_conv2(pan_out2) # 256->256/16
108
+ p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
109
+ pan_out1 = self.C3_n3(p_out1) # 512->512/16
110
+
111
+ p_out0 = self.bu_conv1(pan_out1) # 512->512/32
112
+ p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
113
+ pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
114
+
115
+ outputs = (pan_out2, pan_out1, pan_out0)
116
+ return outputs
yolox/yolox.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch.nn as nn
6
+
7
+ from .yolo_head import YOLOXHead
8
+ from .yolo_pafpn import YOLOPAFPN
9
+
10
+
11
+ class YOLOX(nn.Module):
12
+ """
13
+ YOLOX model module. The module list is defined by create_yolov3_modules function.
14
+ The network returns loss values from three YOLO layers during training
15
+ and detection results during test.
16
+ """
17
+
18
+ def __init__(self, backbone=None, head=None):
19
+ super().__init__()
20
+ if backbone is None:
21
+ backbone = YOLOPAFPN()
22
+ if head is None:
23
+ head = YOLOXHead(80)
24
+
25
+ self.backbone = backbone
26
+ self.head = head
27
+
28
+ def forward(self, x, targets=None):
29
+ assert not self.training, "Training mode not supported, please refer to the YOLOX repo"
30
+ fpn_outs = self.backbone(x)
31
+ outputs = self.head(fpn_outs)
32
+ return outputs