Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| import numpy as np | |
| from PIL import Image | |
| from .utils import AverageMeter, Summary, intersectionAndUnionGPU, compute_logits_from_mask, masks_sample_points | |
| def run_in_process_evaluation(model, accelerator, eval_dataloader, sam_predictor=None): | |
| # --- 1. Initialize all metric recorders --- | |
| intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| use_sam = sam_predictor is not None | |
| if use_sam: | |
| intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| progress_bar = tqdm(eval_dataloader, disable=not accelerator.is_main_process, desc="Running Evaluation") | |
| # --- 2. Iterate over the evaluation dataset --- | |
| for batch in progress_bar: | |
| images, masks, image_names, questions, image_paths = batch | |
| image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0] | |
| w_ori, h_ori = image.size | |
| total_intersection = torch.zeros(2, device=accelerator.device) | |
| total_union = torch.zeros(2, device=accelerator.device) | |
| total_acc_iou = torch.zeros(2, device=accelerator.device) | |
| if use_sam: | |
| total_intersection_sam = torch.zeros(2, device=accelerator.device) | |
| total_union_sam = torch.zeros(2, device=accelerator.device) | |
| total_acc_iou_sam = torch.zeros(2, device=accelerator.device) | |
| num_masks_in_image = len(prompts) | |
| if num_masks_in_image == 0: | |
| continue | |
| with torch.inference_mode(): | |
| if use_sam: | |
| predictor = sam_predictor | |
| predictor.set_image(np.array(image)) | |
| for i, question in enumerate(prompts): | |
| gt_mask = gt_masks[i].to(accelerator.device) | |
| # --- Key difference: directly use the passed-in model object for inference --- | |
| segmentation_masks, _ = model.generate_with_segmentation(image, question) | |
| if segmentation_masks is None or len(segmentation_masks) == 0: | |
| pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device) | |
| else: | |
| mask = segmentation_masks[0].to(accelerator.device) | |
| pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori), | |
| mode='nearest').squeeze() | |
| # --- SAM post-processing (if enabled) --- | |
| if use_sam: | |
| sam_refined_mask = torch.zeros((h_ori, w_ori), device=accelerator.device) | |
| unique_classes = torch.unique(pred_mask) | |
| for class_id in unique_classes: | |
| if class_id == 0: continue | |
| binary_mask = (pred_mask == class_id).double() | |
| try: | |
| logits = compute_logits_from_mask(binary_mask) | |
| point_coords, point_labels = masks_sample_points(binary_mask) | |
| sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels, | |
| mask_input=logits, multimask_output=False) | |
| for _ in range(2): | |
| sam_mask, _, logit = predictor.predict(point_coords=point_coords, point_labels=point_labels, | |
| mask_input=logit, multimask_output=False) | |
| except Exception: | |
| sam_mask = np.zeros((1, h_ori, w_ori)) | |
| sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id | |
| # --- Metric calculation (current mask) --- | |
| # Use the original model for prediction | |
| # <-- MODIFICATION 1: Add ignore_index=255 | |
| intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255) | |
| total_intersection += intersection_i | |
| total_union += union_i | |
| # <-- MODIFICATION 2: Change epsilon to 1e-5 | |
| iou_per_sample = intersection_i / (union_i + 1e-5) | |
| iou_per_sample[union_i == 0] = 1.0 | |
| total_acc_iou += iou_per_sample | |
| # SAM optimized prediction | |
| if use_sam: | |
| # <-- MODIFICATION 1 (SAM): Add ignore_index=255 | |
| intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2, ignore_index=255) | |
| total_intersection_sam += intersection_sam_i | |
| total_union_sam += union_sam_i | |
| # <-- MODIFICATION 2 (SAM): Change epsilon to 1e-5 | |
| iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5) | |
| iou_per_sample_sam[union_sam_i == 0] = 1.0 | |
| total_acc_iou_sam += iou_per_sample_sam | |
| # Update global recorders | |
| intersection_meter.update(total_intersection.cpu().numpy()) | |
| union_meter.update(total_union.cpu().numpy()) | |
| acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image) | |
| if use_sam: | |
| intersection_meter_sam.update(total_intersection_sam.cpu().numpy()) | |
| union_meter_sam.update(total_union_sam.cpu().numpy()) | |
| acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image) | |
| # --- 3. Aggregate metrics from all GPUs --- | |
| all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device)) | |
| all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device)) | |
| all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device)) | |
| all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device)) | |
| if use_sam: | |
| all_intersections_sam = accelerator.gather_for_metrics( | |
| torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device)) | |
| all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device)) | |
| all_giou_sum_sam = accelerator.gather_for_metrics( | |
| torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device)) | |
| all_giou_count_sam = accelerator.gather_for_metrics( | |
| torch.tensor(acc_iou_meter_sam.count, device=accelerator.device)) | |
| # --- 4. Calculate final results and return on the main process --- | |
| final_metrics = {} | |
| if accelerator.is_main_process: | |
| # original model metrics | |
| # <-- MODIFICATION 2 (cIoU): Change epsilon to 1e-5 | |
| iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5) | |
| ciou = iou_class[1].item() | |
| giou_sum = torch.sum(all_giou_sum, dim=0)[1] | |
| giou_count = torch.sum(all_giou_count) | |
| giou = (giou_sum / giou_count).item() if giou_count > 0 else 0.0 | |
| final_metrics['giou'] = giou | |
| final_metrics['ciou'] = ciou | |
| # SAM optimized metrics | |
| if use_sam: | |
| # <-- MODIFICATION 2 (cIoU SAM): Change epsilon to 1e-5 | |
| iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5) | |
| ciou_sam = iou_class_sam[1].item() | |
| giou_sum_sam = torch.sum(all_giou_sum_sam, dim=0)[1] | |
| giou_count_sam = torch.sum(all_giou_count_sam) | |
| giou_sam = (giou_sum_sam / giou_count_sam).item() if giou_count_sam > 0 else 0.0 | |
| final_metrics['sam_giou'] = giou_sam | |
| final_metrics['sam_ciou'] = ciou_sam | |
| return final_metrics |