Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False): | |
| """ | |
| Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. | |
| Parameters | |
| ---------- | |
| source_des: torch.Tensor (B,256,H/8,W/8) | |
| Source image descriptors. | |
| target_des: torch.Tensor (B,256,H/8,W/8) | |
| Target image descriptors. | |
| source_points: torch.Tensor (B,H/8,W/8,2) | |
| Source image keypoints | |
| tar_points: torch.Tensor (B,H/8,W/8,2) | |
| Target image keypoints | |
| tar_points_un: torch.Tensor (B,2,H/8,W/8) | |
| Target image keypoints unnormalized | |
| eval_only: bool | |
| Computes only recall without the loss. | |
| Returns | |
| ------- | |
| loss: torch.Tensor | |
| Descriptor loss. | |
| recall: torch.Tensor | |
| Descriptor match recall. | |
| """ | |
| device = source_des.device | |
| loss = 0 | |
| batch_size = source_des.size(0) | |
| recall = 0. | |
| relax_field_size = [relax_field] | |
| margins = [1.0] | |
| weights = [1.0] | |
| isource_dense = top_kk is None | |
| for b_id in range(batch_size): | |
| if isource_dense: | |
| ref_desc = source_des[b_id].squeeze().view(256, -1) | |
| tar_desc = target_des[b_id].squeeze().view(256, -1) | |
| tar_points_raw = tar_points_un[b_id].view(2, -1) | |
| else: | |
| top_k = top_kk[b_id].squeeze() | |
| n_feat = top_k.sum().item() | |
| if n_feat < 20: | |
| continue | |
| ref_desc = source_des[b_id].squeeze()[:, top_k] | |
| tar_desc = target_des[b_id].squeeze()[:, top_k] | |
| tar_points_raw = tar_points_un[b_id][:, top_k] | |
| # Compute dense descriptor distance matrix and find nearest neighbor | |
| ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0)) | |
| tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0)) | |
| dmat = torch.mm(ref_desc.t(), tar_desc) | |
| dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)) | |
| _, idx = torch.sort(dmat, dim=1) | |
| # Compute triplet loss and recall | |
| for pyramid in range(len(relax_field_size)): | |
| candidates = idx.t() | |
| match_k_x = tar_points_raw[0, candidates] | |
| match_k_y = tar_points_raw[1, candidates] | |
| tru_x = tar_points_raw[0] | |
| tru_y = tar_points_raw[1] | |
| if pyramid == 0: | |
| correct2 = (abs(match_k_x[0]-tru_x) == 0) & (abs(match_k_y[0]-tru_y) == 0) | |
| correct2_cnt = correct2.float().sum() | |
| recall += float(1.0 / batch_size) * (float(correct2_cnt) / float( ref_desc.size(1))) | |
| if eval_only: | |
| continue | |
| correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (abs(match_k_y - tru_y) <= relax_field_size[pyramid]) | |
| incorrect_index = torch.arange(start=correct_k.shape[0]-1, end=-1, step=-1).unsqueeze(1).repeat(1,correct_k.shape[1]).to(device) | |
| incorrect_first = torch.argmax(incorrect_index * (1 - correct_k.long()), dim=0) | |
| incorrect_first_index = candidates.gather(0, incorrect_first.unsqueeze(0)).squeeze() | |
| anchor_var = ref_desc | |
| posource_var = tar_desc | |
| neg_var = tar_desc[:, incorrect_first_index] | |
| loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]).mul(weights[pyramid]) | |
| return loss, recall | |
| class KeypointLoss(object): | |
| """ | |
| Loss function class encapsulating the location loss, the descriptor loss, and the score loss. | |
| """ | |
| def __init__(self, config): | |
| self.score_weight = config.score_weight | |
| self.loc_weight = config.loc_weight | |
| self.desc_weight = config.desc_weight | |
| self.corres_weight = config.corres_weight | |
| self.corres_threshold = config.corres_threshold | |
| def __call__(self, data): | |
| B, _, hc, wc = data['source_score'].shape | |
| loc_mat_abs = torch.abs(data['target_coord_warped'].view(B, 2, -1).unsqueeze(3) - data['target_coord'].view(B, 2, -1).unsqueeze(2)) | |
| l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1) | |
| l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2) | |
| # construct pseudo ground truth matching matrix | |
| loc_min_mat = torch.repeat_interleave(l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1) | |
| pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.) | |
| neg_mask = l2_dist_loc_mat.ge(4.) | |
| pos_corres = - torch.log(data['confidence_matrix'][pos_mask]) | |
| neg_corres = - torch.log(1.0 - data['confidence_matrix'][neg_mask]) | |
| corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean() | |
| # corresponding distance threshold is 4 | |
| dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data['border_mask'].view(B, hc * wc) | |
| # location loss | |
| loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean() | |
| # desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. | |
| desc_loss, _ = build_descriptor_loss(data['source_desc'], data['target_desc_warped'], data['target_coord_warped'].detach(), top_kk=data['border_mask'], relax_field=8) | |
| # score loss | |
| target_score_associated = data['target_score'].view(B, hc * wc).gather(1, l2_dist_loc_min_index).view(B, hc, wc).unsqueeze(1) | |
| dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data['border_mask'].unsqueeze(1) | |
| l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1) | |
| loc_err = l2_dist_loc_min[dist_norm_valid_mask] | |
| # repeatable_constrain in score loss | |
| repeatable_constrain = ((target_score_associated[dist_norm_valid_mask] + data['source_score'][dist_norm_valid_mask]) * (loc_err - loc_err.mean())).mean() | |
| # consistent_constrain in score_loss | |
| consistent_constrain = torch.nn.functional.mse_loss(data['target_score_warped'][data['border_mask'].unsqueeze(1)], data['source_score'][data['border_mask'].unsqueeze(1)]).mean() * 2 | |
| aware_consistent_loss = torch.nn.functional.mse_loss(data['target_aware_warped'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)], data['source_aware'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)]).mean() * 2 | |
| score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss | |
| loss = self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss | |
| return loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss | |