STAMP-2B-uni / eval /val_utils.py
realzliu
init
96f36aa
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from PIL import Image
from .utils import AverageMeter, Summary, intersectionAndUnionGPU, compute_logits_from_mask, masks_sample_points
def run_in_process_evaluation(model, accelerator, eval_dataloader, sam_predictor=None):
# --- 1. Initialize all metric recorders ---
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
use_sam = sam_predictor is not None
if use_sam:
intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM)
progress_bar = tqdm(eval_dataloader, disable=not accelerator.is_main_process, desc="Running Evaluation")
# --- 2. Iterate over the evaluation dataset ---
for batch in progress_bar:
images, masks, image_names, questions, image_paths = batch
image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0]
w_ori, h_ori = image.size
total_intersection = torch.zeros(2, device=accelerator.device)
total_union = torch.zeros(2, device=accelerator.device)
total_acc_iou = torch.zeros(2, device=accelerator.device)
if use_sam:
total_intersection_sam = torch.zeros(2, device=accelerator.device)
total_union_sam = torch.zeros(2, device=accelerator.device)
total_acc_iou_sam = torch.zeros(2, device=accelerator.device)
num_masks_in_image = len(prompts)
if num_masks_in_image == 0:
continue
with torch.inference_mode():
if use_sam:
predictor = sam_predictor
predictor.set_image(np.array(image))
for i, question in enumerate(prompts):
gt_mask = gt_masks[i].to(accelerator.device)
# --- Key difference: directly use the passed-in model object for inference ---
segmentation_masks, _ = model.generate_with_segmentation(image, question)
if segmentation_masks is None or len(segmentation_masks) == 0:
pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
else:
mask = segmentation_masks[0].to(accelerator.device)
pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori),
mode='nearest').squeeze()
# --- SAM post-processing (if enabled) ---
if use_sam:
sam_refined_mask = torch.zeros((h_ori, w_ori), device=accelerator.device)
unique_classes = torch.unique(pred_mask)
for class_id in unique_classes:
if class_id == 0: continue
binary_mask = (pred_mask == class_id).double()
try:
logits = compute_logits_from_mask(binary_mask)
point_coords, point_labels = masks_sample_points(binary_mask)
sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
mask_input=logits, multimask_output=False)
for _ in range(2):
sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels,
mask_input=logit, multimask_output=False)
except Exception:
sam_mask = np.zeros((1, h_ori, w_ori))
sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id
# --- Metric calculation (current mask) ---
# Use the original model for prediction
# <-- MODIFICATION 1: Add ignore_index=255
intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255)
total_intersection += intersection_i
total_union += union_i
# <-- MODIFICATION 2: Change epsilon to 1e-5
iou_per_sample = intersection_i / (union_i + 1e-5)
iou_per_sample[union_i == 0] = 1.0
total_acc_iou += iou_per_sample
# SAM optimized prediction
if use_sam:
# <-- MODIFICATION 1 (SAM): Add ignore_index=255
intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2, ignore_index=255)
total_intersection_sam += intersection_sam_i
total_union_sam += union_sam_i
# <-- MODIFICATION 2 (SAM): Change epsilon to 1e-5
iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5)
iou_per_sample_sam[union_sam_i == 0] = 1.0
total_acc_iou_sam += iou_per_sample_sam
# Update global recorders
intersection_meter.update(total_intersection.cpu().numpy())
union_meter.update(total_union.cpu().numpy())
acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image)
if use_sam:
intersection_meter_sam.update(total_intersection_sam.cpu().numpy())
union_meter_sam.update(total_union_sam.cpu().numpy())
acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image)
# --- 3. Aggregate metrics from all GPUs ---
all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device))
all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device))
all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device))
all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device))
if use_sam:
all_intersections_sam = accelerator.gather_for_metrics(
torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device))
all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device))
all_giou_sum_sam = accelerator.gather_for_metrics(
torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device))
all_giou_count_sam = accelerator.gather_for_metrics(
torch.tensor(acc_iou_meter_sam.count, device=accelerator.device))
# --- 4. Calculate final results and return on the main process ---
final_metrics = {}
if accelerator.is_main_process:
# original model metrics
# <-- MODIFICATION 2 (cIoU): Change epsilon to 1e-5
iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5)
ciou = iou_class[1].item()
giou_sum = torch.sum(all_giou_sum, dim=0)[1]
giou_count = torch.sum(all_giou_count)
giou = (giou_sum / giou_count).item() if giou_count > 0 else 0.0
final_metrics['giou'] = giou
final_metrics['ciou'] = ciou
# SAM optimized metrics
if use_sam:
# <-- MODIFICATION 2 (cIoU SAM): Change epsilon to 1e-5
iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5)
ciou_sam = iou_class_sam[1].item()
giou_sum_sam = torch.sum(all_giou_sum_sam, dim=0)[1]
giou_count_sam = torch.sum(all_giou_count_sam)
giou_sam = (giou_sum_sam / giou_count_sam).item() if giou_count_sam > 0 else 0.0
final_metrics['sam_giou'] = giou_sam
final_metrics['sam_ciou'] = ciou_sam
return final_metrics