Spaces:
Runtime error
Runtime error
Upload tf_post_processing.py
Browse files- tf_post_processing.py +233 -0
tf_post_processing.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Mon Sep 4 16:03:42 2023
|
| 4 |
+
|
| 5 |
+
@author: SABARI
|
| 6 |
+
"""
|
| 7 |
+
import time
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import numpy as np
|
| 10 |
+
#from lsnms import nms, wbc
|
| 11 |
+
|
| 12 |
+
def box_iou(box1, box2, eps=1e-7):
|
| 13 |
+
"""
|
| 14 |
+
Calculate intersection-over-union (IoU) of boxes.
|
| 15 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
| 16 |
+
Args:
|
| 17 |
+
box1 (tf.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
| 18 |
+
box2 (tf.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
| 19 |
+
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
| 20 |
+
Returns:
|
| 21 |
+
(tf.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
a1, a2 = tf.split(box1, 2, axis=1)
|
| 25 |
+
b1, b2 = tf.split(box2, 2, axis=1)
|
| 26 |
+
|
| 27 |
+
inter = tf.reduce_prod(tf.maximum(tf.minimum(a2, b2) - tf.maximum(a1, b1), 0), axis=1)
|
| 28 |
+
|
| 29 |
+
return inter / (tf.reduce_prod(a2 - a1, axis=1) + tf.reduce_prod(b2 - b1, axis=1) - inter + eps)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def xywh2xyxy(x):
|
| 33 |
+
"""
|
| 34 |
+
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
|
| 35 |
+
top-left corner and (x2, y2) is the bottom-right corner.
|
| 36 |
+
Args:
|
| 37 |
+
x (tf.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
| 38 |
+
Returns:
|
| 39 |
+
y (tf.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
| 40 |
+
"""
|
| 41 |
+
# Assuming x is a NumPy array
|
| 42 |
+
y = np.copy(x)
|
| 43 |
+
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
| 44 |
+
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
| 45 |
+
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
| 46 |
+
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
| 47 |
+
return y
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False,
|
| 51 |
+
multi_label=False, max_det=300, nc=0, # number of classes (optional)
|
| 52 |
+
max_time_img=0.05,
|
| 53 |
+
max_nms=100,
|
| 54 |
+
max_wh=7680):
|
| 55 |
+
"""
|
| 56 |
+
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
| 57 |
+
Arguments:
|
| 58 |
+
prediction (tf.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
| 59 |
+
containing the predicted boxes, classes, and masks. The tensor should be in the format
|
| 60 |
+
output by a model, such as YOLO.
|
| 61 |
+
conf_thres (float): The confidence threshold below which boxes will be filtered out.
|
| 62 |
+
Valid values are between 0.0 and 1.0.
|
| 63 |
+
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
|
| 64 |
+
Valid values are between 0.0 and 1.0.
|
| 65 |
+
agnostic (bool): If True, the model is agnostic to the number of classes, and all
|
| 66 |
+
classes will be considered as one.
|
| 67 |
+
multi_label (bool): If True, each box may have multiple labels.
|
| 68 |
+
max_det (int): The maximum number of boxes to keep after NMS.
|
| 69 |
+
nc (int): (optional) The number of classes output by the model. Any indices after this will be considered masks.
|
| 70 |
+
max_time_img (float): The maximum time (seconds) for processing one image.
|
| 71 |
+
max_nms (int): The maximum number of boxes into tf.image.combined_non_max_suppression().
|
| 72 |
+
max_wh (int): The maximum box width and height in pixels
|
| 73 |
+
Returns:
|
| 74 |
+
(List[tf.Tensor]): A list of length batch_size, where each element is a tensor of
|
| 75 |
+
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
| 76 |
+
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Checks
|
| 80 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
| 81 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
| 82 |
+
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
| 83 |
+
prediction = prediction[0] # select only inference output
|
| 84 |
+
|
| 85 |
+
bs = np.shape(prediction)[0] # batch size
|
| 86 |
+
nc = nc or (np.shape(prediction)[1] - 4) # number of classes
|
| 87 |
+
nm = np.shape(prediction)[1] - nc - 4
|
| 88 |
+
mi = 4 + nc # mask start index
|
| 89 |
+
#xc = tf.math.reduce_any(prediction[:, 4:mi] > conf_thres, axis=1) # candidates
|
| 90 |
+
xc = np.amax(prediction[:, 4:mi], axis=1) > conf_thres
|
| 91 |
+
|
| 92 |
+
# Settings
|
| 93 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
| 94 |
+
time_limit = 0.5 + max_time_img * tf.cast(bs, tf.float32) # seconds to quit after
|
| 95 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
| 96 |
+
|
| 97 |
+
t = time.time()
|
| 98 |
+
output = [np.zeros((0, 6 + nm))] * bs
|
| 99 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
| 100 |
+
# Apply constraints
|
| 101 |
+
# x = tf.where(tf.math.logical_or(x[:, 2:4] < min_wh, x[:, 2:4] > max_wh), tf.constant(0, dtype=tf.float32), x) # width-height
|
| 102 |
+
#x = tf.boolean_mask(x, xc[xi])
|
| 103 |
+
#x = x.transpose(0, -1)[xc[xi]] # confidence
|
| 104 |
+
# Assuming x, xc, and xi are NumPy arrays
|
| 105 |
+
x = np.transpose(x)
|
| 106 |
+
|
| 107 |
+
#x = x.transpose()[:, xc[xi]]
|
| 108 |
+
x = x[xc[xi]]
|
| 109 |
+
|
| 110 |
+
# If none remain process next image
|
| 111 |
+
if np.shape(x)[0] == 0:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
| 115 |
+
#box, cls, mask = tf.split(x, [4, nc, nm], axis=1)
|
| 116 |
+
|
| 117 |
+
# Assuming x is a NumPy array
|
| 118 |
+
box = x[:, :4]
|
| 119 |
+
cls = x[:, 4:4 + nc]
|
| 120 |
+
mask = x[:, 4 + nc:]
|
| 121 |
+
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
| 122 |
+
|
| 123 |
+
# Assuming cls is a NumPy array
|
| 124 |
+
if multi_label:
|
| 125 |
+
i, j = np.where(cls > conf_thres)
|
| 126 |
+
x = np.concatenate([box[i], np.expand_dims(cls[i, j], axis=-1), np.expand_dims(j, axis=-1).astype(np.float32), mask[i]], axis=1)
|
| 127 |
+
else:
|
| 128 |
+
conf = np.max(cls, axis=1)
|
| 129 |
+
j = np.argmax(cls, axis=1)
|
| 130 |
+
keep = np.where(conf > conf_thres)[0]
|
| 131 |
+
x = np.concatenate([box[keep], np.expand_dims(conf[keep], axis=-1), np.expand_dims(j[keep], axis=-1).astype(np.float32), mask[keep]], axis=1)
|
| 132 |
+
|
| 133 |
+
# Check shape
|
| 134 |
+
n = np.shape(x)[0] # number of boxes
|
| 135 |
+
if n == 0: # no boxes
|
| 136 |
+
continue
|
| 137 |
+
#x = x[tf.argsort(x[:, 4], direction='DESCENDING')[:max_nms]] # sort by confidence and remove excess boxes
|
| 138 |
+
sorted_indices = np.argsort(x[:, 4])[::-1] # Sort indices in descending order of confidence
|
| 139 |
+
x = x[sorted_indices[:max_nms]] # Keep the top max_nms boxes
|
| 140 |
+
|
| 141 |
+
# Batched NMS
|
| 142 |
+
c = x[:, 5:6] * (0.0 if agnostic else tf.cast(max_wh, tf.float32)) # classes
|
| 143 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
| 144 |
+
i = tf.image.non_max_suppression(boxes, scores, max_nms, iou_threshold=iou_thres) # NMS
|
| 145 |
+
i = i.numpy()
|
| 146 |
+
i = i[:max_det] # limit detections
|
| 147 |
+
|
| 148 |
+
output[xi] = x[i,:]
|
| 149 |
+
|
| 150 |
+
if (time.time() - t) > time_limit:
|
| 151 |
+
break # time limit exceeded
|
| 152 |
+
|
| 153 |
+
return output
|
| 154 |
+
|
| 155 |
+
import numpy as np
|
| 156 |
+
|
| 157 |
+
def optimized_object_detection(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False,
|
| 158 |
+
multi_label=False, max_det=300, nc=0, max_time_img=0.05,
|
| 159 |
+
max_nms=100, max_wh=7680):
|
| 160 |
+
|
| 161 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
| 162 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
| 163 |
+
|
| 164 |
+
if isinstance(prediction, (list, tuple)):
|
| 165 |
+
prediction = prediction[0]
|
| 166 |
+
|
| 167 |
+
bs, _, _ = prediction.shape # Get batch size and dimensions
|
| 168 |
+
|
| 169 |
+
if nc == 0:
|
| 170 |
+
nc = prediction.shape[1] - 4
|
| 171 |
+
|
| 172 |
+
nm = prediction.shape[1] - nc - 4
|
| 173 |
+
mi = 4 + nc
|
| 174 |
+
|
| 175 |
+
xc = np.amax(prediction[:, 4:mi], axis=1) > conf_thres
|
| 176 |
+
|
| 177 |
+
time_limit = 0.5 + max_time_img * bs
|
| 178 |
+
|
| 179 |
+
multi_label &= nc > 1
|
| 180 |
+
|
| 181 |
+
t = time.time()
|
| 182 |
+
output = [np.zeros((0, 6 + nm))] * bs
|
| 183 |
+
|
| 184 |
+
for xi, x in enumerate(prediction):
|
| 185 |
+
x = np.transpose(x)
|
| 186 |
+
x = x[xc[xi]]
|
| 187 |
+
|
| 188 |
+
if np.shape(x)[0] == 0:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
box = x[:, :4]
|
| 192 |
+
cls = x[:, 4:4 + nc]
|
| 193 |
+
mask = x[:, 4 + nc:]
|
| 194 |
+
box = xywh2xyxy(box)
|
| 195 |
+
|
| 196 |
+
if multi_label:
|
| 197 |
+
i, j = np.where(cls > conf_thres)
|
| 198 |
+
x = np.concatenate([box[i], np.expand_dims(cls[i, j], axis=-1), np.expand_dims(j, axis=-1).astype(np.float32), mask[i]], axis=1)
|
| 199 |
+
else:
|
| 200 |
+
conf = np.max(cls, axis=1)
|
| 201 |
+
j = np.argmax(cls, axis=1)
|
| 202 |
+
keep = np.where(conf > conf_thres)[0]
|
| 203 |
+
x = np.concatenate([box[keep], np.expand_dims(conf[keep], axis=-1), np.expand_dims(j[keep], axis=-1).astype(np.float32), mask[keep]], axis=1)
|
| 204 |
+
|
| 205 |
+
n = np.shape(x)[0]
|
| 206 |
+
if n == 0:
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
sorted_indices = np.argsort(x[:, 4])[::-1]
|
| 210 |
+
x = x[sorted_indices[:max_nms]]
|
| 211 |
+
|
| 212 |
+
c = x[:, 5:6] * (0.0 if agnostic else max_wh)
|
| 213 |
+
boxes, scores = x[:, :4] + c, x[:, 4]
|
| 214 |
+
i = tf.image.non_max_suppression(boxes, scores, max_nms, iou_threshold=iou_thres)
|
| 215 |
+
|
| 216 |
+
#keep = nms(boxes, scores, iou_threshold=iou_thres)
|
| 217 |
+
|
| 218 |
+
i = i.numpy()
|
| 219 |
+
i = i[:max_det]
|
| 220 |
+
|
| 221 |
+
output[xi] = x[keep,:]
|
| 222 |
+
|
| 223 |
+
if (time.time() - t) > time_limit:
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
#output_numpy = np.load(r"D:\object_face_person_detection\yolov8_tf_results\gustavo-alves-YOXSC4zRcxw-unsplash.npy")
|
| 230 |
+
|
| 231 |
+
#detections = non_max_suppression(output_numpy, conf_thres=0.4, iou_thres=0.4)[0]
|
| 232 |
+
|
| 233 |
+
#print(detections)
|