Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from fairseq import utils,tasks | |
| from utils.checkpoint_utils import load_model_ensemble_and_task | |
| from utils.eval_utils import eval_step | |
| from tasks.refcoco import RefcocoTask | |
| from models.polyformer import PolyFormerModel | |
| from PIL import Image | |
| import cv2 | |
| import math | |
| from skimage import draw | |
| tasks.register_task('refcoco', RefcocoTask) | |
| # turn on cuda if GPU is available | |
| use_cuda = torch.cuda.is_available() | |
| # use fp16 only when GPU is available | |
| use_fp16 = use_cuda | |
| # Load pretrained ckpt & config | |
| overrides={"bpe_dir":"utils/BPE"} | |
| models, cfg, task = load_model_ensemble_and_task( | |
| utils.split_paths('polyformer_l_refcocog.pt'), | |
| arg_overrides=overrides | |
| ) | |
| # print(cfg) | |
| cfg.common.seed = 7 | |
| cfg.generation.beam = 5 | |
| cfg.generation.min_len = 12 | |
| cfg.generation.max_len_a = 0 | |
| cfg.generation.max_len_b = 420 | |
| cfg.generation.no_repeat_ngram_size = 3 | |
| # cfg.max_tgt_length = 256 | |
| #cfg.num_bins = 1000 | |
| cfg.task.patch_image_size = 512 | |
| from bert.tokenization_bert import BertTokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| # model = '' | |
| # Move models to GPU | |
| for model in models: | |
| model.eval() | |
| if use_fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| # Initialize generator | |
| generator = task.build_generator(models, cfg.generation) | |
| # Image transform | |
| from torchvision import transforms | |
| mean = [0.5, 0.5, 0.5] | |
| std = [0.5, 0.5, 0.5] | |
| patch_resize_transform = transforms.Compose([ | |
| lambda image: image.convert("RGB"), | |
| transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| # Text preprocess | |
| bos_item = torch.LongTensor([task.src_dict.bos()]) | |
| eos_item = torch.LongTensor([task.src_dict.eos()]) | |
| pad_idx = task.src_dict.pad() | |
| # Construct input for refcoco task | |
| patch_image_size = cfg.task.patch_image_size | |
| def construct_sample(image: Image, text: str): | |
| w, h = image.size | |
| w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0) | |
| h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0) | |
| patch_image = patch_resize_transform(image).unsqueeze(0) | |
| patch_mask = torch.tensor([True]) | |
| prompt = ' which region does the text " {} " describe?'.format(text) | |
| tokenized = tokenizer.batch_encode_plus([prompt], padding="longest", return_tensors="pt") | |
| src_tokens = tokenized["input_ids"] | |
| att_masks = tokenized["attention_mask"] | |
| src_lengths = torch.LongTensor(att_masks.ne(0).long().sum()) | |
| sample = { | |
| "id":np.array(['42']), | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| "att_masks": att_masks, | |
| "patch_images": patch_image, | |
| "patch_masks": patch_mask, | |
| }, | |
| "w_resize_ratios": w_resize_ratio, | |
| "h_resize_ratios": h_resize_ratio, | |
| "region_coords": torch.randn(1, 4), | |
| "label": np.zeros((512,512)), | |
| "poly": 'None', | |
| "text": text | |
| } | |
| return sample | |
| # Function to turn FP32 to FP16 | |
| def apply_half(t): | |
| if t.dtype is torch.float32: | |
| return t.to(dtype=torch.half) | |
| return t | |
| from io import BytesIO | |
| import base64 | |
| import re | |
| def pre_caption(caption): | |
| caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person') | |
| caption = re.sub( | |
| r"\s{2,}", | |
| ' ', | |
| caption, | |
| ) | |
| caption = caption.rstrip('\n') | |
| caption = caption.strip(' ') | |
| return caption | |
| def convert_pts(coeffs): | |
| pts = [] | |
| for i in range(len(coeffs) // 2): | |
| pts.append([coeffs[2 * i + 1], coeffs[2 * i]]) # y, x | |
| return np.array(pts, np.int32) | |
| def get_mask_from_codes(codes, img_size): | |
| masks = [np.zeros(img_size)] | |
| for code in codes: | |
| mask = draw.polygon2mask(img_size, convert_pts(code)) | |
| mask = np.array(mask, np.uint8) | |
| masks.append(mask) | |
| mask = sum(masks) | |
| mask = mask > 0 | |
| return mask.astype(np.uint8) | |
| def overlay_predictions(img, mask=None, polygons=None, bbox=None, color_box=(0, 255, 0), color_mask=[255, 102, 102], color_poly=[255, 0, 0], thickness=3, radius=6): | |
| overlayed = img.copy() | |
| if bbox is not None: | |
| overlayed = draw_bbox(overlayed, bbox, color=color_box, thickness=thickness) | |
| if mask is not None: | |
| overlayed = overlay_davis(overlayed, mask, colors=[[0, 0, 0], color_mask]) | |
| if polygons is not None: | |
| overlayed = plot_polygons(overlayed, polygons, color=color_poly, radius=radius) | |
| return overlayed | |
| def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 102, 102]], cscale=1, alpha=0.4): # [255, 178, 102] orange [102, 178, 255] red | |
| from scipy.ndimage.morphology import binary_dilation | |
| colors = np.reshape(colors, (-1, 3)) | |
| colors = np.atleast_2d(colors) * cscale | |
| im_overlay = image.copy() | |
| object_ids = np.unique(mask) | |
| h_i, w_i = image.shape[0:2] | |
| h_m, w_m = mask.shape[0:2] | |
| if h_i != h_m: | |
| mask = cv2.resize(mask, [h_i, w_i], interpolation=cv2.INTER_NEAREST) | |
| for object_id in object_ids[1:]: | |
| # Overlay color on binary mask | |
| foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) | |
| binary_mask = mask == object_id | |
| # Compose image | |
| im_overlay[binary_mask] = foreground[binary_mask] | |
| return im_overlay.astype(image.dtype) | |
| def draw_bbox(img, box, color=(0, 255, 0), thickness=2): | |
| x1, y1, x2, y2 = box | |
| return cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=thickness) | |
| def plot_polygons(img, polygons, color=(255, 0, 0), radius=7): | |
| for polygon in polygons: | |
| if len(polygon) > 0: | |
| polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16) | |
| for i, point in enumerate(polygon): | |
| img = cv2.circle(img, point, radius, color, thickness=-1) | |
| img = cv2.circle(img, polygon[0], radius, color, thickness=-1) | |
| return img | |
| def plot_arrow(img, polygons, color=(128, 128, 128), thickness=3, tip_length=0.3): | |
| for polygon in polygons: | |
| if len(polygon) > 0: | |
| polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16) | |
| for i, point in enumerate(polygon): | |
| if i > 0: | |
| img = cv2.arrowedLine(img, polygon[i-1], point, color, thickness=thickness, tipLength=tip_length) | |
| return img | |
| def downsample_polygon(polygon, ds_rate=25): | |
| points = np.array(polygon).reshape(int(len(polygon) / 2), 2) | |
| points = points[::ds_rate] | |
| return list(points.flatten()) | |
| def downsample_polygons(polygons, ds_rate=25): | |
| polygons_ds = [] | |
| for polygon in polygons: | |
| polygons_ds.append(downsample_polygon(polygon, ds_rate)) | |
| return polygons_ds | |
| def visual_grounding(image, text): | |
| # Construct input sample & preprocess for GPU if cuda available | |
| sample = construct_sample(image, text.lower()) | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample | |
| with torch.no_grad(): | |
| if isinstance(models, list): | |
| model = models[0] | |
| model = model.float() | |
| min_len = 6 | |
| max_len = 210 | |
| model.eval() | |
| img = sample["net_input"]["patch_images"] | |
| b = img.shape[0] | |
| prev_output_token_11 = [[0] for _ in range(b)] | |
| prev_output_token_12 = [[0] for _ in range(b)] | |
| prev_output_token_21 = [[0] for _ in range(b)] | |
| prev_output_token_22 = [[0] for _ in range(b)] | |
| delta_x1 = [[0] for _ in range(b)] | |
| delta_y1 = [[0] for _ in range(b)] | |
| delta_x2 = [[1] for _ in range(b)] | |
| delta_y2 = [[1] for _ in range(b)] | |
| gen_out = [[] for _ in range(b)] | |
| n_bins = 64 | |
| unfinish_flag = np.ones(b) | |
| i = 0 | |
| encoder_out = model.encoder( | |
| sample['net_input']['src_tokens'], | |
| src_lengths=sample['net_input']['src_lengths'], | |
| att_masks=sample['net_input']['att_masks'], | |
| patch_images=sample['net_input']['patch_images'], | |
| patch_masks=sample['net_input']['patch_masks'], | |
| token_embeddings=None, | |
| return_all_hiddens=False, | |
| sample_patch_num=None | |
| ) | |
| attn_masks = [] | |
| while i < max_len and unfinish_flag.any(): | |
| # print(i) | |
| prev_output_tokens_11_tensor = torch.tensor(np.array(prev_output_token_11)).long() | |
| prev_output_tokens_12_tensor = torch.tensor(np.array(prev_output_token_12)).long() | |
| prev_output_tokens_21_tensor = torch.tensor(np.array(prev_output_token_21)).long() | |
| prev_output_tokens_22_tensor = torch.tensor(np.array(prev_output_token_22)).long() | |
| delta_x1_tensor = torch.tensor(np.array(delta_x1)).float() | |
| delta_x2_tensor = torch.tensor(np.array(delta_x2)).float() | |
| delta_y1_tensor = torch.tensor(np.array(delta_y1)).float() | |
| delta_y2_tensor = torch.tensor(np.array(delta_y2)).float() | |
| net_output = model.decoder( | |
| prev_output_tokens_11_tensor, | |
| prev_output_tokens_12_tensor, | |
| prev_output_tokens_21_tensor, | |
| prev_output_tokens_22_tensor, | |
| delta_x1_tensor, | |
| delta_y1_tensor, | |
| delta_x2_tensor, | |
| delta_y2_tensor, | |
| code_masks=None, | |
| encoder_out=encoder_out, | |
| features_only=False, | |
| alignment_layer=None, | |
| alignment_heads=None, | |
| src_lengths=sample['net_input']['src_lengths'], | |
| return_all_hiddens=False | |
| ) | |
| cls_output = net_output[0] | |
| cls_type = torch.argmax(cls_output, 2) | |
| reg_output = net_output[1].squeeze(-1) | |
| attn = net_output[2]['attn'] | |
| attn_arrays = [att.detach().cpu().numpy() for att in attn] | |
| attn_arrays = np.concatenate(attn_arrays, 0) | |
| attn_arrays = np.mean(attn_arrays, 0) | |
| attn_arrays = attn_arrays[i, :256].reshape(16, 16) | |
| h, w = image.size | |
| attn_mask = cv2.resize(attn_arrays.astype(np.float32), (h, w)) | |
| attn_masks.append(attn_mask) | |
| for j in range(b): | |
| # print(j) | |
| if unfinish_flag[j] == 1: # prediction is not finished | |
| cls_j = cls_type[j, i].item() | |
| if cls_j == 0 or (cls_j == 2 and i < min_len): # 0 for coordinate tokens; 2 for eos | |
| output_j_x, output_j_y = reg_output[j, i].cpu().numpy() | |
| output_j_x = min(output_j_x, 1) | |
| output_j_y = min(output_j_y, 1) | |
| gen_out[j].extend([output_j_x, output_j_y]) | |
| output_j_x = output_j_x * (n_bins - 1) | |
| output_j_y = output_j_y * (n_bins - 1) | |
| output_j_x_floor = math.floor(output_j_x) | |
| output_j_y_floor = math.floor(output_j_y) | |
| output_j_x_ceil = math.ceil(output_j_x) | |
| output_j_y_ceil = math.ceil(output_j_y) | |
| # convert to token | |
| prev_output_token_11[j].append(output_j_x_floor * n_bins + output_j_y_floor + 4) | |
| prev_output_token_12[j].append(output_j_x_floor * n_bins + output_j_y_ceil + 4) | |
| prev_output_token_21[j].append(output_j_x_ceil * n_bins + output_j_y_floor + 4) | |
| prev_output_token_22[j].append(output_j_x_ceil * n_bins + output_j_y_ceil + 4) | |
| delta_x = output_j_x - output_j_x_floor | |
| delta_y = output_j_y - output_j_y_floor | |
| elif cls_j == 1: # 1 for separator tokens | |
| gen_out[j].append(2) # insert 2 indicating separator tokens | |
| prev_output_token_11[j].append(3) | |
| prev_output_token_12[j].append(3) | |
| prev_output_token_21[j].append(3) | |
| prev_output_token_22[j].append(3) | |
| delta_x = 0 | |
| delta_y = 0 | |
| else: # eos is predicted and i >= min_len | |
| unfinish_flag[j] = 0 | |
| gen_out[j].append(-1) | |
| prev_output_token_11[j].append(2) # 2 is eos token | |
| prev_output_token_12[j].append(2) # 2 is eos token | |
| prev_output_token_21[j].append(2) # 2 is eos token | |
| prev_output_token_22[j].append(2) # 2 is eos token | |
| delta_x = 0 | |
| delta_y = 0 | |
| else: # prediction is finished | |
| gen_out[j].append(-1) | |
| prev_output_token_11[j].append(1) # 1 is padding token | |
| prev_output_token_12[j].append(1) | |
| prev_output_token_21[j].append(1) | |
| prev_output_token_22[j].append(1) | |
| delta_x = 0 | |
| delta_y = 0 | |
| delta_x1[j].append(delta_x) | |
| delta_y1[j].append(delta_y) | |
| delta_x2[j].append(1 - delta_x) | |
| delta_y2[j].append(1 - delta_y) | |
| i += 1 | |
| print("inference step: ", i) | |
| hyps = [] | |
| hyps_det = [] | |
| n_poly_pred = [] | |
| b = len(gen_out) | |
| for i in range(b): | |
| gen_out_i = np.array(gen_out[i]) | |
| gen_out_i = gen_out_i[gen_out_i != -1] # excluding eos and padding indices | |
| gen_out_i_det = gen_out_i[:4] | |
| w, h = image.size | |
| gen_out_i_det[::2] *= w | |
| gen_out_i_det[1::2] *= h | |
| polygons_pred = gen_out_i[4:] | |
| polygons_pred = np.append(polygons_pred, [2]) | |
| size = len(polygons_pred) | |
| idx_list = [idx for idx, val in | |
| enumerate(polygons_pred) if val == 2] # 2 indicates separator token | |
| polygons_pred[::2] *= w | |
| polygons_pred[1::2] *= h | |
| if len(idx_list) > 0: # multiple polygons | |
| polygons = [] | |
| pred_idx = 0 | |
| for idx in idx_list: | |
| cur_idx = idx | |
| if pred_idx == cur_idx or pred_idx == size: | |
| pass | |
| else: | |
| polygons.append(polygons_pred[pred_idx: cur_idx]) | |
| pred_idx = cur_idx + 1 | |
| else: | |
| polygons = [polygons_pred] | |
| n_poly_pred.append(len(polygons)) | |
| hyps.append(polygons) | |
| hyps_det.append(gen_out_i_det) | |
| pred_mask = get_mask_from_codes(hyps[0], (h, w)) | |
| pred_overlayed = overlay_predictions(np.asarray(image), pred_mask, hyps[0], hyps_det[0]) | |
| return pred_overlayed, np.array(pred_mask*255, dtype=np.uint8) | |