Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class slam(nn.Module): | |
| def __init__(self, spatial_dim): | |
| super(slam,self).__init__() | |
| self.spatial_dim = spatial_dim | |
| self.linear = nn.Sequential( | |
| nn.Linear(spatial_dim**2,512), | |
| nn.ReLU(), | |
| nn.Linear(512,1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, feature): | |
| n,c,h,w = feature.shape | |
| if (h != self.spatial_dim): | |
| x = F.interpolate(feature,size=(self.spatial_dim,self.spatial_dim),mode= "bilinear", align_corners=True) | |
| else: | |
| x = feature | |
| x = x.view(n,c,-1) | |
| x = self.linear(x) | |
| x = x.unsqueeze(dim =3) | |
| out = x.expand_as(feature)*feature | |
| return out | |
| class to_map(nn.Module): | |
| def __init__(self,channels): | |
| super(to_map,self).__init__() | |
| self.to_map = nn.Sequential( | |
| nn.Conv2d(in_channels=channels,out_channels=1, kernel_size=1,stride=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self,feature): | |
| return self.to_map(feature) | |
| class conv_bn_relu(nn.Module): | |
| def __init__(self,in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1): | |
| super(conv_bn_relu,self).__init__() | |
| self.conv = nn.Conv2d(in_channels= in_channels, out_channels= out_channels, kernel_size= kernel_size, padding= padding, stride = stride) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| def forward(self,x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| x = self.relu(x) | |
| return x | |
| class up_conv_bn_relu(nn.Module): | |
| def __init__(self,up_size, in_channels, out_channels = 64, kernal_size = 1, padding =0, stride = 1): | |
| super(up_conv_bn_relu,self).__init__() | |
| self.upSample = nn.Upsample(size = (up_size,up_size),mode="bilinear",align_corners=True) | |
| self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size = kernal_size, stride = stride, padding= padding) | |
| self.bn = nn.BatchNorm2d(num_features=out_channels) | |
| self.act = nn.ReLU() | |
| def forward(self,x): | |
| x = self.upSample(x) | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| x = self.act(x) | |
| return x | |
| class ICNet(nn.Module): | |
| def __init__(self, is_pretrain = True, size1 = 512, size2 = 256): | |
| super(ICNet,self).__init__() | |
| resnet18Pretrained1 = torchvision.models.resnet18(pretrained= is_pretrain) | |
| resnet18Pretrained2 = torchvision.models.resnet18(pretrained= is_pretrain) | |
| self.size1 = size1 | |
| self.size2 = size2 | |
| ## detail branch | |
| self.b1_1 = nn.Sequential(*list(resnet18Pretrained1.children())[:5]) | |
| self.b1_1_slam = slam(32) | |
| self.b1_2 = list(resnet18Pretrained1.children())[5] | |
| self.b1_2_slam = slam(32) | |
| ## context branch | |
| self.b2_1 = nn.Sequential(*list(resnet18Pretrained2.children())[:5]) | |
| self.b2_1_slam = slam(32) | |
| self.b2_2 = list(resnet18Pretrained2.children())[5] | |
| self.b2_2_slam = slam(32) | |
| self.b2_3 = list(resnet18Pretrained2.children())[6] | |
| self.b2_3_slam = slam(16) | |
| self.b2_4 = list(resnet18Pretrained2.children())[7] | |
| self.b2_4_slam = slam(8) | |
| ## upsample | |
| self.upsize = size1 // 8 | |
| self.up1 = up_conv_bn_relu(up_size = self.upsize, in_channels = 128, out_channels = 256) | |
| self.up2 = up_conv_bn_relu(up_size = self.upsize, in_channels = 512, out_channels = 256) | |
| ## map prediction head | |
| self.to_map_f = conv_bn_relu(256*2,256*2) | |
| self.to_map_f_slam = slam(32) | |
| self.to_map = to_map(256*2) | |
| ## score prediction head | |
| self.to_score_f = conv_bn_relu(256*2,256*2) | |
| self.to_score_f_slam = slam(32) | |
| self.head = nn.Sequential( | |
| nn.Linear(256*2,512), | |
| nn.ReLU(), | |
| nn.Linear(512,1), | |
| nn.Sigmoid() | |
| ) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1,1)) | |
| def forward(self,x1): | |
| assert(x1.shape[2] == x1.shape[3] == self.size1) | |
| x2 = F.interpolate(x1, size= (self.size2,self.size2), mode = "bilinear", align_corners= True) | |
| x1 = self.b1_2_slam(self.b1_2(self.b1_1_slam(self.b1_1(x1)))) | |
| x2 = self.b2_2_slam(self.b2_2(self.b2_1_slam(self.b2_1(x2)))) | |
| x2 = self.b2_4_slam(self.b2_4(self.b2_3_slam(self.b2_3(x2)))) | |
| x1 = self.up1(x1) | |
| x2 = self.up2(x2) | |
| x_cat = torch.cat((x1,x2),dim = 1) | |
| cly_map = self.to_map(self.to_map_f_slam(self.to_map_f(x_cat))) | |
| score_feature = self.to_score_f_slam(self.to_score_f(x_cat)) | |
| score_feature = self.avgpool(score_feature) | |
| score_feature = score_feature.squeeze() | |
| score = self.head(score_feature) | |
| score = score.squeeze() | |
| return score,cly_map | |