Spaces:
Runtime error
Runtime error
| import sys | |
| import json | |
| import os.path | |
| import logging | |
| import argparse | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import clip | |
| from collections import defaultdict | |
| from PIL import Image | |
| import faiss | |
| import os | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| cudnn.benchmark = True | |
| torch.manual_seed(0) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(0) | |
| import gc | |
| class ClipRetrieval(): | |
| def __init__(self, index_name): | |
| self.datastore = faiss.read_index(index_name) | |
| #self.datastore.nprobe=25 | |
| def get_nns(self, query_img, k=20): | |
| #get k nearest image | |
| D, I = self.datastore.search(query_img, k) | |
| return D, I[:,:k] | |
| class EvalDataset(): | |
| def __init__(self, dataset_splits, images_dir, images_names, clip_retrieval_processor, eval_split="val_images"): | |
| super().__init__() | |
| with open(dataset_splits) as f: | |
| self.split = json.load(f) | |
| self.split = self.split[eval_split] | |
| self.images_dir= images_dir | |
| with open(args.images_names) as f: | |
| self.images_names = json.load(f) | |
| self.clip_retrieval_processor = clip_retrieval_processor | |
| def __getitem__(self, i): | |
| coco_id = self.split[i] | |
| image_filename= self.images_dir+self.images_names[coco_id] | |
| img_open = Image.open(image_filename).copy() | |
| img = np.array(img_open) | |
| if len(img.shape) ==2 or img.shape[-1]!=3: #convert grey or CMYK to RGB | |
| img_open = img_open.convert('RGB') | |
| gc.collect() | |
| print("img_open",np.array(img_open).shape) | |
| #inputs_features_retrieval = self.clip_retrieval_processor(img_open).unsqueeze(0) | |
| return self.clip_retrieval_processor(img_open).unsqueeze(0), coco_id | |
| def __len__(self): | |
| return len(self.split) | |
| def evaluate(args): | |
| #load data of the datastore (i.e., captions) | |
| with open(args.index_captions) as f: | |
| data_datastore = json.load(f) | |
| datastore = ClipRetrieval(args.datastore_path) | |
| datastore_name = args.datastore_path.split("/")[-1] | |
| #load clip to encode the images that we want to retrieve captions for | |
| clip_retrieval_model, clip_retrieval_feature_extractor = clip.load("RN50x64", device=device) | |
| clip_retrieval_model.eval() | |
| #data_loader to get images that we want to retrieve captions for | |
| data_loader = torch.utils.data.DataLoader( | |
| EvalDataset( | |
| args.dataset_splits, | |
| args.images_dir, | |
| args.images_names, | |
| clip_retrieval_feature_extractor, | |
| args.split), | |
| batch_size=1, | |
| shuffle=True, | |
| num_workers=1, | |
| pin_memory=True | |
| ) | |
| print("device",device) | |
| nearest_caps={} | |
| for data in tqdm(data_loader): | |
| inputs_features_retrieval, coco_id = data | |
| coco_id = coco_id[0] | |
| #normalize images to retrieve (since datastore has also normalized captions) | |
| inputs_features_retrieval = inputs_features_retrieval.to(device) | |
| image_retrieval_features = clip_retrieval_model.encode_image(inputs_features_retrieval[0]) | |
| image_retrieval_features /= image_retrieval_features.norm(dim=-1, keepdim=True) | |
| image_retrieval_features=image_retrieval_features.detach().cpu().numpy().astype(np.float32) | |
| print("inputs_features_retrieval",inputs_features_retrieval.size()) | |
| print("image_retrieval_features",image_retrieval_features.shape) | |
| D, nearest_ids=datastore.get_nns(image_retrieval_features, k=5) | |
| print("D size", D.shape) | |
| print("nea", nearest_ids.shape) | |
| gc.collect() | |
| #Since at inference batch is 1 | |
| D=D[0] | |
| nearest_ids=nearest_ids[0] | |
| list_of_similar_caps=defaultdict(list) | |
| for index in range(len(nearest_ids)): | |
| nearest_id = str(nearest_ids[index]) | |
| nearest_cap=data_datastore[nearest_id] | |
| if len(nearest_cap.split()) > args.max_caption_len: | |
| print("retrieve cap too big" ) | |
| continue | |
| #distance=D[index] | |
| #list_of_similar_caps[datastore_name].append((nearest_cap, str(distance))) | |
| #list_of_similar_caps[datastore_name].append(nearest_cap) | |
| #nearest_caps[str(coco_id)]=list_of_similar_caps | |
| #save results | |
| outputs_dir = os.path.join(args.output_path, "retrieved_caps") | |
| if not os.path.exists(outputs_dir): | |
| os.makedirs(outputs_dir) | |
| data_name=dataset_splits.split("/")[-1] | |
| name = "nearest_caps_"+data_name +"_w_"+datastore_name + "_"+ args.split | |
| results_output_file_name = os.path.join(outputs_dir, name + ".json") | |
| json.dump(nearest_caps, open(results_output_file_name, "w")) | |
| def check_args(args): | |
| parser = argparse.ArgumentParser() | |
| #Info of the dataset to evaluate on (vizwiz, flick30k, msr-vtt) | |
| parser.add_argument("--images_dir",help="Folder where the preprocessed image data is located", default="data/vizwiz/images") | |
| parser.add_argument("--dataset_splits",help="File containing the dataset splits", default="data/vizwiz/dataset_splits.json") | |
| parser.add_argument("--images_names",help="File containing the images names per id", default="data/vizwiz/images_names.json") | |
| parser.add_argument("--split", default="val_images", choices=["val_images", "test_images"]) | |
| parser.add_argument("--max-caption-len", type=int, default=25) | |
| #Which datastore to use (web, human) | |
| parser.add_argument("--datastore_path", type=str, default="datastore2/vizwiz/vizwiz") | |
| parser.add_argument("--index_captions", | |
| help="File containing the captions of the datastore per id", default="datastore2/vizwiz/vizwiz.json") | |
| parser.add_argument("--output-path",help="Folder where to store outputs", default="eval_vizwiz_with_datastore_from_vizwiz.json") | |
| parsed_args = parser.parse_args(args) | |
| return parsed_args | |
| if __name__ == "__main__": | |
| args = check_args(sys.argv[1:]) | |
| logging.basicConfig( | |
| format='%(levelname)s: %(message)s', level=logging.INFO) | |
| logging.info(args) | |
| evaluate(args) | |