Spaces:
Running
Running
| import argparse | |
| import torch | |
| import os | |
| from tqdm import tqdm | |
| import random | |
| from PIL import Image | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import datetime | |
| from model.segment_anything import SamPredictor, sam_model_registry | |
| from dataset.refer_seg_dataset import ValDataset | |
| from dataset.grefer_seg_dataset import grefcocoValDataset | |
| from data.question_answer_list import QUESTION_PARTIAL | |
| from segment_predictor import GenerativeSegmenter | |
| from eval.utils import AverageMeter, Summary, intersectionAndUnionGPU, \ | |
| compute_logits_from_mask, masks_sample_points | |
| from torch.utils.data import Dataset, DataLoader | |
| import math | |
| # --- Accelerate Import --- | |
| from accelerate import Accelerator | |
| def get_chunk(ds, n, k): | |
| chunk_size = math.ceil(len(ds) / n) | |
| i = chunk_size * k | |
| start_index = i | |
| end_index = i + chunk_size | |
| ds.refer_seg_ds["images"] = ds.refer_seg_ds["images"][start_index:end_index] | |
| return ds | |
| def gget_chunk(ds, n, k): | |
| chunk_size = math.ceil(len(ds) / n) | |
| i = chunk_size * k | |
| start_index = i | |
| end_index = i + chunk_size | |
| ds.loaded_images = ds.loaded_images[start_index:end_index] | |
| return ds | |
| class CustomDataset(Dataset): | |
| def __init__(self, sub_dataset): | |
| self.dataset = sub_dataset | |
| def __getitem__(self, index): | |
| image, masks, questions, image_path = self.dataset[index] | |
| image_name = os.path.basename(image_path).split(".")[0] | |
| questions = [random.choice(QUESTION_PARTIAL).replace("[class_name]", q) for q in questions] | |
| return image, masks, image_name, questions, image_path | |
| def __len__(self): | |
| return len(self.dataset) | |
| def collate_fn(batch): | |
| images, masks, image_names, questions, image_paths = zip(*batch) | |
| return images, masks, image_names, questions, image_paths | |
| def create_data_loader(args, sub_dataset, batch_size=1): | |
| assert batch_size == 1, "Batch size must be 1 for this evaluation script." | |
| dataset = CustomDataset(sub_dataset) | |
| return DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=collate_fn) | |
| def eval_model(args): | |
| # --- Initialize Accelerator --- | |
| accelerator = Accelerator() | |
| # --- Model and Predictor Initialization --- | |
| segmenter = GenerativeSegmenter(args.model_path, device_map=accelerator.device, min_pixels=args.min_pixels, | |
| max_pixels=args.max_pixels) | |
| sam_help = args.sam_path is not None | |
| if sam_help: | |
| sam = sam_model_registry["vit_h"](checkpoint=args.sam_path) | |
| sam = sam.to(dtype=torch.float32, device=accelerator.device) | |
| predictor = SamPredictor(sam) | |
| else: | |
| predictor = None | |
| # --- Dataset and DataLoader Initialization --- | |
| if accelerator.is_main_process: | |
| print("Loading dataset...") | |
| # First, load the full dataset based on the parameters | |
| if "grefcoco" in args.dataset_split: | |
| val_dataset = grefcocoValDataset(args.image_folder, args.dataset_split) | |
| else: | |
| val_dataset = ValDataset(args.image_folder, args.dataset_split) | |
| if accelerator.is_main_process: | |
| total_data_size = len(val_dataset) | |
| print(f"Total evaluation data volume (full dataset): {total_data_size} samples.") | |
| # Then, get a chunk of the dataset as needed | |
| if "grefcoco" in args.dataset_split: | |
| sub_dataset = gget_chunk(val_dataset, args.num_chunks, args.chunk_idx) | |
| else: | |
| sub_dataset = get_chunk(val_dataset, args.num_chunks, args.chunk_idx) | |
| data_loader = create_data_loader(args, sub_dataset) | |
| data_loader = accelerator.prepare(data_loader) | |
| # --- Metric Meters Initialization --- | |
| intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| if sam_help: | |
| intersection_meter_sam = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter_sam = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter_sam = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| progress_bar = tqdm(data_loader, disable=not accelerator.is_main_process, total=len(data_loader)) | |
| for batch in progress_bar: | |
| images, masks, image_names, questions, image_paths = batch | |
| image, gt_masks, image_name, prompts = images[0], masks[0], image_names[0], questions[0] | |
| w_ori, h_ori = image.size | |
| total_intersection = torch.zeros(2, device=accelerator.device) | |
| total_union = torch.zeros(2, device=accelerator.device) | |
| total_acc_iou = torch.zeros(2, device=accelerator.device) | |
| if sam_help: | |
| total_intersection_sam = torch.zeros(2, device=accelerator.device) | |
| total_union_sam = torch.zeros(2, device=accelerator.device) | |
| total_acc_iou_sam = torch.zeros(2, device=accelerator.device) | |
| num_masks_in_image = len(prompts) | |
| with torch.inference_mode(): | |
| if sam_help: | |
| predictor.set_image(np.array(image)) | |
| for i, question in enumerate(prompts): | |
| gt_mask = gt_masks[i].to(accelerator.device).float().contiguous() | |
| segmentation_masks, _ = segmenter.generate_with_segmentation(image, question) | |
| if segmentation_masks is None or len(segmentation_masks) == 0: | |
| pred_mask = torch.zeros((h_ori, w_ori), device=accelerator.device) | |
| else: | |
| mask = segmentation_masks[0].to(accelerator.device) | |
| pred_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).double(), size=(h_ori, w_ori), | |
| mode='nearest').squeeze() | |
| # if accelerator.is_main_process: | |
| # print("\n" + "=" * 20 + " DEBUG INFO (first iteration) " + "=" * 20) | |
| # print(f"Image Name: {image_name}") | |
| # print(f"GT Mask Shape: {gt_mask.shape}") | |
| # print(f"GT Mask DType: {gt_mask.dtype}") | |
| # print(f"Unique values in GT Mask: {torch.unique(gt_mask)}") | |
| # print(f"Pred Mask Shape: {pred_mask.shape}") | |
| # print(f"Pred Mask DType: {pred_mask.dtype}") | |
| # # Print the number of non-zero pixels in the predicted mask to check if the model generated an all-black image | |
| # print(f"Number of non-zero pixels in Pred Mask: {torch.count_nonzero(pred_mask)}") | |
| # print("=" * 68) | |
| sam_refined_mask = torch.zeros_like(pred_mask) | |
| if sam_help: | |
| unique_classes = torch.unique(pred_mask) | |
| for class_id in unique_classes: | |
| if class_id == 0: continue | |
| binary_mask = (pred_mask == class_id).double().cpu() | |
| try: | |
| logits = compute_logits_from_mask(pred_mask.cpu()) | |
| point_coords, point_labels = masks_sample_points(binary_mask) | |
| sam_mask, _, logit = predictor.predict(point_coords=point_coords, | |
| point_labels=point_labels, | |
| mask_input=logits, multimask_output=False) | |
| for _ in range(2): | |
| sam_mask, _, logit = predictor.predict(point_coords=point_coords, | |
| point_labels=point_labels, | |
| mask_input=logit, multimask_output=False) | |
| sam_mask = sam_mask[0].astype(np.float32) | |
| except Exception as E: | |
| print(f"Error: {E}") | |
| sam_mask = np.zeros((h_ori, w_ori)) | |
| sam_refined_mask = torch.from_numpy(sam_mask).to(accelerator.device) | |
| # sam_refined_mask[torch.from_numpy(sam_mask[0] > 0).to(accelerator.device)] = class_id | |
| intersection_i, union_i, _ = intersectionAndUnionGPU(pred_mask, gt_mask, 2, ignore_index=255) | |
| total_intersection += intersection_i | |
| total_union += union_i | |
| iou_per_sample = intersection_i / (union_i + 1e-5) | |
| iou_per_sample[union_i == 0] = 1.0 | |
| total_acc_iou += iou_per_sample | |
| if sam_help: | |
| intersection_sam_i, union_sam_i, _ = intersectionAndUnionGPU(sam_refined_mask, gt_mask, 2, | |
| ignore_index=255) | |
| total_intersection_sam += intersection_sam_i | |
| total_union_sam += union_sam_i | |
| iou_per_sample_sam = intersection_sam_i / (union_sam_i + 1e-5) | |
| iou_per_sample_sam[union_sam_i == 0] = 1.0 | |
| total_acc_iou_sam += iou_per_sample_sam | |
| if args.save_masks and accelerator.is_main_process: | |
| ds_split_sanitized = args.dataset_split.replace("|", "_") | |
| model_name = os.path.basename(args.model_path.strip('/')) | |
| save_path = os.path.join(args.save_file, model_name, ds_split_sanitized, "masks", image_name) | |
| if not os.path.exists(save_path): os.makedirs(save_path) | |
| Image.fromarray(pred_mask.cpu().numpy().astype("uint8") * 255).convert('L').save( | |
| os.path.join(save_path, f"{i}_pred_mask.png")) | |
| if sam_help: | |
| Image.fromarray(sam_refined_mask.cpu().numpy().astype("uint8") * 255).convert('L').save( | |
| os.path.join(save_path, f"{i}_sam_mask.png")) | |
| Image.fromarray(gt_mask.cpu().numpy().astype("uint8") * 255).convert('L').save( | |
| os.path.join(save_path, f"{i}_gt_mask.png")) | |
| image.save(os.path.join(save_path, f"{i}_image.png")) | |
| intersection_meter.update(total_intersection.cpu().numpy()) | |
| union_meter.update(total_union.cpu().numpy()) | |
| if sam_help: | |
| intersection_meter_sam.update(total_intersection_sam.cpu().numpy()) | |
| union_meter_sam.update(total_union_sam.cpu().numpy()) | |
| if num_masks_in_image > 0: | |
| total_acc_iou = total_acc_iou / num_masks_in_image | |
| acc_iou_meter.update(total_acc_iou.cpu().numpy(), n=num_masks_in_image) | |
| if sam_help: | |
| total_acc_iou_sam = total_acc_iou_sam / num_masks_in_image | |
| acc_iou_meter_sam.update(total_acc_iou_sam.cpu().numpy(), n=num_masks_in_image) | |
| # break | |
| # --- Synchronize metrics across all processes --- | |
| all_intersections = accelerator.gather_for_metrics(torch.from_numpy(intersection_meter.sum).to(accelerator.device)) | |
| all_unions = accelerator.gather_for_metrics(torch.from_numpy(union_meter.sum).to(accelerator.device)) | |
| all_giou_sum = accelerator.gather_for_metrics(torch.from_numpy(acc_iou_meter.sum).to(accelerator.device)) | |
| all_giou_count = accelerator.gather_for_metrics(torch.tensor(acc_iou_meter.count, device=accelerator.device)) | |
| all_intersections = all_intersections.view(-1, 2) | |
| all_unions = all_unions.view(-1, 2) | |
| all_giou_sum = all_giou_sum.view(-1, 2) | |
| all_giou_count = all_giou_count.view(-1, 1) | |
| if sam_help: | |
| all_intersections_sam = accelerator.gather_for_metrics( | |
| torch.from_numpy(intersection_meter_sam.sum).to(accelerator.device)) | |
| all_unions_sam = accelerator.gather_for_metrics(torch.from_numpy(union_meter_sam.sum).to(accelerator.device)) | |
| all_giou_sum_sam = accelerator.gather_for_metrics( | |
| torch.from_numpy(acc_iou_meter_sam.sum).to(accelerator.device)) | |
| all_giou_count_sam = accelerator.gather_for_metrics( | |
| torch.tensor(acc_iou_meter_sam.count, device=accelerator.device)) | |
| all_intersections_sam = all_intersections_sam.view(-1, 2) | |
| all_unions_sam = all_unions_sam.view(-1, 2) | |
| all_giou_sum_sam = all_giou_sum_sam.view(-1, 2) | |
| all_giou_count_sam = all_giou_count_sam.view(-1, 1) | |
| # --- Only calculate and output final results on the main process --- | |
| if accelerator.is_main_process: | |
| iou_class = torch.sum(all_intersections, dim=0) / (torch.sum(all_unions, dim=0) + 1e-5) | |
| # print(all_intersections, all_unions, iou_class) | |
| ciou = iou_class[1].item() | |
| giou = (torch.sum(all_giou_sum, dim=0)[1] / torch.sum(all_giou_count)).item() | |
| if sam_help: | |
| iou_class_sam = torch.sum(all_intersections_sam, dim=0) / (torch.sum(all_unions_sam, dim=0) + 1e-5) | |
| ciou_sam = iou_class_sam[1].item() | |
| giou_sam = (torch.sum(all_giou_sum_sam, dim=0)[1] / torch.sum(all_giou_count_sam)).item() | |
| else: | |
| giou_sam, ciou_sam = 0.0, 0.0 | |
| # <--- Added: Calculate and print accurate evaluation totals --- | |
| total_evaluated_images = len(sub_dataset) # Total images evaluated | |
| total_evaluated_masks = torch.sum(all_giou_count).item() # Total masks/prompts evaluated | |
| # <--- End added --- | |
| print("\n" + "=" * 50) | |
| print(f"Evaluation finished for: {args.model_path}") | |
| print(f"Dataset: {args.dataset_split}") | |
| print("-" * 50) | |
| # <--- Added: Print evaluation sample counts --- | |
| print(f"Total images evaluated: {total_evaluated_images}") | |
| print(f"Total masks/prompts evaluated: {total_evaluated_masks}") | |
| print("-" * 50) | |
| # <--- End added --- | |
| print(f"Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}") | |
| if sam_help: | |
| print(f"SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}") | |
| print("=" * 50 + "\n") | |
| # --- Dynamically construct output file path and write results --- | |
| model_name = os.path.basename(args.model_path.strip('/')) | |
| ds_split_sanitized = args.dataset_split.replace("|", "_") | |
| output_dir = os.path.join(args.save_file, model_name, ds_split_sanitized) | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_filepath = os.path.join(output_dir, "evaluation_results.txt") | |
| current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| chunk_info = f"(chunk {args.chunk_idx + 1}/{args.num_chunks})" if args.num_chunks > 1 else "" | |
| header_text = f"[{current_time}] Model: {args.model_path}, Dataset: {args.dataset_split} {chunk_info}\n" | |
| # <--- Added: Also record evaluation sample counts in the file --- | |
| eval_stats_text = f" - Evaluated on {total_evaluated_images} images and {total_evaluated_masks} masks.\n" | |
| # <--- End added --- | |
| output_text = f" - Raw Model Mask -> gIoU: {giou:.4f}, cIoU: {ciou:.4f}\n" | |
| if sam_help: | |
| output_text += f" - SAM-Refined Mask -> gIoU: {giou_sam:.4f}, cIoU: {ciou_sam:.4f}\n" | |
| with open(output_filepath, "a") as file: | |
| file.write(header_text) | |
| file.write(eval_stats_text) # <--- Added | |
| file.write(output_text) | |
| file.write("-" * 60 + "\n") | |
| print(f"Results appended to: {output_filepath}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_path", type=str, default='/raid2/DATA/text4seg/model_trained_qwen_2b/', | |
| help="Path to your GenerativeSegmenter model checkpoint.") | |
| parser.add_argument("--sam_path", type=str, default='/efficient_sag4text/sam_vit_h_4b8939.pth', help="Path to the SAM checkpoint.") | |
| parser.add_argument("--image_folder", type=str, default='/efficient_sag4text/seg_data/refer_seg', help="Root folder for the dataset images.") | |
| parser.add_argument("--dataset_split", type=str, default="refcoco|unc|val", help="Dataset split to evaluate on.") | |
| parser.add_argument("--save_file", type=str, default="output_eval_accelerated/", | |
| help="Root directory to save evaluation outputs (masks and metrics).") | |
| parser.add_argument("--save_masks", action='store_true', help="Set this flag to save output masks and images.") | |
| parser.add_argument("--num_chunks", type=int, default=1) | |
| parser.add_argument("--chunk-idx", type=int, default=0) | |
| parser.add_argument("--min_pixels", type=int, default=1024*28 * 28, help="Minimum pixels for segmentation.") | |
| parser.add_argument("--max_pixels", type=int, default=1024*28 * 28, help="Maximum pixels for segmentation.") | |
| args = parser.parse_args() | |
| eval_model(args) |