realzliu commited on
Commit
96f36aa
·
1 Parent(s): c9049cc
Files changed (4) hide show
  1. eval/eval_refer_seg.py +333 -0
  2. eval/refer.py +323 -0
  3. eval/transforms.py +97 -0
  4. eval/val_utils.py +149 -0
eval/eval_refer_seg.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ from tqdm import tqdm
5
+ import random
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+ import datetime
10
+ from model.segment_anything import SamPredictor, sam_model_registry
11
+ from dataset.refer_seg_dataset import ValDataset
12
+ from dataset.grefer_seg_dataset import grefcocoValDataset
13
+ from data.question_answer_list import QUESTION_PARTIAL
14
+ from segment_predictor import GenerativeSegmenter
15
+ from eval.utils import AverageMeter, Summary, intersectionAndUnionGPU, \
16
+ compute_logits_from_mask, masks_sample_points
17
+
18
+ from torch.utils.data import Dataset, DataLoader
19
+ import math
20
+
21
+ # --- Accelerate Import ---
22
+ from accelerate import Accelerator
23
+
24
+
25
+ def get_chunk(ds, n, k):
26
+ chunk_size = math.ceil(len(ds) / n)
27
+ i = chunk_size * k
28
+ start_index = i
29
+ end_index = i + chunk_size
30
+ ds.refer_seg_ds["images"] = ds.refer_seg_ds["images"][start_index:end_index]
31
+ return ds
32
+
33
+
34
+ def gget_chunk(ds, n, k):
35
+ chunk_size = math.ceil(len(ds) / n)
36
+ i = chunk_size * k
37
+ start_index = i
38
+ end_index = i + chunk_size
39
+ ds.loaded_images = ds.loaded_images[start_index:end_index]
40
+ return ds
41
+
42
+
43
+ class CustomDataset(Dataset):
44
+ def __init__(self, sub_dataset):
45
+ self.dataset = sub_dataset
46
+
47
+ def __getitem__(self, index):
48
+ image, masks, questions, image_path = self.dataset[index]
49
+ image_name = os.path.basename(image_path).split(".")[0]
50
+ questions = [random.choice(QUESTION_PARTIAL).replace("[class_name]", q) for q in questions]
51
+ return image, masks, image_name, questions, image_path
52
+
53
+ def __len__(self):
54
+ return len(self.dataset)
55
+
56
+
57
+ def collate_fn(batch):
58
+ images, masks, image_names, questions, image_paths = zip(*batch)
59
+ return images, masks, image_names, questions, image_paths
60
+
61
+
62
+ def create_data_loader(args, sub_dataset, batch_size=1):
63
+ assert batch_size == 1, "Batch size must be 1 for this evaluation script."
64
+ dataset = CustomDataset(sub_dataset)
65
+ return DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=collate_fn)
66
+
67
+
68
+ def eval_model(args):
69
+ # --- Initialize Accelerator ---
70
+ accelerator = Accelerator()
71
+
72
+ # --- Model and Predictor Initialization ---
73
+ segmenter = GenerativeSegmenter(args.model_path, device_map=accelerator.device, min_pixels=args.min_pixels,
74
+ max_pixels=args.max_pixels)
75
+
76
+ sam_help = args.sam_path is not None
77
+ if sam_help:
78
+ sam = sam_model_registry["vit_h"](checkpoint=args.sam_path)
79
+ sam = sam.to(dtype=torch.float32, device=accelerator.device)
80
+ predictor = SamPredictor(sam)
81
+ else:
82
+ predictor = None
83
+
84
+ # --- Dataset and DataLoader Initialization ---
85
+ if accelerator.is_main_process:
86
+ print("Loading dataset...")
87
+
88
+ # First, load the full dataset based on the parameters
89
+ if "grefcoco" in args.dataset_split:
90
+ val_dataset = grefcocoValDataset(args.image_folder, args.dataset_split)
91
+ else:
92
+ val_dataset = ValDataset(args.image_folder, args.dataset_split)
93
+
94
+ if accelerator.is_main_process:
95
+ total_data_size = len(val_dataset)
96
+ print(f"Total evaluation data volume (full dataset): {total_data_size} samples.")
97
+
98
+ # Then, get a chunk of the dataset as needed
99
+ if "grefcoco" in args.dataset_split:
100
+ sub_dataset = gget_chunk(val_dataset, args.num_chunks, args.chunk_idx)
101
+ else:
102
+ sub_dataset = get_chunk(val_dataset, args.num_chunks, args.chunk_idx)
103
+
104
+ data_loader = create_data_loader(args, sub_dataset)
105
+ data_loader = accelerator.prepare(data_loader)
106
+
107
+ # --- Metric Meters Initialization ---
108
+ intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
109
+ union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
110
+ acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
111
+
112
+ if sam_help:
113
+ intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
114
+ union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
115
+ acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
116
+
117
+ progress_bar = tqdm(data_loader, disable=not accelerator.is_main_process, total=len(data_loader))
118
+
119
+ for batch in progress_bar:
120
+ images, masks, image_names, questions, image_paths = batch
121
+ image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
122
+ w_ori, h_ori = image.size
123
+
124
+ total_intersection = torch.zeros(2, device=accelerator.device)
125
+ total_union = torch.zeros(2, device=accelerator.device)
126
+ total_acc_iou = torch.zeros(2, device=accelerator.device)
127
+
128
+ if sam_help:
129
+ total_intersection_sam = torch.zeros(2, device=accelerator.device)
130
+ total_union_sam = torch.zeros(2, device=accelerator.device)
131
+ total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
132
+
133
+ num_masks_in_image = len(prompts)
134
+
135
+ with torch.inference_mode():
136
+ if sam_help:
137
+ predictor.set_image(np.array(image))
138
+
139
+ for i, question in enumerate(prompts):
140
+ gt_mask = gt_masks[i].to(accelerator.device).float().contiguous()
141
+ segmentation_masks, _ = segmenter.generate_with_segmentation(image, question)
142
+
143
+ if segmentation_masks is None or len(segmentation_masks) == 0:
144
+ pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
145
+ else:
146
+ mask = segmentation_masks[0].to(accelerator.device)
147
+ pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
148
+ mode='nearest').squeeze()
149
+
150
+ # if accelerator.is_main_process:
151
+ # print("\n" + "=" * 20 + " DEBUG INFO (first iteration) " + "=" * 20)
152
+ # print(f"Image Name: {image_name}")
153
+ # print(f"GT Mask Shape: {gt_mask.shape}")
154
+ # print(f"GT Mask DType: {gt_mask.dtype}")
155
+ # print(f"Unique values in GT Mask: {torch.unique(gt_mask)}")
156
+ # print(f"Pred Mask Shape: {pred_mask.shape}")
157
+ # print(f"Pred Mask DType: {pred_mask.dtype}")
158
+ # # Print the number of non-zero pixels in the predicted mask to check if the model generated an all-black image
159
+ # print(f"Number of non-zero pixels in Pred Mask: {torch.count_nonzero(pred_mask)}")
160
+ # print("=" * 68)
161
+
162
+ sam_refined_mask = torch.zeros_like(pred_mask)
163
+ if sam_help:
164
+ unique_classes = torch.unique(pred_mask)
165
+ for class_id in unique_classes:
166
+ if class_id == 0: continue
167
+ binary_mask = (pred_mask == class_id).double().cpu()
168
+ try:
169
+ logits = compute_logits_from_mask(pred_mask.cpu())
170
+ point_coords, point_labels = masks_sample_points(binary_mask)
171
+ sam_mask, _, logit = predictor.predict(point_coords=point_coords,
172
+ point_labels=point_labels,
173
+ mask_input=logits, multimask_output=False)
174
+ for _ in range(2):
175
+ sam_mask, _, logit = predictor.predict(point_coords=point_coords,
176
+ point_labels=point_labels,
177
+ mask_input=logit, multimask_output=False)
178
+ sam_mask = sam_mask[0].astype(np.float32)
179
+ except Exception as E:
180
+ print(f"Error: {E}")
181
+ sam_mask = np.zeros((h_ori, w_ori))
182
+ sam_refined_mask = torch.from_numpy(sam_mask).to(accelerator.device)
183
+ # sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
184
+
185
+ intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
186
+
187
+ total_intersection += intersection_i
188
+ total_union += union_i
189
+
190
+ iou_per_sample = intersection_i / (union_i + 1e-5)
191
+ iou_per_sample[union_i == 0] = 1.0
192
+ total_acc_iou += iou_per_sample
193
+
194
+ if sam_help:
195
+ intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2,
196
+ ignore_index=255)
197
+ total_intersection_sam += intersection_sam_i
198
+ total_union_sam += union_sam_i
199
+
200
+ iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
201
+ iou_per_sample_sam[union_sam_i == 0] = 1.0
202
+ total_acc_iou_sam += iou_per_sample_sam
203
+
204
+ if args.save_masks and accelerator.is_main_process:
205
+ ds_split_sanitized = args.dataset_split.replace("|", "_")
206
+ model_name = os.path.basename(args.model_path.strip('/'))
207
+ save_path = os.path.join(args.save_file, model_name, ds_split_sanitized, "masks", image_name)
208
+ if not os.path.exists(save_path): os.makedirs(save_path)
209
+
210
+ Image.fromarray(pred_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
211
+ os.path.join(save_path, f"{i}_pred_mask.png"))
212
+ if sam_help:
213
+ Image.fromarray(sam_refined_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
214
+ os.path.join(save_path, f"{i}_sam_mask.png"))
215
+ Image.fromarray(gt_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
216
+ os.path.join(save_path, f"{i}_gt_mask.png"))
217
+ image.save(os.path.join(save_path, f"{i}_image.png"))
218
+
219
+ intersection_meter.update(total_intersection.cpu().numpy())
220
+ union_meter.update(total_union.cpu().numpy())
221
+ if sam_help:
222
+ intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
223
+ union_meter_sam.update(total_union_sam.cpu().numpy())
224
+ if num_masks_in_image > 0:
225
+ total_acc_iou = total_acc_iou / num_masks_in_image
226
+ acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
227
+ if sam_help:
228
+ total_acc_iou_sam = total_acc_iou_sam / num_masks_in_image
229
+ acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
230
+ # break
231
+ # --- Synchronize metrics across all processes ---
232
+ all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
233
+ all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
234
+ all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
235
+ all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
236
+
237
+ all_intersections = all_intersections.view(-1, 2)
238
+ all_unions = all_unions.view(-1, 2)
239
+ all_giou_sum = all_giou_sum.view(-1, 2)
240
+ all_giou_count = all_giou_count.view(-1, 1)
241
+
242
+ if sam_help:
243
+ all_intersections_sam = accelerator.gather_for_metrics(
244
+ torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
245
+ all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
246
+ all_giou_sum_sam = accelerator.gather_for_metrics(
247
+ torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
248
+ all_giou_count_sam = accelerator.gather_for_metrics(
249
+ torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
250
+
251
+ all_intersections_sam = all_intersections_sam.view(-1, 2)
252
+ all_unions_sam = all_unions_sam.view(-1, 2)
253
+ all_giou_sum_sam = all_giou_sum_sam.view(-1, 2)
254
+ all_giou_count_sam = all_giou_count_sam.view(-1, 1)
255
+
256
+
257
+ # --- Only calculate and output final results on the main process ---
258
+ if accelerator.is_main_process:
259
+ iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
260
+ # print(all_intersections, all_unions, iou_class)
261
+ ciou = iou_class[1].item()
262
+ giou = (torch.sum(all_giou_sum, dim=0)[1] / torch.sum(all_giou_count)).item()
263
+
264
+ if sam_help:
265
+ iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
266
+ ciou_sam = iou_class_sam[1].item()
267
+ giou_sam = (torch.sum(all_giou_sum_sam, dim=0)[1] / torch.sum(all_giou_count_sam)).item()
268
+ else:
269
+ giou_sam, ciou_sam = 0.0, 0.0
270
+
271
+ # <--- Added: Calculate and print accurate evaluation totals ---
272
+ total_evaluated_images = len(sub_dataset) # Total images evaluated
273
+ total_evaluated_masks = torch.sum(all_giou_count).item() # Total masks/prompts evaluated
274
+ # <--- End added ---
275
+
276
+ print("\n" + "=" * 50)
277
+ print(f"Evaluation finished for: {args.model_path}")
278
+ print(f"Dataset: {args.dataset_split}")
279
+ print("-" * 50)
280
+ # <--- Added: Print evaluation sample counts ---
281
+ print(f"Total images evaluated: {total_evaluated_images}")
282
+ print(f"Total masks/prompts evaluated: {total_evaluated_masks}")
283
+ print("-" * 50)
284
+ # <--- End added ---
285
+ print(f"Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}")
286
+ if sam_help:
287
+ print(f"SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}")
288
+ print("=" * 50 + "\n")
289
+
290
+ # --- Dynamically construct output file path and write results ---
291
+ model_name = os.path.basename(args.model_path.strip('/'))
292
+ ds_split_sanitized = args.dataset_split.replace("|", "_")
293
+ output_dir = os.path.join(args.save_file, model_name, ds_split_sanitized)
294
+ os.makedirs(output_dir, exist_ok=True)
295
+ output_filepath = os.path.join(output_dir, "evaluation_results.txt")
296
+
297
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
298
+ chunk_info = f"(chunk {args.chunk_idx + 1}/{args.num_chunks})" if args.num_chunks > 1 else ""
299
+
300
+ header_text = f"[{current_time}] Model: {args.model_path}, Dataset: {args.dataset_split} {chunk_info}\n"
301
+ # <--- Added: Also record evaluation sample counts in the file ---
302
+ eval_stats_text = f" - Evaluated on {total_evaluated_images} images and {total_evaluated_masks} masks.\n"
303
+ # <--- End added ---
304
+ output_text = f" - Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}\n"
305
+ if sam_help:
306
+ output_text += f" - SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}\n"
307
+
308
+ with open(output_filepath, "a") as file:
309
+ file.write(header_text)
310
+ file.write(eval_stats_text) # <--- Added
311
+ file.write(output_text)
312
+ file.write("-" * 60 + "\n")
313
+
314
+ print(f"Results appended to: {output_filepath}")
315
+
316
+
317
+ if __name__ == "__main__":
318
+ parser = argparse.ArgumentParser()
319
+ parser.add_argument("--model_path", type=str, default='/raid2/DATA/text4seg/model_trained_qwen_2b/',
320
+ help="Path to your GenerativeSegmenter model checkpoint.")
321
+ parser.add_argument("--sam_path", type=str, default='/efficient_sag4text/sam_vit_h_4b8939.pth', help="Path to the SAM checkpoint.")
322
+ parser.add_argument("--image_folder", type=str, default='/efficient_sag4text/seg_data/refer_seg', help="Root folder for the dataset images.")
323
+ parser.add_argument("--dataset_split", type=str, default="refcoco|unc|val", help="Dataset split to evaluate on.")
324
+ parser.add_argument("--save_file", type=str, default="output_eval_accelerated/",
325
+ help="Root directory to save evaluation outputs (masks and metrics).")
326
+ parser.add_argument("--save_masks", action='store_true', help="Set this flag to save output masks and images.")
327
+ parser.add_argument("--num_chunks", type=int, default=1)
328
+ parser.add_argument("--chunk-idx", type=int, default=0)
329
+ parser.add_argument("--min_pixels", type=int, default=1024*28 * 28, help="Minimum pixels for segmentation.")
330
+ parser.add_argument("--max_pixels", type=int, default=1024*28 * 28, help="Maximum pixels for segmentation.")
331
+ args = parser.parse_args()
332
+
333
+ eval_model(args)
eval/refer.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import itertools
3
+ import json
4
+ import os.path as osp
5
+ import pickle
6
+ import sys
7
+ import time
8
+ from pprint import pprint
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import skimage.io as io
13
+ from matplotlib.collections import PatchCollection
14
+ from matplotlib.patches import Polygon, Rectangle
15
+ from pycocotools import mask
16
+
17
+
18
+ class REFER:
19
+ def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
20
+ print("loading dataset %s into memory..." % dataset)
21
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
22
+ self.DATA_DIR = osp.join(data_root, dataset)
23
+ if dataset in ["refcoco", "refcoco+", "refcocog"]:
24
+ self.IMAGE_DIR = osp.join(data_root, "images/coco_2014/train2014")
25
+ elif dataset == "refclef":
26
+ self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
27
+ else:
28
+ print("No refer dataset is called [%s]" % dataset)
29
+ sys.exit()
30
+
31
+ self.dataset = dataset
32
+
33
+ # load refs from data/dataset/refs(dataset).json
34
+ tic = time.time()
35
+
36
+ ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
37
+ print("ref_file: ", ref_file)
38
+ self.data = {}
39
+ self.data["dataset"] = dataset
40
+ self.data["refs"] = pickle.load(open(ref_file, "rb"))
41
+
42
+ # load annotations from data/dataset/instances.json
43
+ instances_file = osp.join(self.DATA_DIR, "instances.json")
44
+ instances = json.load(open(instances_file, "rb"))
45
+ self.data["images"] = instances["images"]
46
+ self.data["annotations"] = instances["annotations"]
47
+ self.data["categories"] = instances["categories"]
48
+
49
+ # create index
50
+ self.createIndex()
51
+ print("DONE (t=%.2fs)" % (time.time() - tic))
52
+
53
+ def createIndex(self):
54
+ # create sets of mapping
55
+ # 1) Refs: {ref_id: ref}
56
+ # 2) Anns: {ann_id: ann}
57
+ # 3) Imgs: {image_id: image}
58
+ # 4) Cats: {category_id: category_name}
59
+ # 5) Sents: {sent_id: sent}
60
+ # 6) imgToRefs: {image_id: refs}
61
+ # 7) imgToAnns: {image_id: anns}
62
+ # 8) refToAnn: {ref_id: ann}
63
+ # 9) annToRef: {ann_id: ref}
64
+ # 10) catToRefs: {category_id: refs}
65
+ # 11) sentToRef: {sent_id: ref}
66
+ # 12) sentToTokens: {sent_id: tokens}
67
+ print("creating index...")
68
+ # fetch info from instances
69
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
70
+ for ann in self.data["annotations"]:
71
+ Anns[ann["id"]] = ann
72
+ imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
73
+ for img in self.data["images"]:
74
+ Imgs[img["id"]] = img
75
+ for cat in self.data["categories"]:
76
+ Cats[cat["id"]] = cat["name"]
77
+
78
+ # fetch info from refs
79
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
80
+ Sents, sentToRef, sentToTokens = {}, {}, {}
81
+ for ref in self.data["refs"]:
82
+ # ids
83
+ ref_id = ref["ref_id"]
84
+ ann_id = ref["ann_id"]
85
+ category_id = ref["category_id"]
86
+ image_id = ref["image_id"]
87
+
88
+ # add mapping related to ref
89
+ Refs[ref_id] = ref
90
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
91
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
92
+ refToAnn[ref_id] = Anns[ann_id]
93
+ annToRef[ann_id] = ref
94
+
95
+ # add mapping of sent
96
+ for sent in ref["sentences"]:
97
+ Sents[sent["sent_id"]] = sent
98
+ sentToRef[sent["sent_id"]] = ref
99
+ sentToTokens[sent["sent_id"]] = sent["tokens"]
100
+
101
+ # create class members
102
+ self.Refs = Refs
103
+ self.Anns = Anns
104
+ self.Imgs = Imgs
105
+ self.Cats = Cats
106
+ self.Sents = Sents
107
+ self.imgToRefs = imgToRefs
108
+ self.imgToAnns = imgToAnns
109
+ self.refToAnn = refToAnn
110
+ self.annToRef = annToRef
111
+ self.catToRefs = catToRefs
112
+ self.sentToRef = sentToRef
113
+ self.sentToTokens = sentToTokens
114
+ print("index created.")
115
+
116
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
117
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
118
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
119
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
120
+
121
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
122
+ refs = self.data["refs"]
123
+ else:
124
+ if not len(image_ids) == 0:
125
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
126
+ else:
127
+ refs = self.data["refs"]
128
+ if not len(cat_ids) == 0:
129
+ refs = [ref for ref in refs if ref["category_id"] in cat_ids]
130
+ if not len(ref_ids) == 0:
131
+ refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
132
+ if not len(split) == 0:
133
+ if split in ["testA", "testB", "testC"]:
134
+ refs = [
135
+ ref for ref in refs if split[-1] in ref["split"]
136
+ ] # we also consider testAB, testBC, ...
137
+ elif split in ["testAB", "testBC", "testAC"]:
138
+ refs = [
139
+ ref for ref in refs if ref["split"] == split
140
+ ] # rarely used I guess...
141
+ elif split == "test":
142
+ refs = [ref for ref in refs if "test" in ref["split"]]
143
+ elif split == "train" or split == "val":
144
+ refs = [ref for ref in refs if ref["split"] == split]
145
+ else:
146
+ print("No such split [%s]" % split)
147
+ sys.exit()
148
+ ref_ids = [ref["ref_id"] for ref in refs]
149
+ return ref_ids
150
+
151
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
152
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
153
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
154
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
155
+
156
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
157
+ ann_ids = [ann["id"] for ann in self.data["annotations"]]
158
+ else:
159
+ if not len(image_ids) == 0:
160
+ lists = [
161
+ self.imgToAnns[image_id]
162
+ for image_id in image_ids
163
+ if image_id in self.imgToAnns
164
+ ] # list of [anns]
165
+ anns = list(itertools.chain.from_iterable(lists))
166
+ else:
167
+ anns = self.data["annotations"]
168
+ if not len(cat_ids) == 0:
169
+ anns = [ann for ann in anns if ann["category_id"] in cat_ids]
170
+ ann_ids = [ann["id"] for ann in anns]
171
+ if not len(ref_ids) == 0:
172
+ ids = set(ann_ids).intersection(
173
+ set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
174
+ )
175
+ return ann_ids
176
+
177
+ def getImgIds(self, ref_ids=[]):
178
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
179
+
180
+ if not len(ref_ids) == 0:
181
+ image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
182
+ else:
183
+ image_ids = self.Imgs.keys()
184
+ return image_ids
185
+
186
+ def getCatIds(self):
187
+ return self.Cats.keys()
188
+
189
+ def loadRefs(self, ref_ids=[]):
190
+ if type(ref_ids) == list:
191
+ return [self.Refs[ref_id] for ref_id in ref_ids]
192
+ elif type(ref_ids) == int:
193
+ return [self.Refs[ref_ids]]
194
+
195
+ def loadAnns(self, ann_ids=[]):
196
+ if type(ann_ids) == list:
197
+ return [self.Anns[ann_id] for ann_id in ann_ids]
198
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
199
+ return [self.Anns[ann_ids]]
200
+
201
+ def loadImgs(self, image_ids=[]):
202
+ if type(image_ids) == list:
203
+ return [self.Imgs[image_id] for image_id in image_ids]
204
+ elif type(image_ids) == int:
205
+ return [self.Imgs[image_ids]]
206
+
207
+ def loadCats(self, cat_ids=[]):
208
+ if type(cat_ids) == list:
209
+ return [self.Cats[cat_id] for cat_id in cat_ids]
210
+ elif type(cat_ids) == int:
211
+ return [self.Cats[cat_ids]]
212
+
213
+ def getRefBox(self, ref_id):
214
+ ref = self.Refs[ref_id]
215
+ ann = self.refToAnn[ref_id]
216
+ return ann["bbox"] # [x, y, w, h]
217
+
218
+ def showRef(self, ref, seg_box="seg"):
219
+ ax = plt.gca()
220
+ # show image
221
+ image = self.Imgs[ref["image_id"]]
222
+ I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
223
+ ax.imshow(I)
224
+ # show refer expression
225
+ for sid, sent in enumerate(ref["sentences"]):
226
+ print("%s. %s" % (sid + 1, sent["sent"]))
227
+ # show segmentations
228
+ if seg_box == "seg":
229
+ ann_id = ref["ann_id"]
230
+ ann = self.Anns[ann_id]
231
+ polygons = []
232
+ color = []
233
+ c = "none"
234
+ if type(ann["segmentation"][0]) == list:
235
+ # polygon used for refcoco*
236
+ for seg in ann["segmentation"]:
237
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
238
+ polygons.append(Polygon(poly, True, alpha=0.4))
239
+ color.append(c)
240
+ p = PatchCollection(
241
+ polygons,
242
+ facecolors=color,
243
+ edgecolors=(1, 1, 0, 0),
244
+ linewidths=3,
245
+ alpha=1,
246
+ )
247
+ ax.add_collection(p) # thick yellow polygon
248
+ p = PatchCollection(
249
+ polygons,
250
+ facecolors=color,
251
+ edgecolors=(1, 0, 0, 0),
252
+ linewidths=1,
253
+ alpha=1,
254
+ )
255
+ ax.add_collection(p) # thin red polygon
256
+ else:
257
+ # mask used for refclef
258
+ rle = ann["segmentation"]
259
+ m = mask.decode(rle)
260
+ img = np.ones((m.shape[0], m.shape[1], 3))
261
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
262
+ for i in range(3):
263
+ img[:, :, i] = color_mask[i]
264
+ ax.imshow(np.dstack((img, m * 0.5)))
265
+ # show bounding-box
266
+ elif seg_box == "box":
267
+ ann_id = ref["ann_id"]
268
+ ann = self.Anns[ann_id]
269
+ bbox = self.getRefBox(ref["ref_id"])
270
+ box_plot = Rectangle(
271
+ (bbox[0], bbox[1]),
272
+ bbox[2],
273
+ bbox[3],
274
+ fill=False,
275
+ edgecolor="green",
276
+ linewidth=3,
277
+ )
278
+ ax.add_patch(box_plot)
279
+
280
+ def getMask(self, ref):
281
+ # return mask, area and mask-center
282
+ ann = self.refToAnn[ref["ref_id"]]
283
+ image = self.Imgs[ref["image_id"]]
284
+ if type(ann["segmentation"][0]) == list: # polygon
285
+ rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
286
+ else:
287
+ rle = ann["segmentation"]
288
+ m = mask.decode(rle)
289
+ m = np.sum(
290
+ m, axis=2
291
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
292
+ m = m.astype(np.uint8) # convert to np.uint8
293
+ # compute area
294
+ area = sum(mask.area(rle)) # should be close to ann['area']
295
+ return {"mask": m, "area": area}
296
+ def showMask(self, ref):
297
+ M = self.getMask(ref)
298
+ msk = M["mask"]
299
+ ax = plt.gca()
300
+ ax.imshow(msk)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ refer = REFER(dataset="refcocog", splitBy="google")
305
+ ref_ids = refer.getRefIds()
306
+ print(len(ref_ids))
307
+
308
+ print(len(refer.Imgs))
309
+ print(len(refer.imgToRefs))
310
+
311
+ ref_ids = refer.getRefIds(split="train")
312
+ print("There are %s training referred objects." % len(ref_ids))
313
+
314
+ for ref_id in ref_ids:
315
+ ref = refer.loadRefs(ref_id)[0]
316
+ if len(ref["sentences"]) < 2:
317
+ continue
318
+
319
+ pprint(ref)
320
+ print("The label is %s." % refer.Cats[ref["category_id"]])
321
+ plt.figure()
322
+ refer.showRef(ref, seg_box="box")
323
+ plt.show()
eval/transforms.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
6
+
7
+ from copy import deepcopy
8
+ from typing import Tuple
9
+
10
+
11
+ class ResizeLongestSide:
12
+ """
13
+ Resizes images to the longest side 'target_length', as well as provides
14
+ methods for resizing coordinates and boxes. Provides methods for
15
+ transforming both numpy array and batched torch tensors.
16
+ """
17
+
18
+ def __init__(self, target_length: int) -> None:
19
+ self.target_length = target_length
20
+
21
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
22
+ """
23
+ Expects a numpy array with shape HxWxC in uint8 format.
24
+ """
25
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
26
+ return np.array(resize(to_pil_image(image), target_size))
27
+
28
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
29
+ """
30
+ Expects a numpy array of length 2 in the final dimension. Requires the
31
+ original image size in (H, W) format.
32
+ """
33
+ old_h, old_w = original_size
34
+ new_h, new_w = self.get_preprocess_shape(
35
+ original_size[0], original_size[1], self.target_length
36
+ )
37
+ coords = deepcopy(coords).astype(float)
38
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
39
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
40
+ return coords
41
+
42
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
43
+ """
44
+ Expects a numpy array shape Bx4. Requires the original image size
45
+ in (H, W) format.
46
+ """
47
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
48
+ return boxes.reshape(-1, 4)
49
+
50
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Expects batched images with shape BxCxHxW and float format. This
53
+ transformation may not exactly match apply_image. apply_image is
54
+ the transformation expected by the model.
55
+ """
56
+ # Expects an image in BCHW format. May not exactly match apply_image.
57
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
58
+ return F.interpolate(
59
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
60
+ )
61
+
62
+ def apply_coords_torch(
63
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
64
+ ) -> torch.Tensor:
65
+ """
66
+ Expects a torch tensor with length 2 in the last dimension. Requires the
67
+ original image size in (H, W) format.
68
+ """
69
+ old_h, old_w = original_size
70
+ new_h, new_w = self.get_preprocess_shape(
71
+ original_size[0], original_size[1], self.target_length
72
+ )
73
+ coords = deepcopy(coords).to(torch.float)
74
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
75
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
76
+ return coords
77
+
78
+ def apply_boxes_torch(
79
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
80
+ ) -> torch.Tensor:
81
+ """
82
+ Expects a torch tensor with shape Bx4. Requires the original image
83
+ size in (H, W) format.
84
+ """
85
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
86
+ return boxes.reshape(-1, 4)
87
+
88
+ @staticmethod
89
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
90
+ """
91
+ Compute the output size given input size and target long side length.
92
+ """
93
+ scale = long_side_length * 1.0 / max(oldh, oldw)
94
+ newh, neww = oldh * scale, oldw * scale
95
+ neww = int(neww + 0.5)
96
+ newh = int(newh + 0.5)
97
+ return (newh, neww)
eval/val_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from PIL import Image
6
+ from .utils import AverageMeter, Summary, intersectionAndUnionGPU, compute_logits_from_mask, masks_sample_points
7
+
8
+
9
+ def run_in_process_evaluation(model, accelerator, eval_dataloader, sam_predictor=None):
10
+ # --- 1. Initialize all metric recorders ---
11
+ intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
12
+ union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
13
+ acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
14
+
15
+ use_sam = sam_predictor is not None
16
+ if use_sam:
17
+ intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
18
+ union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
19
+ acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
20
+
21
+ progress_bar = tqdm(eval_dataloader, disable=not accelerator.is_main_process, desc="Running Evaluation")
22
+
23
+ # --- 2. Iterate over the evaluation dataset ---
24
+ for batch in progress_bar:
25
+ images, masks, image_names, questions, image_paths = batch
26
+ image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
27
+ w_ori, h_ori = image.size
28
+
29
+ total_intersection = torch.zeros(2, device=accelerator.device)
30
+ total_union = torch.zeros(2, device=accelerator.device)
31
+ total_acc_iou = torch.zeros(2, device=accelerator.device)
32
+
33
+ if use_sam:
34
+ total_intersection_sam = torch.zeros(2, device=accelerator.device)
35
+ total_union_sam = torch.zeros(2, device=accelerator.device)
36
+ total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
37
+
38
+ num_masks_in_image = len(prompts)
39
+ if num_masks_in_image == 0:
40
+ continue
41
+
42
+ with torch.inference_mode():
43
+ if use_sam:
44
+ predictor = sam_predictor
45
+ predictor.set_image(np.array(image))
46
+
47
+ for i, question in enumerate(prompts):
48
+ gt_mask = gt_masks[i].to(accelerator.device)
49
+
50
+ # --- Key difference: directly use the passed-in model object for inference ---
51
+ segmentation_masks, _ = model.generate_with_segmentation(image, question)
52
+
53
+ if segmentation_masks is None or len(segmentation_masks) == 0:
54
+ pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
55
+ else:
56
+ mask = segmentation_masks[0].to(accelerator.device)
57
+ pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
58
+ mode='nearest').squeeze()
59
+
60
+ # --- SAM post-processing (if enabled) ---
61
+ if use_sam:
62
+ sam_refined_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
63
+ unique_classes = torch.unique(pred_mask)
64
+ for class_id in unique_classes:
65
+ if class_id == 0: continue
66
+ binary_mask = (pred_mask == class_id).double()
67
+ try:
68
+ logits = compute_logits_from_mask(binary_mask)
69
+ point_coords, point_labels = masks_sample_points(binary_mask)
70
+ sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
71
+ mask_input=logits, multimask_output=False)
72
+ for _ in range(2):
73
+ sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
74
+ mask_input=logit, multimask_output=False)
75
+ except Exception:
76
+ sam_mask = np.zeros((1, h_ori, w_ori))
77
+ sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
78
+
79
+ # --- Metric calculation (current mask) ---
80
+ # Use the original model for prediction
81
+ # <-- MODIFICATION 1: Add ignore_index=255
82
+ intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
83
+ total_intersection += intersection_i
84
+ total_union += union_i
85
+ # <-- MODIFICATION 2: Change epsilon to 1e-5
86
+ iou_per_sample = intersection_i / (union_i + 1e-5)
87
+ iou_per_sample[union_i == 0] = 1.0
88
+ total_acc_iou += iou_per_sample
89
+
90
+ # SAM optimized prediction
91
+ if use_sam:
92
+ # <-- MODIFICATION 1 (SAM): Add ignore_index=255
93
+ intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2, ignore_index=255)
94
+ total_intersection_sam += intersection_sam_i
95
+ total_union_sam += union_sam_i
96
+ # <-- MODIFICATION 2 (SAM): Change epsilon to 1e-5
97
+ iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
98
+ iou_per_sample_sam[union_sam_i == 0] = 1.0
99
+ total_acc_iou_sam += iou_per_sample_sam
100
+
101
+ # Update global recorders
102
+ intersection_meter.update(total_intersection.cpu().numpy())
103
+ union_meter.update(total_union.cpu().numpy())
104
+ acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
105
+ if use_sam:
106
+ intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
107
+ union_meter_sam.update(total_union_sam.cpu().numpy())
108
+ acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
109
+
110
+ # --- 3. Aggregate metrics from all GPUs ---
111
+ all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
112
+ all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
113
+ all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
114
+ all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
115
+
116
+ if use_sam:
117
+ all_intersections_sam = accelerator.gather_for_metrics(
118
+ torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
119
+ all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
120
+ all_giou_sum_sam = accelerator.gather_for_metrics(
121
+ torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
122
+ all_giou_count_sam = accelerator.gather_for_metrics(
123
+ torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
124
+
125
+ # --- 4. Calculate final results and return on the main process ---
126
+ final_metrics = {}
127
+ if accelerator.is_main_process:
128
+ # original model metrics
129
+ # <-- MODIFICATION 2 (cIoU): Change epsilon to 1e-5
130
+ iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
131
+ ciou = iou_class[1].item()
132
+ giou_sum = torch.sum(all_giou_sum, dim=0)[1]
133
+ giou_count = torch.sum(all_giou_count)
134
+ giou = (giou_sum / giou_count).item() if giou_count > 0 else 0.0
135
+ final_metrics['giou'] = giou
136
+ final_metrics['ciou'] = ciou
137
+
138
+ # SAM optimized metrics
139
+ if use_sam:
140
+ # <-- MODIFICATION 2 (cIoU SAM): Change epsilon to 1e-5
141
+ iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
142
+ ciou_sam = iou_class_sam[1].item()
143
+ giou_sum_sam = torch.sum(all_giou_sum_sam, dim=0)[1]
144
+ giou_count_sam = torch.sum(all_giou_count_sam)
145
+ giou_sam = (giou_sum_sam / giou_count_sam).item() if giou_count_sam > 0 else 0.0
146
+ final_metrics['sam_giou'] = giou_sam
147
+ final_metrics['sam_ciou'] = ciou_sam
148
+
149
+ return final_metrics