Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| "LiftFeat: 3D Geometry-Aware Local Feature Matching" | |
| """ | |
| import numpy as np | |
| import os | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import tqdm | |
| import math | |
| import cv2 | |
| import sys | |
| sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat') | |
| from utils.featurebooster import FeatureBooster | |
| from utils.config import featureboost_config | |
| # from models.model_dfb import LiftFeatModel | |
| # from models.interpolator import InterpolateSparse2d | |
| # from third_party.config import featureboost_config | |
| """ | |
| foundational functions | |
| """ | |
| def simple_nms(scores, radius): | |
| """Perform non maximum suppression on the heatmap using max-pooling. | |
| This method does not suppress contiguous points that have the same score. | |
| Args: | |
| scores: the score heatmap of size `(B, H, W)`. | |
| radius: an integer scalar, the radius of the NMS window. | |
| """ | |
| def max_pool(x): | |
| return torch.nn.functional.max_pool2d( | |
| x, kernel_size=radius * 2 + 1, stride=1, padding=radius | |
| ) | |
| zeros = torch.zeros_like(scores) | |
| max_mask = scores == max_pool(scores) | |
| for _ in range(2): | |
| supp_mask = max_pool(max_mask.float()) > 0 | |
| supp_scores = torch.where(supp_mask, zeros, scores) | |
| new_max_mask = supp_scores == max_pool(supp_scores) | |
| max_mask = max_mask | (new_max_mask & (~supp_mask)) | |
| return torch.where(max_mask, scores, zeros) | |
| def top_k_keypoints(keypoints, scores, k): | |
| if k >= len(keypoints): | |
| return keypoints, scores | |
| scores, indices = torch.topk(scores, k, dim=0, sorted=True) | |
| return keypoints[indices], scores | |
| def sample_k_keypoints(keypoints, scores, k): | |
| if k >= len(keypoints): | |
| return keypoints, scores | |
| indices = torch.multinomial(scores, k, replacement=False) | |
| return keypoints[indices], scores[indices] | |
| def soft_argmax_refinement(keypoints, scores, radius: int): | |
| width = 2 * radius + 1 | |
| sum_ = torch.nn.functional.avg_pool2d( | |
| scores[:, None], width, 1, radius, divisor_override=1 | |
| ) | |
| ar = torch.arange(-radius, radius + 1).to(scores) | |
| kernel_x = ar[None].expand(width, -1)[None, None] | |
| dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius) | |
| dy = torch.nn.functional.conv2d( | |
| scores[:, None], kernel_x.transpose(2, 3), padding=radius | |
| ) | |
| dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None] | |
| refined_keypoints = [] | |
| for i, kpts in enumerate(keypoints): | |
| delta = dydx[i][tuple(kpts.t())] | |
| refined_keypoints.append(kpts.float() + delta) | |
| return refined_keypoints | |
| # Legacy (broken) sampling of the descriptors | |
| def sample_descriptors(keypoints, descriptors, s): | |
| b, c, h, w = descriptors.shape | |
| keypoints = keypoints - s / 2 + 0.5 | |
| keypoints /= torch.tensor( | |
| [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], | |
| ).to( | |
| keypoints | |
| )[None] | |
| keypoints = keypoints * 2 - 1 # normalize to (-1, 1) | |
| args = {"align_corners": True} if torch.__version__ >= "1.3" else {} | |
| descriptors = torch.nn.functional.grid_sample( | |
| descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args | |
| ) | |
| descriptors = torch.nn.functional.normalize( | |
| descriptors.reshape(b, c, -1), p=2, dim=1 | |
| ) | |
| return descriptors | |
| # The original keypoint sampling is incorrect. We patch it here but | |
| # keep the original one above for legacy. | |
| def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): | |
| """Interpolate descriptors at keypoint locations""" | |
| b, c, h, w = descriptors.shape | |
| keypoints = keypoints / (keypoints.new_tensor([w, h]) * s) | |
| keypoints = keypoints * 2 - 1 # normalize to (-1, 1) | |
| descriptors = torch.nn.functional.grid_sample( | |
| descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False | |
| ) | |
| descriptors = torch.nn.functional.normalize( | |
| descriptors.reshape(b, c, -1), p=2, dim=1 | |
| ) | |
| return descriptors | |
| class UpsampleLayer(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| # 定义特征提取层,减少通道数同时增加特征提取能力 | |
| self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) | |
| # 使用BN层 | |
| self.bn = nn.BatchNorm2d(in_channels//2) | |
| # 使用LeakyReLU激活函数 | |
| self.leaky_relu = nn.LeakyReLU(0.1) | |
| def forward(self, x): | |
| x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) | |
| x = self.leaky_relu(self.bn(self.conv(x))) | |
| return x | |
| class KeypointHead(nn.Module): | |
| def __init__(self,in_channels,out_channels): | |
| super().__init__() | |
| self.layer1=BaseLayer(in_channels,32) | |
| self.layer2=BaseLayer(32,32) | |
| self.layer3=BaseLayer(32,64) | |
| self.layer4=BaseLayer(64,64) | |
| self.layer5=BaseLayer(64,128) | |
| self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1) | |
| self.bn=nn.BatchNorm2d(65) | |
| def forward(self,x): | |
| x=self.layer1(x) | |
| x=self.layer2(x) | |
| x=self.layer3(x) | |
| x=self.layer4(x) | |
| x=self.layer5(x) | |
| x=self.bn(self.conv(x)) | |
| return x | |
| class DescriptorHead(nn.Module): | |
| def __init__(self,in_channels,out_channels): | |
| super().__init__() | |
| self.layer=nn.Sequential( | |
| BaseLayer(in_channels,32), | |
| BaseLayer(32,32,activation=False), | |
| BaseLayer(32,64,activation=False), | |
| BaseLayer(64,out_channels,activation=False) | |
| ) | |
| def forward(self,x): | |
| x=self.layer(x) | |
| # x=nn.functional.softmax(x,dim=1) | |
| return x | |
| class HeatmapHead(nn.Module): | |
| def __init__(self,in_channels,mid_channels,out_channels): | |
| super().__init__() | |
| self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) | |
| self.bnHa = nn.BatchNorm2d(mid_channels) | |
| self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.bnHb = nn.BatchNorm2d(out_channels) | |
| self.leaky_relu = nn.LeakyReLU(0.1) | |
| def forward(self,x): | |
| x = self.leaky_relu(self.bnHa(self.convHa(x))) | |
| x = self.leaky_relu(self.bnHb(self.convHb(x))) | |
| x = torch.sigmoid(x) | |
| return x | |
| class DepthHead(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.upsampleDa = UpsampleLayer(in_channels) | |
| self.upsampleDb = UpsampleLayer(in_channels//2) | |
| self.upsampleDc = UpsampleLayer(in_channels//4) | |
| self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) | |
| self.bnDepa = nn.BatchNorm2d(in_channels//2) | |
| self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1) | |
| self.bnDepb = nn.BatchNorm2d(in_channels//4) | |
| self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1) | |
| self.bnDepc = nn.BatchNorm2d(3) | |
| self.leaky_relu = nn.LeakyReLU(0.1) | |
| def forward(self, x): | |
| x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False) | |
| x1 = self.upsampleDa(x) | |
| x1 = torch.cat([x0,x1],dim=1) | |
| x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1))) | |
| x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False) | |
| x2 = self.upsampleDb(x1) | |
| x2 = torch.cat([x1_0,x2],dim=1) | |
| x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2))) | |
| x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False) | |
| x3 = self.upsampleDc(x2) | |
| x3 = torch.cat([x2_0,x3],dim=1) | |
| x = self.leaky_relu(self.bnDepc(self.convDepc(x3))) | |
| x = F.normalize(x,p=2,dim=1) | |
| return x | |
| class BaseLayer(nn.Module): | |
| def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True): | |
| super().__init__() | |
| if activation: | |
| self.layer=nn.Sequential( | |
| nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), | |
| nn.BatchNorm2d(out_channels,affine=False), | |
| nn.ReLU(inplace=True) | |
| ) | |
| else: | |
| self.layer=nn.Sequential( | |
| nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), | |
| nn.BatchNorm2d(out_channels,affine=False) | |
| ) | |
| def forward(self,x): | |
| return self.layer(x) | |
| class LiftFeatSPModel(nn.Module): | |
| default_conf = { | |
| "has_detector": True, | |
| "has_descriptor": True, | |
| "descriptor_dim": 64, | |
| # Inference | |
| "sparse_outputs": True, | |
| "dense_outputs": False, | |
| "nms_radius": 4, | |
| "refinement_radius": 0, | |
| "detection_threshold": 0.005, | |
| "max_num_keypoints": -1, | |
| "max_num_keypoints_val": None, | |
| "force_num_keypoints": False, | |
| "randomize_keypoints_training": False, | |
| "remove_borders": 4, | |
| "legacy_sampling": True, # True to use the old broken sampling | |
| } | |
| def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True): | |
| super().__init__() | |
| self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.descriptor_dim = 64 | |
| self.norm = nn.InstanceNorm2d(1) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| c1,c2,c3,c4,c5 = 24,24,64,64,128 | |
| self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) | |
| self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) | |
| self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) | |
| self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) | |
| self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) | |
| self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) | |
| self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) | |
| self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) | |
| self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) | |
| self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1) | |
| self.upsample4 = UpsampleLayer(c4) | |
| self.upsample5 = UpsampleLayer(c5) | |
| self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1) | |
| self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1) | |
| # detector | |
| self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65) | |
| # descriptor | |
| self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim) | |
| # # heatmap | |
| # self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1) | |
| # depth | |
| self.depth_head = DepthHead(c3) | |
| self.fine_matcher = nn.Sequential( | |
| nn.Linear(128, 512), | |
| nn.BatchNorm1d(512, affine=False), | |
| nn.ReLU(inplace = True), | |
| nn.Linear(512, 512), | |
| nn.BatchNorm1d(512, affine=False), | |
| nn.ReLU(inplace = True), | |
| nn.Linear(512, 512), | |
| nn.BatchNorm1d(512, affine=False), | |
| nn.ReLU(inplace = True), | |
| nn.Linear(512, 512), | |
| nn.BatchNorm1d(512, affine=False), | |
| nn.ReLU(inplace = True), | |
| nn.Linear(512, 64), | |
| ) | |
| # feature_booster | |
| self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal) | |
| def feature_extract(self, x): | |
| x1 = self.relu(self.conv1a(x)) | |
| x1 = self.relu(self.conv1b(x1)) | |
| x1 = self.pool(x1) | |
| x2 = self.relu(self.conv2a(x1)) | |
| x2 = self.relu(self.conv2b(x2)) | |
| x2 = self.pool(x2) | |
| x3 = self.relu(self.conv3a(x2)) | |
| x3 = self.relu(self.conv3b(x3)) | |
| x3 = self.pool(x3) | |
| x4 = self.relu(self.conv4a(x3)) | |
| x4 = self.relu(self.conv4b(x4)) | |
| x4 = self.pool(x4) | |
| x5 = self.relu(self.conv5a(x4)) | |
| x5 = self.relu(self.conv5b(x5)) | |
| x5 = self.pool(x5) | |
| return x3,x4,x5 | |
| def fuse_multi_features(self,x3,x4,x5): | |
| # upsample x5 feature | |
| x5 = self.upsample5(x5) | |
| x4 = torch.cat([x4,x5],dim=1) | |
| x4 = self.conv_fusion45(x4) | |
| # upsample x4 feature | |
| x4 = self.upsample4(x4) | |
| x3 = torch.cat([x3,x4],dim=1) | |
| x = self.conv_fusion34(x3) | |
| return x | |
| def _unfold2d(self, x, ws = 2): | |
| """ | |
| Unfolds tensor in 2D with desired ws (window size) and concat the channels | |
| """ | |
| B, C, H, W = x.shape | |
| x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2) | |
| return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws) | |
| def forward1(self, x): | |
| """ | |
| input: | |
| x -> torch.Tensor(B, C, H, W) grayscale or rgb images | |
| return: | |
| feats -> torch.Tensor(B, 64, H/8, W/8) dense local features | |
| keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map | |
| heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map | |
| """ | |
| with torch.no_grad(): | |
| x = x.mean(dim=1, keepdim = True) | |
| x = self.norm(x) | |
| x3,x4,x5 = self.feature_extract(x) | |
| # features fusion | |
| x = self.fuse_multi_features(x3,x4,x5) | |
| # keypoint | |
| keypoint_map = self.keypoint_head(x) | |
| # descriptor | |
| des_map = self.descriptor_head(x) | |
| # # heatmap | |
| # heatmap = self.heatmap_head(x) | |
| # import pdb;pdb.set_trace() | |
| # depth | |
| d_feats = self.depth_head(x) | |
| return des_map, keypoint_map, d_feats | |
| # return des_map, keypoint_map, heatmap, d_feats | |
| def forward2(self, descs, kpts, normals): | |
| # import pdb;pdb.set_trace() | |
| normals_feat=self._unfold2d(normals, ws=8) | |
| normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1]) | |
| descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1]) | |
| kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1]) | |
| descs_refine = self.feature_boost(descs_v, kpts_v, normals_v) | |
| return descs_refine | |
| def forward(self,x): | |
| M1,K1,D1=self.forward1(x) | |
| descs_refine=self.forward2(M1,K1,D1) | |
| return descs_refine,M1,K1,D1 | |
| if __name__ == "__main__": | |
| img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg') | |
| img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE) | |
| img=cv2.resize(img,(800,608)) | |
| import pdb;pdb.set_trace() | |
| img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0 | |
| img=img.cuda() if torch.cuda.is_available() else img | |
| liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
| des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img) | |
| des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats) | |
| print(des_map.shape) | |