Spaces:
Running
Running
realzliu
commited on
Commit
·
96f36aa
1
Parent(s):
c9049cc
init
Browse files- eval/eval_refer_seg.py +333 -0
- eval/refer.py +323 -0
- eval/transforms.py +97 -0
- eval/val_utils.py +149 -0
eval/eval_refer_seg.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import random
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import datetime
|
| 10 |
+
from model.segment_anything import SamPredictor, sam_model_registry
|
| 11 |
+
from dataset.refer_seg_dataset import ValDataset
|
| 12 |
+
from dataset.grefer_seg_dataset import grefcocoValDataset
|
| 13 |
+
from data.question_answer_list import QUESTION_PARTIAL
|
| 14 |
+
from segment_predictor import GenerativeSegmenter
|
| 15 |
+
from eval.utils import AverageMeter, Summary, intersectionAndUnionGPU, \
|
| 16 |
+
compute_logits_from_mask, masks_sample_points
|
| 17 |
+
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
# --- Accelerate Import ---
|
| 22 |
+
from accelerate import Accelerator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_chunk(ds, n, k):
|
| 26 |
+
chunk_size = math.ceil(len(ds) / n)
|
| 27 |
+
i = chunk_size * k
|
| 28 |
+
start_index = i
|
| 29 |
+
end_index = i + chunk_size
|
| 30 |
+
ds.refer_seg_ds["images"] = ds.refer_seg_ds["images"][start_index:end_index]
|
| 31 |
+
return ds
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def gget_chunk(ds, n, k):
|
| 35 |
+
chunk_size = math.ceil(len(ds) / n)
|
| 36 |
+
i = chunk_size * k
|
| 37 |
+
start_index = i
|
| 38 |
+
end_index = i + chunk_size
|
| 39 |
+
ds.loaded_images = ds.loaded_images[start_index:end_index]
|
| 40 |
+
return ds
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CustomDataset(Dataset):
|
| 44 |
+
def __init__(self, sub_dataset):
|
| 45 |
+
self.dataset = sub_dataset
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index):
|
| 48 |
+
image, masks, questions, image_path = self.dataset[index]
|
| 49 |
+
image_name = os.path.basename(image_path).split(".")[0]
|
| 50 |
+
questions = [random.choice(QUESTION_PARTIAL).replace("[class_name]", q) for q in questions]
|
| 51 |
+
return image, masks, image_name, questions, image_path
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.dataset)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def collate_fn(batch):
|
| 58 |
+
images, masks, image_names, questions, image_paths = zip(*batch)
|
| 59 |
+
return images, masks, image_names, questions, image_paths
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_data_loader(args, sub_dataset, batch_size=1):
|
| 63 |
+
assert batch_size == 1, "Batch size must be 1 for this evaluation script."
|
| 64 |
+
dataset = CustomDataset(sub_dataset)
|
| 65 |
+
return DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=collate_fn)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def eval_model(args):
|
| 69 |
+
# --- Initialize Accelerator ---
|
| 70 |
+
accelerator = Accelerator()
|
| 71 |
+
|
| 72 |
+
# --- Model and Predictor Initialization ---
|
| 73 |
+
segmenter = GenerativeSegmenter(args.model_path, device_map=accelerator.device, min_pixels=args.min_pixels,
|
| 74 |
+
max_pixels=args.max_pixels)
|
| 75 |
+
|
| 76 |
+
sam_help = args.sam_path is not None
|
| 77 |
+
if sam_help:
|
| 78 |
+
sam = sam_model_registry["vit_h"](checkpoint=args.sam_path)
|
| 79 |
+
sam = sam.to(dtype=torch.float32, device=accelerator.device)
|
| 80 |
+
predictor = SamPredictor(sam)
|
| 81 |
+
else:
|
| 82 |
+
predictor = None
|
| 83 |
+
|
| 84 |
+
# --- Dataset and DataLoader Initialization ---
|
| 85 |
+
if accelerator.is_main_process:
|
| 86 |
+
print("Loading dataset...")
|
| 87 |
+
|
| 88 |
+
# First, load the full dataset based on the parameters
|
| 89 |
+
if "grefcoco" in args.dataset_split:
|
| 90 |
+
val_dataset = grefcocoValDataset(args.image_folder, args.dataset_split)
|
| 91 |
+
else:
|
| 92 |
+
val_dataset = ValDataset(args.image_folder, args.dataset_split)
|
| 93 |
+
|
| 94 |
+
if accelerator.is_main_process:
|
| 95 |
+
total_data_size = len(val_dataset)
|
| 96 |
+
print(f"Total evaluation data volume (full dataset): {total_data_size} samples.")
|
| 97 |
+
|
| 98 |
+
# Then, get a chunk of the dataset as needed
|
| 99 |
+
if "grefcoco" in args.dataset_split:
|
| 100 |
+
sub_dataset = gget_chunk(val_dataset, args.num_chunks, args.chunk_idx)
|
| 101 |
+
else:
|
| 102 |
+
sub_dataset = get_chunk(val_dataset, args.num_chunks, args.chunk_idx)
|
| 103 |
+
|
| 104 |
+
data_loader = create_data_loader(args, sub_dataset)
|
| 105 |
+
data_loader = accelerator.prepare(data_loader)
|
| 106 |
+
|
| 107 |
+
# --- Metric Meters Initialization ---
|
| 108 |
+
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
| 109 |
+
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
|
| 110 |
+
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
| 111 |
+
|
| 112 |
+
if sam_help:
|
| 113 |
+
intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
| 114 |
+
union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
|
| 115 |
+
acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
| 116 |
+
|
| 117 |
+
progress_bar = tqdm(data_loader, disable=not accelerator.is_main_process, total=len(data_loader))
|
| 118 |
+
|
| 119 |
+
for batch in progress_bar:
|
| 120 |
+
images, masks, image_names, questions, image_paths = batch
|
| 121 |
+
image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
|
| 122 |
+
w_ori, h_ori = image.size
|
| 123 |
+
|
| 124 |
+
total_intersection = torch.zeros(2, device=accelerator.device)
|
| 125 |
+
total_union = torch.zeros(2, device=accelerator.device)
|
| 126 |
+
total_acc_iou = torch.zeros(2, device=accelerator.device)
|
| 127 |
+
|
| 128 |
+
if sam_help:
|
| 129 |
+
total_intersection_sam = torch.zeros(2, device=accelerator.device)
|
| 130 |
+
total_union_sam = torch.zeros(2, device=accelerator.device)
|
| 131 |
+
total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
|
| 132 |
+
|
| 133 |
+
num_masks_in_image = len(prompts)
|
| 134 |
+
|
| 135 |
+
with torch.inference_mode():
|
| 136 |
+
if sam_help:
|
| 137 |
+
predictor.set_image(np.array(image))
|
| 138 |
+
|
| 139 |
+
for i, question in enumerate(prompts):
|
| 140 |
+
gt_mask = gt_masks[i].to(accelerator.device).float().contiguous()
|
| 141 |
+
segmentation_masks, _ = segmenter.generate_with_segmentation(image, question)
|
| 142 |
+
|
| 143 |
+
if segmentation_masks is None or len(segmentation_masks) == 0:
|
| 144 |
+
pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
|
| 145 |
+
else:
|
| 146 |
+
mask = segmentation_masks[0].to(accelerator.device)
|
| 147 |
+
pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
|
| 148 |
+
mode='nearest').squeeze()
|
| 149 |
+
|
| 150 |
+
# if accelerator.is_main_process:
|
| 151 |
+
# print("\n" + "=" * 20 + " DEBUG INFO (first iteration) " + "=" * 20)
|
| 152 |
+
# print(f"Image Name: {image_name}")
|
| 153 |
+
# print(f"GT Mask Shape: {gt_mask.shape}")
|
| 154 |
+
# print(f"GT Mask DType: {gt_mask.dtype}")
|
| 155 |
+
# print(f"Unique values in GT Mask: {torch.unique(gt_mask)}")
|
| 156 |
+
# print(f"Pred Mask Shape: {pred_mask.shape}")
|
| 157 |
+
# print(f"Pred Mask DType: {pred_mask.dtype}")
|
| 158 |
+
# # Print the number of non-zero pixels in the predicted mask to check if the model generated an all-black image
|
| 159 |
+
# print(f"Number of non-zero pixels in Pred Mask: {torch.count_nonzero(pred_mask)}")
|
| 160 |
+
# print("=" * 68)
|
| 161 |
+
|
| 162 |
+
sam_refined_mask = torch.zeros_like(pred_mask)
|
| 163 |
+
if sam_help:
|
| 164 |
+
unique_classes = torch.unique(pred_mask)
|
| 165 |
+
for class_id in unique_classes:
|
| 166 |
+
if class_id == 0: continue
|
| 167 |
+
binary_mask = (pred_mask == class_id).double().cpu()
|
| 168 |
+
try:
|
| 169 |
+
logits = compute_logits_from_mask(pred_mask.cpu())
|
| 170 |
+
point_coords, point_labels = masks_sample_points(binary_mask)
|
| 171 |
+
sam_mask, _, logit = predictor.predict(point_coords=point_coords,
|
| 172 |
+
point_labels=point_labels,
|
| 173 |
+
mask_input=logits, multimask_output=False)
|
| 174 |
+
for _ in range(2):
|
| 175 |
+
sam_mask, _, logit = predictor.predict(point_coords=point_coords,
|
| 176 |
+
point_labels=point_labels,
|
| 177 |
+
mask_input=logit, multimask_output=False)
|
| 178 |
+
sam_mask = sam_mask[0].astype(np.float32)
|
| 179 |
+
except Exception as E:
|
| 180 |
+
print(f"Error: {E}")
|
| 181 |
+
sam_mask = np.zeros((h_ori, w_ori))
|
| 182 |
+
sam_refined_mask = torch.from_numpy(sam_mask).to(accelerator.device)
|
| 183 |
+
# sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
|
| 184 |
+
|
| 185 |
+
intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
|
| 186 |
+
|
| 187 |
+
total_intersection += intersection_i
|
| 188 |
+
total_union += union_i
|
| 189 |
+
|
| 190 |
+
iou_per_sample = intersection_i / (union_i + 1e-5)
|
| 191 |
+
iou_per_sample[union_i == 0] = 1.0
|
| 192 |
+
total_acc_iou += iou_per_sample
|
| 193 |
+
|
| 194 |
+
if sam_help:
|
| 195 |
+
intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2,
|
| 196 |
+
ignore_index=255)
|
| 197 |
+
total_intersection_sam += intersection_sam_i
|
| 198 |
+
total_union_sam += union_sam_i
|
| 199 |
+
|
| 200 |
+
iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
|
| 201 |
+
iou_per_sample_sam[union_sam_i == 0] = 1.0
|
| 202 |
+
total_acc_iou_sam += iou_per_sample_sam
|
| 203 |
+
|
| 204 |
+
if args.save_masks and accelerator.is_main_process:
|
| 205 |
+
ds_split_sanitized = args.dataset_split.replace("|", "_")
|
| 206 |
+
model_name = os.path.basename(args.model_path.strip('/'))
|
| 207 |
+
save_path = os.path.join(args.save_file, model_name, ds_split_sanitized, "masks", image_name)
|
| 208 |
+
if not os.path.exists(save_path): os.makedirs(save_path)
|
| 209 |
+
|
| 210 |
+
Image.fromarray(pred_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
|
| 211 |
+
os.path.join(save_path, f"{i}_pred_mask.png"))
|
| 212 |
+
if sam_help:
|
| 213 |
+
Image.fromarray(sam_refined_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
|
| 214 |
+
os.path.join(save_path, f"{i}_sam_mask.png"))
|
| 215 |
+
Image.fromarray(gt_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
|
| 216 |
+
os.path.join(save_path, f"{i}_gt_mask.png"))
|
| 217 |
+
image.save(os.path.join(save_path, f"{i}_image.png"))
|
| 218 |
+
|
| 219 |
+
intersection_meter.update(total_intersection.cpu().numpy())
|
| 220 |
+
union_meter.update(total_union.cpu().numpy())
|
| 221 |
+
if sam_help:
|
| 222 |
+
intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
|
| 223 |
+
union_meter_sam.update(total_union_sam.cpu().numpy())
|
| 224 |
+
if num_masks_in_image > 0:
|
| 225 |
+
total_acc_iou = total_acc_iou / num_masks_in_image
|
| 226 |
+
acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
|
| 227 |
+
if sam_help:
|
| 228 |
+
total_acc_iou_sam = total_acc_iou_sam / num_masks_in_image
|
| 229 |
+
acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
|
| 230 |
+
# break
|
| 231 |
+
# --- Synchronize metrics across all processes ---
|
| 232 |
+
all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
|
| 233 |
+
all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
|
| 234 |
+
all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
|
| 235 |
+
all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
|
| 236 |
+
|
| 237 |
+
all_intersections = all_intersections.view(-1, 2)
|
| 238 |
+
all_unions = all_unions.view(-1, 2)
|
| 239 |
+
all_giou_sum = all_giou_sum.view(-1, 2)
|
| 240 |
+
all_giou_count = all_giou_count.view(-1, 1)
|
| 241 |
+
|
| 242 |
+
if sam_help:
|
| 243 |
+
all_intersections_sam = accelerator.gather_for_metrics(
|
| 244 |
+
torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
|
| 245 |
+
all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
|
| 246 |
+
all_giou_sum_sam = accelerator.gather_for_metrics(
|
| 247 |
+
torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
|
| 248 |
+
all_giou_count_sam = accelerator.gather_for_metrics(
|
| 249 |
+
torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
|
| 250 |
+
|
| 251 |
+
all_intersections_sam = all_intersections_sam.view(-1, 2)
|
| 252 |
+
all_unions_sam = all_unions_sam.view(-1, 2)
|
| 253 |
+
all_giou_sum_sam = all_giou_sum_sam.view(-1, 2)
|
| 254 |
+
all_giou_count_sam = all_giou_count_sam.view(-1, 1)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# --- Only calculate and output final results on the main process ---
|
| 258 |
+
if accelerator.is_main_process:
|
| 259 |
+
iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
|
| 260 |
+
# print(all_intersections, all_unions, iou_class)
|
| 261 |
+
ciou = iou_class[1].item()
|
| 262 |
+
giou = (torch.sum(all_giou_sum, dim=0)[1] / torch.sum(all_giou_count)).item()
|
| 263 |
+
|
| 264 |
+
if sam_help:
|
| 265 |
+
iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
|
| 266 |
+
ciou_sam = iou_class_sam[1].item()
|
| 267 |
+
giou_sam = (torch.sum(all_giou_sum_sam, dim=0)[1] / torch.sum(all_giou_count_sam)).item()
|
| 268 |
+
else:
|
| 269 |
+
giou_sam, ciou_sam = 0.0, 0.0
|
| 270 |
+
|
| 271 |
+
# <--- Added: Calculate and print accurate evaluation totals ---
|
| 272 |
+
total_evaluated_images = len(sub_dataset) # Total images evaluated
|
| 273 |
+
total_evaluated_masks = torch.sum(all_giou_count).item() # Total masks/prompts evaluated
|
| 274 |
+
# <--- End added ---
|
| 275 |
+
|
| 276 |
+
print("\n" + "=" * 50)
|
| 277 |
+
print(f"Evaluation finished for: {args.model_path}")
|
| 278 |
+
print(f"Dataset: {args.dataset_split}")
|
| 279 |
+
print("-" * 50)
|
| 280 |
+
# <--- Added: Print evaluation sample counts ---
|
| 281 |
+
print(f"Total images evaluated: {total_evaluated_images}")
|
| 282 |
+
print(f"Total masks/prompts evaluated: {total_evaluated_masks}")
|
| 283 |
+
print("-" * 50)
|
| 284 |
+
# <--- End added ---
|
| 285 |
+
print(f"Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}")
|
| 286 |
+
if sam_help:
|
| 287 |
+
print(f"SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}")
|
| 288 |
+
print("=" * 50 + "\n")
|
| 289 |
+
|
| 290 |
+
# --- Dynamically construct output file path and write results ---
|
| 291 |
+
model_name = os.path.basename(args.model_path.strip('/'))
|
| 292 |
+
ds_split_sanitized = args.dataset_split.replace("|", "_")
|
| 293 |
+
output_dir = os.path.join(args.save_file, model_name, ds_split_sanitized)
|
| 294 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 295 |
+
output_filepath = os.path.join(output_dir, "evaluation_results.txt")
|
| 296 |
+
|
| 297 |
+
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 298 |
+
chunk_info = f"(chunk {args.chunk_idx + 1}/{args.num_chunks})" if args.num_chunks > 1 else ""
|
| 299 |
+
|
| 300 |
+
header_text = f"[{current_time}] Model: {args.model_path}, Dataset: {args.dataset_split} {chunk_info}\n"
|
| 301 |
+
# <--- Added: Also record evaluation sample counts in the file ---
|
| 302 |
+
eval_stats_text = f" - Evaluated on {total_evaluated_images} images and {total_evaluated_masks} masks.\n"
|
| 303 |
+
# <--- End added ---
|
| 304 |
+
output_text = f" - Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}\n"
|
| 305 |
+
if sam_help:
|
| 306 |
+
output_text += f" - SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}\n"
|
| 307 |
+
|
| 308 |
+
with open(output_filepath, "a") as file:
|
| 309 |
+
file.write(header_text)
|
| 310 |
+
file.write(eval_stats_text) # <--- Added
|
| 311 |
+
file.write(output_text)
|
| 312 |
+
file.write("-" * 60 + "\n")
|
| 313 |
+
|
| 314 |
+
print(f"Results appended to: {output_filepath}")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
parser = argparse.ArgumentParser()
|
| 319 |
+
parser.add_argument("--model_path", type=str, default='/raid2/DATA/text4seg/model_trained_qwen_2b/',
|
| 320 |
+
help="Path to your GenerativeSegmenter model checkpoint.")
|
| 321 |
+
parser.add_argument("--sam_path", type=str, default='/efficient_sag4text/sam_vit_h_4b8939.pth', help="Path to the SAM checkpoint.")
|
| 322 |
+
parser.add_argument("--image_folder", type=str, default='/efficient_sag4text/seg_data/refer_seg', help="Root folder for the dataset images.")
|
| 323 |
+
parser.add_argument("--dataset_split", type=str, default="refcoco|unc|val", help="Dataset split to evaluate on.")
|
| 324 |
+
parser.add_argument("--save_file", type=str, default="output_eval_accelerated/",
|
| 325 |
+
help="Root directory to save evaluation outputs (masks and metrics).")
|
| 326 |
+
parser.add_argument("--save_masks", action='store_true', help="Set this flag to save output masks and images.")
|
| 327 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
| 328 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
| 329 |
+
parser.add_argument("--min_pixels", type=int, default=1024*28 * 28, help="Minimum pixels for segmentation.")
|
| 330 |
+
parser.add_argument("--max_pixels", type=int, default=1024*28 * 28, help="Maximum pixels for segmentation.")
|
| 331 |
+
args = parser.parse_args()
|
| 332 |
+
|
| 333 |
+
eval_model(args)
|
eval/refer.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import pickle
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
import skimage.io as io
|
| 13 |
+
from matplotlib.collections import PatchCollection
|
| 14 |
+
from matplotlib.patches import Polygon, Rectangle
|
| 15 |
+
from pycocotools import mask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class REFER:
|
| 19 |
+
def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
|
| 20 |
+
print("loading dataset %s into memory..." % dataset)
|
| 21 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
| 22 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
| 23 |
+
if dataset in ["refcoco", "refcoco+", "refcocog"]:
|
| 24 |
+
self.IMAGE_DIR = osp.join(data_root, "images/coco_2014/train2014")
|
| 25 |
+
elif dataset == "refclef":
|
| 26 |
+
self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
|
| 27 |
+
else:
|
| 28 |
+
print("No refer dataset is called [%s]" % dataset)
|
| 29 |
+
sys.exit()
|
| 30 |
+
|
| 31 |
+
self.dataset = dataset
|
| 32 |
+
|
| 33 |
+
# load refs from data/dataset/refs(dataset).json
|
| 34 |
+
tic = time.time()
|
| 35 |
+
|
| 36 |
+
ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
|
| 37 |
+
print("ref_file: ", ref_file)
|
| 38 |
+
self.data = {}
|
| 39 |
+
self.data["dataset"] = dataset
|
| 40 |
+
self.data["refs"] = pickle.load(open(ref_file, "rb"))
|
| 41 |
+
|
| 42 |
+
# load annotations from data/dataset/instances.json
|
| 43 |
+
instances_file = osp.join(self.DATA_DIR, "instances.json")
|
| 44 |
+
instances = json.load(open(instances_file, "rb"))
|
| 45 |
+
self.data["images"] = instances["images"]
|
| 46 |
+
self.data["annotations"] = instances["annotations"]
|
| 47 |
+
self.data["categories"] = instances["categories"]
|
| 48 |
+
|
| 49 |
+
# create index
|
| 50 |
+
self.createIndex()
|
| 51 |
+
print("DONE (t=%.2fs)" % (time.time() - tic))
|
| 52 |
+
|
| 53 |
+
def createIndex(self):
|
| 54 |
+
# create sets of mapping
|
| 55 |
+
# 1) Refs: {ref_id: ref}
|
| 56 |
+
# 2) Anns: {ann_id: ann}
|
| 57 |
+
# 3) Imgs: {image_id: image}
|
| 58 |
+
# 4) Cats: {category_id: category_name}
|
| 59 |
+
# 5) Sents: {sent_id: sent}
|
| 60 |
+
# 6) imgToRefs: {image_id: refs}
|
| 61 |
+
# 7) imgToAnns: {image_id: anns}
|
| 62 |
+
# 8) refToAnn: {ref_id: ann}
|
| 63 |
+
# 9) annToRef: {ann_id: ref}
|
| 64 |
+
# 10) catToRefs: {category_id: refs}
|
| 65 |
+
# 11) sentToRef: {sent_id: ref}
|
| 66 |
+
# 12) sentToTokens: {sent_id: tokens}
|
| 67 |
+
print("creating index...")
|
| 68 |
+
# fetch info from instances
|
| 69 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
| 70 |
+
for ann in self.data["annotations"]:
|
| 71 |
+
Anns[ann["id"]] = ann
|
| 72 |
+
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
|
| 73 |
+
for img in self.data["images"]:
|
| 74 |
+
Imgs[img["id"]] = img
|
| 75 |
+
for cat in self.data["categories"]:
|
| 76 |
+
Cats[cat["id"]] = cat["name"]
|
| 77 |
+
|
| 78 |
+
# fetch info from refs
|
| 79 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
| 80 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
| 81 |
+
for ref in self.data["refs"]:
|
| 82 |
+
# ids
|
| 83 |
+
ref_id = ref["ref_id"]
|
| 84 |
+
ann_id = ref["ann_id"]
|
| 85 |
+
category_id = ref["category_id"]
|
| 86 |
+
image_id = ref["image_id"]
|
| 87 |
+
|
| 88 |
+
# add mapping related to ref
|
| 89 |
+
Refs[ref_id] = ref
|
| 90 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
| 91 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
| 92 |
+
refToAnn[ref_id] = Anns[ann_id]
|
| 93 |
+
annToRef[ann_id] = ref
|
| 94 |
+
|
| 95 |
+
# add mapping of sent
|
| 96 |
+
for sent in ref["sentences"]:
|
| 97 |
+
Sents[sent["sent_id"]] = sent
|
| 98 |
+
sentToRef[sent["sent_id"]] = ref
|
| 99 |
+
sentToTokens[sent["sent_id"]] = sent["tokens"]
|
| 100 |
+
|
| 101 |
+
# create class members
|
| 102 |
+
self.Refs = Refs
|
| 103 |
+
self.Anns = Anns
|
| 104 |
+
self.Imgs = Imgs
|
| 105 |
+
self.Cats = Cats
|
| 106 |
+
self.Sents = Sents
|
| 107 |
+
self.imgToRefs = imgToRefs
|
| 108 |
+
self.imgToAnns = imgToAnns
|
| 109 |
+
self.refToAnn = refToAnn
|
| 110 |
+
self.annToRef = annToRef
|
| 111 |
+
self.catToRefs = catToRefs
|
| 112 |
+
self.sentToRef = sentToRef
|
| 113 |
+
self.sentToTokens = sentToTokens
|
| 114 |
+
print("index created.")
|
| 115 |
+
|
| 116 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
|
| 117 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 118 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 119 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 120 |
+
|
| 121 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
| 122 |
+
refs = self.data["refs"]
|
| 123 |
+
else:
|
| 124 |
+
if not len(image_ids) == 0:
|
| 125 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
| 126 |
+
else:
|
| 127 |
+
refs = self.data["refs"]
|
| 128 |
+
if not len(cat_ids) == 0:
|
| 129 |
+
refs = [ref for ref in refs if ref["category_id"] in cat_ids]
|
| 130 |
+
if not len(ref_ids) == 0:
|
| 131 |
+
refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
|
| 132 |
+
if not len(split) == 0:
|
| 133 |
+
if split in ["testA", "testB", "testC"]:
|
| 134 |
+
refs = [
|
| 135 |
+
ref for ref in refs if split[-1] in ref["split"]
|
| 136 |
+
] # we also consider testAB, testBC, ...
|
| 137 |
+
elif split in ["testAB", "testBC", "testAC"]:
|
| 138 |
+
refs = [
|
| 139 |
+
ref for ref in refs if ref["split"] == split
|
| 140 |
+
] # rarely used I guess...
|
| 141 |
+
elif split == "test":
|
| 142 |
+
refs = [ref for ref in refs if "test" in ref["split"]]
|
| 143 |
+
elif split == "train" or split == "val":
|
| 144 |
+
refs = [ref for ref in refs if ref["split"] == split]
|
| 145 |
+
else:
|
| 146 |
+
print("No such split [%s]" % split)
|
| 147 |
+
sys.exit()
|
| 148 |
+
ref_ids = [ref["ref_id"] for ref in refs]
|
| 149 |
+
return ref_ids
|
| 150 |
+
|
| 151 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
| 152 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 153 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 154 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 155 |
+
|
| 156 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
| 157 |
+
ann_ids = [ann["id"] for ann in self.data["annotations"]]
|
| 158 |
+
else:
|
| 159 |
+
if not len(image_ids) == 0:
|
| 160 |
+
lists = [
|
| 161 |
+
self.imgToAnns[image_id]
|
| 162 |
+
for image_id in image_ids
|
| 163 |
+
if image_id in self.imgToAnns
|
| 164 |
+
] # list of [anns]
|
| 165 |
+
anns = list(itertools.chain.from_iterable(lists))
|
| 166 |
+
else:
|
| 167 |
+
anns = self.data["annotations"]
|
| 168 |
+
if not len(cat_ids) == 0:
|
| 169 |
+
anns = [ann for ann in anns if ann["category_id"] in cat_ids]
|
| 170 |
+
ann_ids = [ann["id"] for ann in anns]
|
| 171 |
+
if not len(ref_ids) == 0:
|
| 172 |
+
ids = set(ann_ids).intersection(
|
| 173 |
+
set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
|
| 174 |
+
)
|
| 175 |
+
return ann_ids
|
| 176 |
+
|
| 177 |
+
def getImgIds(self, ref_ids=[]):
|
| 178 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 179 |
+
|
| 180 |
+
if not len(ref_ids) == 0:
|
| 181 |
+
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
|
| 182 |
+
else:
|
| 183 |
+
image_ids = self.Imgs.keys()
|
| 184 |
+
return image_ids
|
| 185 |
+
|
| 186 |
+
def getCatIds(self):
|
| 187 |
+
return self.Cats.keys()
|
| 188 |
+
|
| 189 |
+
def loadRefs(self, ref_ids=[]):
|
| 190 |
+
if type(ref_ids) == list:
|
| 191 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
| 192 |
+
elif type(ref_ids) == int:
|
| 193 |
+
return [self.Refs[ref_ids]]
|
| 194 |
+
|
| 195 |
+
def loadAnns(self, ann_ids=[]):
|
| 196 |
+
if type(ann_ids) == list:
|
| 197 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
| 198 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
| 199 |
+
return [self.Anns[ann_ids]]
|
| 200 |
+
|
| 201 |
+
def loadImgs(self, image_ids=[]):
|
| 202 |
+
if type(image_ids) == list:
|
| 203 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
| 204 |
+
elif type(image_ids) == int:
|
| 205 |
+
return [self.Imgs[image_ids]]
|
| 206 |
+
|
| 207 |
+
def loadCats(self, cat_ids=[]):
|
| 208 |
+
if type(cat_ids) == list:
|
| 209 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
| 210 |
+
elif type(cat_ids) == int:
|
| 211 |
+
return [self.Cats[cat_ids]]
|
| 212 |
+
|
| 213 |
+
def getRefBox(self, ref_id):
|
| 214 |
+
ref = self.Refs[ref_id]
|
| 215 |
+
ann = self.refToAnn[ref_id]
|
| 216 |
+
return ann["bbox"] # [x, y, w, h]
|
| 217 |
+
|
| 218 |
+
def showRef(self, ref, seg_box="seg"):
|
| 219 |
+
ax = plt.gca()
|
| 220 |
+
# show image
|
| 221 |
+
image = self.Imgs[ref["image_id"]]
|
| 222 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
|
| 223 |
+
ax.imshow(I)
|
| 224 |
+
# show refer expression
|
| 225 |
+
for sid, sent in enumerate(ref["sentences"]):
|
| 226 |
+
print("%s. %s" % (sid + 1, sent["sent"]))
|
| 227 |
+
# show segmentations
|
| 228 |
+
if seg_box == "seg":
|
| 229 |
+
ann_id = ref["ann_id"]
|
| 230 |
+
ann = self.Anns[ann_id]
|
| 231 |
+
polygons = []
|
| 232 |
+
color = []
|
| 233 |
+
c = "none"
|
| 234 |
+
if type(ann["segmentation"][0]) == list:
|
| 235 |
+
# polygon used for refcoco*
|
| 236 |
+
for seg in ann["segmentation"]:
|
| 237 |
+
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
| 238 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
| 239 |
+
color.append(c)
|
| 240 |
+
p = PatchCollection(
|
| 241 |
+
polygons,
|
| 242 |
+
facecolors=color,
|
| 243 |
+
edgecolors=(1, 1, 0, 0),
|
| 244 |
+
linewidths=3,
|
| 245 |
+
alpha=1,
|
| 246 |
+
)
|
| 247 |
+
ax.add_collection(p) # thick yellow polygon
|
| 248 |
+
p = PatchCollection(
|
| 249 |
+
polygons,
|
| 250 |
+
facecolors=color,
|
| 251 |
+
edgecolors=(1, 0, 0, 0),
|
| 252 |
+
linewidths=1,
|
| 253 |
+
alpha=1,
|
| 254 |
+
)
|
| 255 |
+
ax.add_collection(p) # thin red polygon
|
| 256 |
+
else:
|
| 257 |
+
# mask used for refclef
|
| 258 |
+
rle = ann["segmentation"]
|
| 259 |
+
m = mask.decode(rle)
|
| 260 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
| 261 |
+
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
| 262 |
+
for i in range(3):
|
| 263 |
+
img[:, :, i] = color_mask[i]
|
| 264 |
+
ax.imshow(np.dstack((img, m * 0.5)))
|
| 265 |
+
# show bounding-box
|
| 266 |
+
elif seg_box == "box":
|
| 267 |
+
ann_id = ref["ann_id"]
|
| 268 |
+
ann = self.Anns[ann_id]
|
| 269 |
+
bbox = self.getRefBox(ref["ref_id"])
|
| 270 |
+
box_plot = Rectangle(
|
| 271 |
+
(bbox[0], bbox[1]),
|
| 272 |
+
bbox[2],
|
| 273 |
+
bbox[3],
|
| 274 |
+
fill=False,
|
| 275 |
+
edgecolor="green",
|
| 276 |
+
linewidth=3,
|
| 277 |
+
)
|
| 278 |
+
ax.add_patch(box_plot)
|
| 279 |
+
|
| 280 |
+
def getMask(self, ref):
|
| 281 |
+
# return mask, area and mask-center
|
| 282 |
+
ann = self.refToAnn[ref["ref_id"]]
|
| 283 |
+
image = self.Imgs[ref["image_id"]]
|
| 284 |
+
if type(ann["segmentation"][0]) == list: # polygon
|
| 285 |
+
rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
|
| 286 |
+
else:
|
| 287 |
+
rle = ann["segmentation"]
|
| 288 |
+
m = mask.decode(rle)
|
| 289 |
+
m = np.sum(
|
| 290 |
+
m, axis=2
|
| 291 |
+
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
| 292 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 293 |
+
# compute area
|
| 294 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
| 295 |
+
return {"mask": m, "area": area}
|
| 296 |
+
def showMask(self, ref):
|
| 297 |
+
M = self.getMask(ref)
|
| 298 |
+
msk = M["mask"]
|
| 299 |
+
ax = plt.gca()
|
| 300 |
+
ax.imshow(msk)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if __name__ == "__main__":
|
| 304 |
+
refer = REFER(dataset="refcocog", splitBy="google")
|
| 305 |
+
ref_ids = refer.getRefIds()
|
| 306 |
+
print(len(ref_ids))
|
| 307 |
+
|
| 308 |
+
print(len(refer.Imgs))
|
| 309 |
+
print(len(refer.imgToRefs))
|
| 310 |
+
|
| 311 |
+
ref_ids = refer.getRefIds(split="train")
|
| 312 |
+
print("There are %s training referred objects." % len(ref_ids))
|
| 313 |
+
|
| 314 |
+
for ref_id in ref_ids:
|
| 315 |
+
ref = refer.loadRefs(ref_id)[0]
|
| 316 |
+
if len(ref["sentences"]) < 2:
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
pprint(ref)
|
| 320 |
+
print("The label is %s." % refer.Cats[ref["category_id"]])
|
| 321 |
+
plt.figure()
|
| 322 |
+
refer.showRef(ref, seg_box="box")
|
| 323 |
+
plt.show()
|
eval/transforms.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
| 6 |
+
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ResizeLongestSide:
|
| 12 |
+
"""
|
| 13 |
+
Resizes images to the longest side 'target_length', as well as provides
|
| 14 |
+
methods for resizing coordinates and boxes. Provides methods for
|
| 15 |
+
transforming both numpy array and batched torch tensors.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, target_length: int) -> None:
|
| 19 |
+
self.target_length = target_length
|
| 20 |
+
|
| 21 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
| 22 |
+
"""
|
| 23 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 24 |
+
"""
|
| 25 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
| 26 |
+
return np.array(resize(to_pil_image(image), target_size))
|
| 27 |
+
|
| 28 |
+
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
| 29 |
+
"""
|
| 30 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
| 31 |
+
original image size in (H, W) format.
|
| 32 |
+
"""
|
| 33 |
+
old_h, old_w = original_size
|
| 34 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 35 |
+
original_size[0], original_size[1], self.target_length
|
| 36 |
+
)
|
| 37 |
+
coords = deepcopy(coords).astype(float)
|
| 38 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 39 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 40 |
+
return coords
|
| 41 |
+
|
| 42 |
+
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
| 43 |
+
"""
|
| 44 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
| 45 |
+
in (H, W) format.
|
| 46 |
+
"""
|
| 47 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
| 48 |
+
return boxes.reshape(-1, 4)
|
| 49 |
+
|
| 50 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Expects batched images with shape BxCxHxW and float format. This
|
| 53 |
+
transformation may not exactly match apply_image. apply_image is
|
| 54 |
+
the transformation expected by the model.
|
| 55 |
+
"""
|
| 56 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
| 57 |
+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
| 58 |
+
return F.interpolate(
|
| 59 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def apply_coords_torch(
|
| 63 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
| 67 |
+
original image size in (H, W) format.
|
| 68 |
+
"""
|
| 69 |
+
old_h, old_w = original_size
|
| 70 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 71 |
+
original_size[0], original_size[1], self.target_length
|
| 72 |
+
)
|
| 73 |
+
coords = deepcopy(coords).to(torch.float)
|
| 74 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 75 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 76 |
+
return coords
|
| 77 |
+
|
| 78 |
+
def apply_boxes_torch(
|
| 79 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
| 83 |
+
size in (H, W) format.
|
| 84 |
+
"""
|
| 85 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
| 86 |
+
return boxes.reshape(-1, 4)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
| 90 |
+
"""
|
| 91 |
+
Compute the output size given input size and target long side length.
|
| 92 |
+
"""
|
| 93 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 94 |
+
newh, neww = oldh * scale, oldw * scale
|
| 95 |
+
neww = int(neww + 0.5)
|
| 96 |
+
newh = int(newh + 0.5)
|
| 97 |
+
return (newh, neww)
|
eval/val_utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from .utils import AverageMeter, Summary, intersectionAndUnionGPU, compute_logits_from_mask, masks_sample_points
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_in_process_evaluation(model, accelerator, eval_dataloader, sam_predictor=None):
|
| 10 |
+
# --- 1. Initialize all metric recorders ---
|
| 11 |
+
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
| 12 |
+
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
|
| 13 |
+
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
| 14 |
+
|
| 15 |
+
use_sam = sam_predictor is not None
|
| 16 |
+
if use_sam:
|
| 17 |
+
intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
| 18 |
+
union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
|
| 19 |
+
acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
| 20 |
+
|
| 21 |
+
progress_bar = tqdm(eval_dataloader, disable=not accelerator.is_main_process, desc="Running Evaluation")
|
| 22 |
+
|
| 23 |
+
# --- 2. Iterate over the evaluation dataset ---
|
| 24 |
+
for batch in progress_bar:
|
| 25 |
+
images, masks, image_names, questions, image_paths = batch
|
| 26 |
+
image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
|
| 27 |
+
w_ori, h_ori = image.size
|
| 28 |
+
|
| 29 |
+
total_intersection = torch.zeros(2, device=accelerator.device)
|
| 30 |
+
total_union = torch.zeros(2, device=accelerator.device)
|
| 31 |
+
total_acc_iou = torch.zeros(2, device=accelerator.device)
|
| 32 |
+
|
| 33 |
+
if use_sam:
|
| 34 |
+
total_intersection_sam = torch.zeros(2, device=accelerator.device)
|
| 35 |
+
total_union_sam = torch.zeros(2, device=accelerator.device)
|
| 36 |
+
total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
|
| 37 |
+
|
| 38 |
+
num_masks_in_image = len(prompts)
|
| 39 |
+
if num_masks_in_image == 0:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
with torch.inference_mode():
|
| 43 |
+
if use_sam:
|
| 44 |
+
predictor = sam_predictor
|
| 45 |
+
predictor.set_image(np.array(image))
|
| 46 |
+
|
| 47 |
+
for i, question in enumerate(prompts):
|
| 48 |
+
gt_mask = gt_masks[i].to(accelerator.device)
|
| 49 |
+
|
| 50 |
+
# --- Key difference: directly use the passed-in model object for inference ---
|
| 51 |
+
segmentation_masks, _ = model.generate_with_segmentation(image, question)
|
| 52 |
+
|
| 53 |
+
if segmentation_masks is None or len(segmentation_masks) == 0:
|
| 54 |
+
pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
|
| 55 |
+
else:
|
| 56 |
+
mask = segmentation_masks[0].to(accelerator.device)
|
| 57 |
+
pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
|
| 58 |
+
mode='nearest').squeeze()
|
| 59 |
+
|
| 60 |
+
# --- SAM post-processing (if enabled) ---
|
| 61 |
+
if use_sam:
|
| 62 |
+
sam_refined_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
|
| 63 |
+
unique_classes = torch.unique(pred_mask)
|
| 64 |
+
for class_id in unique_classes:
|
| 65 |
+
if class_id == 0: continue
|
| 66 |
+
binary_mask = (pred_mask == class_id).double()
|
| 67 |
+
try:
|
| 68 |
+
logits = compute_logits_from_mask(binary_mask)
|
| 69 |
+
point_coords, point_labels = masks_sample_points(binary_mask)
|
| 70 |
+
sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
|
| 71 |
+
mask_input=logits, multimask_output=False)
|
| 72 |
+
for _ in range(2):
|
| 73 |
+
sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
|
| 74 |
+
mask_input=logit, multimask_output=False)
|
| 75 |
+
except Exception:
|
| 76 |
+
sam_mask = np.zeros((1, h_ori, w_ori))
|
| 77 |
+
sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
|
| 78 |
+
|
| 79 |
+
# --- Metric calculation (current mask) ---
|
| 80 |
+
# Use the original model for prediction
|
| 81 |
+
# <-- MODIFICATION 1: Add ignore_index=255
|
| 82 |
+
intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
|
| 83 |
+
total_intersection += intersection_i
|
| 84 |
+
total_union += union_i
|
| 85 |
+
# <-- MODIFICATION 2: Change epsilon to 1e-5
|
| 86 |
+
iou_per_sample = intersection_i / (union_i + 1e-5)
|
| 87 |
+
iou_per_sample[union_i == 0] = 1.0
|
| 88 |
+
total_acc_iou += iou_per_sample
|
| 89 |
+
|
| 90 |
+
# SAM optimized prediction
|
| 91 |
+
if use_sam:
|
| 92 |
+
# <-- MODIFICATION 1 (SAM): Add ignore_index=255
|
| 93 |
+
intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2, ignore_index=255)
|
| 94 |
+
total_intersection_sam += intersection_sam_i
|
| 95 |
+
total_union_sam += union_sam_i
|
| 96 |
+
# <-- MODIFICATION 2 (SAM): Change epsilon to 1e-5
|
| 97 |
+
iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
|
| 98 |
+
iou_per_sample_sam[union_sam_i == 0] = 1.0
|
| 99 |
+
total_acc_iou_sam += iou_per_sample_sam
|
| 100 |
+
|
| 101 |
+
# Update global recorders
|
| 102 |
+
intersection_meter.update(total_intersection.cpu().numpy())
|
| 103 |
+
union_meter.update(total_union.cpu().numpy())
|
| 104 |
+
acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
|
| 105 |
+
if use_sam:
|
| 106 |
+
intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
|
| 107 |
+
union_meter_sam.update(total_union_sam.cpu().numpy())
|
| 108 |
+
acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
|
| 109 |
+
|
| 110 |
+
# --- 3. Aggregate metrics from all GPUs ---
|
| 111 |
+
all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
|
| 112 |
+
all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
|
| 113 |
+
all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
|
| 114 |
+
all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
|
| 115 |
+
|
| 116 |
+
if use_sam:
|
| 117 |
+
all_intersections_sam = accelerator.gather_for_metrics(
|
| 118 |
+
torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
|
| 119 |
+
all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
|
| 120 |
+
all_giou_sum_sam = accelerator.gather_for_metrics(
|
| 121 |
+
torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
|
| 122 |
+
all_giou_count_sam = accelerator.gather_for_metrics(
|
| 123 |
+
torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
|
| 124 |
+
|
| 125 |
+
# --- 4. Calculate final results and return on the main process ---
|
| 126 |
+
final_metrics = {}
|
| 127 |
+
if accelerator.is_main_process:
|
| 128 |
+
# original model metrics
|
| 129 |
+
# <-- MODIFICATION 2 (cIoU): Change epsilon to 1e-5
|
| 130 |
+
iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
|
| 131 |
+
ciou = iou_class[1].item()
|
| 132 |
+
giou_sum = torch.sum(all_giou_sum, dim=0)[1]
|
| 133 |
+
giou_count = torch.sum(all_giou_count)
|
| 134 |
+
giou = (giou_sum / giou_count).item() if giou_count > 0 else 0.0
|
| 135 |
+
final_metrics['giou'] = giou
|
| 136 |
+
final_metrics['ciou'] = ciou
|
| 137 |
+
|
| 138 |
+
# SAM optimized metrics
|
| 139 |
+
if use_sam:
|
| 140 |
+
# <-- MODIFICATION 2 (cIoU SAM): Change epsilon to 1e-5
|
| 141 |
+
iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
|
| 142 |
+
ciou_sam = iou_class_sam[1].item()
|
| 143 |
+
giou_sum_sam = torch.sum(all_giou_sum_sam, dim=0)[1]
|
| 144 |
+
giou_count_sam = torch.sum(all_giou_count_sam)
|
| 145 |
+
giou_sam = (giou_sum_sam / giou_count_sam).item() if giou_count_sam > 0 else 0.0
|
| 146 |
+
final_metrics['sam_giou'] = giou_sam
|
| 147 |
+
final_metrics['sam_ciou'] = ciou_sam
|
| 148 |
+
|
| 149 |
+
return final_metrics
|