STAMP-2B-uni / eval /eval_refer_seg.py
realzliu
init
96f36aa
import argparse
import torch
import os
from tqdm import tqdm
import random
from PIL import Image
import numpy as np
import torch.nn.functional as F
import datetime
from model.segment_anything import SamPredictor, sam_model_registry
from dataset.refer_seg_dataset import ValDataset
from dataset.grefer_seg_dataset import grefcocoValDataset
from data.question_answer_list import QUESTION_PARTIAL
from segment_predictor import GenerativeSegmenter
from eval.utils import AverageMeter, Summary, intersectionAndUnionGPU, \
compute_logits_from_mask, masks_sample_points
from torch.utils.data import Dataset, DataLoader
import math
# --- Accelerate Import ---
from accelerate import Accelerator
def get_chunk(ds, n, k):
chunk_size = math.ceil(len(ds) / n)
i = chunk_size * k
start_index = i
end_index = i + chunk_size
ds.refer_seg_ds["images"] = ds.refer_seg_ds["images"][start_index:end_index]
return ds
def gget_chunk(ds, n, k):
chunk_size = math.ceil(len(ds) / n)
i = chunk_size * k
start_index = i
end_index = i + chunk_size
ds.loaded_images = ds.loaded_images[start_index:end_index]
return ds
class CustomDataset(Dataset):
def __init__(self, sub_dataset):
self.dataset = sub_dataset
def __getitem__(self, index):
image, masks, questions, image_path = self.dataset[index]
image_name = os.path.basename(image_path).split(".")[0]
questions = [random.choice(QUESTION_PARTIAL).replace("[class_name]", q) for q in questions]
return image, masks, image_name, questions, image_path
def __len__(self):
return len(self.dataset)
def collate_fn(batch):
images, masks, image_names, questions, image_paths = zip(*batch)
return images, masks, image_names, questions, image_paths
def create_data_loader(args, sub_dataset, batch_size=1):
assert batch_size == 1, "Batch size must be 1 for this evaluation script."
dataset = CustomDataset(sub_dataset)
return DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=collate_fn)
def eval_model(args):
# --- Initialize Accelerator ---
accelerator = Accelerator()
# --- Model and Predictor Initialization ---
segmenter = GenerativeSegmenter(args.model_path, device_map=accelerator.device, min_pixels=args.min_pixels,
max_pixels=args.max_pixels)
sam_help = args.sam_path is not None
if sam_help:
sam = sam_model_registry["vit_h"](checkpoint=args.sam_path)
sam = sam.to(dtype=torch.float32, device=accelerator.device)
predictor = SamPredictor(sam)
else:
predictor = None
# --- Dataset and DataLoader Initialization ---
if accelerator.is_main_process:
print("Loading dataset...")
# First, load the full dataset based on the parameters
if "grefcoco" in args.dataset_split:
val_dataset = grefcocoValDataset(args.image_folder, args.dataset_split)
else:
val_dataset = ValDataset(args.image_folder, args.dataset_split)
if accelerator.is_main_process:
total_data_size = len(val_dataset)
print(f"Total evaluation data volume (full dataset): {total_data_size} samples.")
# Then, get a chunk of the dataset as needed
if "grefcoco" in args.dataset_split:
sub_dataset = gget_chunk(val_dataset, args.num_chunks, args.chunk_idx)
else:
sub_dataset = get_chunk(val_dataset, args.num_chunks, args.chunk_idx)
data_loader = create_data_loader(args, sub_dataset)
data_loader = accelerator.prepare(data_loader)
# --- Metric Meters Initialization ---
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
if sam_help:
intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
progress_bar = tqdm(data_loader, disable=not accelerator.is_main_process, total=len(data_loader))
for batch in progress_bar:
images, masks, image_names, questions, image_paths = batch
image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
w_ori, h_ori = image.size
total_intersection = torch.zeros(2, device=accelerator.device)
total_union = torch.zeros(2, device=accelerator.device)
total_acc_iou = torch.zeros(2, device=accelerator.device)
if sam_help:
total_intersection_sam = torch.zeros(2, device=accelerator.device)
total_union_sam = torch.zeros(2, device=accelerator.device)
total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
num_masks_in_image = len(prompts)
with torch.inference_mode():
if sam_help:
predictor.set_image(np.array(image))
for i, question in enumerate(prompts):
gt_mask = gt_masks[i].to(accelerator.device).float().contiguous()
segmentation_masks, _ = segmenter.generate_with_segmentation(image, question)
if segmentation_masks is None or len(segmentation_masks) == 0:
pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
else:
mask = segmentation_masks[0].to(accelerator.device)
pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
mode='nearest').squeeze()
# if accelerator.is_main_process:
# print("\n" + "=" * 20 + " DEBUG INFO (first iteration) " + "=" * 20)
# print(f"Image Name: {image_name}")
# print(f"GT Mask Shape: {gt_mask.shape}")
# print(f"GT Mask DType: {gt_mask.dtype}")
# print(f"Unique values in GT Mask: {torch.unique(gt_mask)}")
# print(f"Pred Mask Shape: {pred_mask.shape}")
# print(f"Pred Mask DType: {pred_mask.dtype}")
# # Print the number of non-zero pixels in the predicted mask to check if the model generated an all-black image
# print(f"Number of non-zero pixels in Pred Mask: {torch.count_nonzero(pred_mask)}")
# print("=" * 68)
sam_refined_mask = torch.zeros_like(pred_mask)
if sam_help:
unique_classes = torch.unique(pred_mask)
for class_id in unique_classes:
if class_id == 0: continue
binary_mask = (pred_mask == class_id).double().cpu()
try:
logits = compute_logits_from_mask(pred_mask.cpu())
point_coords, point_labels = masks_sample_points(binary_mask)
sam_mask, _, logit = predictor.predict(point_coords=point_coords,
point_labels=point_labels,
mask_input=logits, multimask_output=False)
for _ in range(2):
sam_mask, _, logit = predictor.predict(point_coords=point_coords,
point_labels=point_labels,
mask_input=logit, multimask_output=False)
sam_mask = sam_mask[0].astype(np.float32)
except Exception as E:
print(f"Error: {E}")
sam_mask = np.zeros((h_ori, w_ori))
sam_refined_mask = torch.from_numpy(sam_mask).to(accelerator.device)
# sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
total_intersection += intersection_i
total_union += union_i
iou_per_sample = intersection_i / (union_i + 1e-5)
iou_per_sample[union_i == 0] = 1.0
total_acc_iou += iou_per_sample
if sam_help:
intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2,
ignore_index=255)
total_intersection_sam += intersection_sam_i
total_union_sam += union_sam_i
iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
iou_per_sample_sam[union_sam_i == 0] = 1.0
total_acc_iou_sam += iou_per_sample_sam
if args.save_masks and accelerator.is_main_process:
ds_split_sanitized = args.dataset_split.replace("|", "_")
model_name = os.path.basename(args.model_path.strip('/'))
save_path = os.path.join(args.save_file, model_name, ds_split_sanitized, "masks", image_name)
if not os.path.exists(save_path): os.makedirs(save_path)
Image.fromarray(pred_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
os.path.join(save_path, f"{i}_pred_mask.png"))
if sam_help:
Image.fromarray(sam_refined_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
os.path.join(save_path, f"{i}_sam_mask.png"))
Image.fromarray(gt_mask.cpu().numpy().astype("uint8") * 255).convert('L').save(
os.path.join(save_path, f"{i}_gt_mask.png"))
image.save(os.path.join(save_path, f"{i}_image.png"))
intersection_meter.update(total_intersection.cpu().numpy())
union_meter.update(total_union.cpu().numpy())
if sam_help:
intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
union_meter_sam.update(total_union_sam.cpu().numpy())
if num_masks_in_image > 0:
total_acc_iou = total_acc_iou / num_masks_in_image
acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
if sam_help:
total_acc_iou_sam = total_acc_iou_sam / num_masks_in_image
acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
# break
# --- Synchronize metrics across all processes ---
all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
all_intersections = all_intersections.view(-1, 2)
all_unions = all_unions.view(-1, 2)
all_giou_sum = all_giou_sum.view(-1, 2)
all_giou_count = all_giou_count.view(-1, 1)
if sam_help:
all_intersections_sam = accelerator.gather_for_metrics(
torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
all_giou_sum_sam = accelerator.gather_for_metrics(
torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
all_giou_count_sam = accelerator.gather_for_metrics(
torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
all_intersections_sam = all_intersections_sam.view(-1, 2)
all_unions_sam = all_unions_sam.view(-1, 2)
all_giou_sum_sam = all_giou_sum_sam.view(-1, 2)
all_giou_count_sam = all_giou_count_sam.view(-1, 1)
# --- Only calculate and output final results on the main process ---
if accelerator.is_main_process:
iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
# print(all_intersections, all_unions, iou_class)
ciou = iou_class[1].item()
giou = (torch.sum(all_giou_sum, dim=0)[1] / torch.sum(all_giou_count)).item()
if sam_help:
iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
ciou_sam = iou_class_sam[1].item()
giou_sam = (torch.sum(all_giou_sum_sam, dim=0)[1] / torch.sum(all_giou_count_sam)).item()
else:
giou_sam, ciou_sam = 0.0, 0.0
# <--- Added: Calculate and print accurate evaluation totals ---
total_evaluated_images = len(sub_dataset) # Total images evaluated
total_evaluated_masks = torch.sum(all_giou_count).item() # Total masks/prompts evaluated
# <--- End added ---
print("\n" + "=" * 50)
print(f"Evaluation finished for: {args.model_path}")
print(f"Dataset: {args.dataset_split}")
print("-" * 50)
# <--- Added: Print evaluation sample counts ---
print(f"Total images evaluated: {total_evaluated_images}")
print(f"Total masks/prompts evaluated: {total_evaluated_masks}")
print("-" * 50)
# <--- End added ---
print(f"Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}")
if sam_help:
print(f"SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}")
print("=" * 50 + "\n")
# --- Dynamically construct output file path and write results ---
model_name = os.path.basename(args.model_path.strip('/'))
ds_split_sanitized = args.dataset_split.replace("|", "_")
output_dir = os.path.join(args.save_file, model_name, ds_split_sanitized)
os.makedirs(output_dir, exist_ok=True)
output_filepath = os.path.join(output_dir, "evaluation_results.txt")
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chunk_info = f"(chunk {args.chunk_idx + 1}/{args.num_chunks})" if args.num_chunks > 1 else ""
header_text = f"[{current_time}] Model: {args.model_path}, Dataset: {args.dataset_split} {chunk_info}\n"
# <--- Added: Also record evaluation sample counts in the file ---
eval_stats_text = f" - Evaluated on {total_evaluated_images} images and {total_evaluated_masks} masks.\n"
# <--- End added ---
output_text = f" - Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}\n"
if sam_help:
output_text += f" - SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}\n"
with open(output_filepath, "a") as file:
file.write(header_text)
file.write(eval_stats_text) # <--- Added
file.write(output_text)
file.write("-" * 60 + "\n")
print(f"Results appended to: {output_filepath}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default='/raid2/DATA/text4seg/model_trained_qwen_2b/',
help="Path to your GenerativeSegmenter model checkpoint.")
parser.add_argument("--sam_path", type=str, default='/efficient_sag4text/sam_vit_h_4b8939.pth', help="Path to the SAM checkpoint.")
parser.add_argument("--image_folder", type=str, default='/efficient_sag4text/seg_data/refer_seg', help="Root folder for the dataset images.")
parser.add_argument("--dataset_split", type=str, default="refcoco|unc|val", help="Dataset split to evaluate on.")
parser.add_argument("--save_file", type=str, default="output_eval_accelerated/",
help="Root directory to save evaluation outputs (masks and metrics).")
parser.add_argument("--save_masks", action='store_true', help="Set this flag to save output masks and images.")
parser.add_argument("--num_chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--min_pixels", type=int, default=1024*28 * 28, help="Minimum pixels for segmentation.")
parser.add_argument("--max_pixels", type=int, default=1024*28 * 28, help="Maximum pixels for segmentation.")
args = parser.parse_args()
eval_model(args)