ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class DeepLabV3Decoder(nn.Sequential):
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
super().__init__(
ASPP(in_channels, out_channels, atrous_rates),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=True),
nn.ReLU(),
)
self.out_channels = out_channels
def forward(self, *features):
return super().forward(features[-1])
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=True,
),
nn.ReLU(),
)
class ASPPSeparableConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
SeparableConv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=True,
),
nn.ReLU(),
)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
nn.ReLU(),
)
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
super(ASPP, self).__init__()
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=True),
nn.ReLU(),
)
)
rate1, rate2, rate3 = tuple(atrous_rates)
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=True),
nn.ReLU(),
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class SeparableConv2d(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
):
dephtwise_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False,
)
pointwise_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=bias,
)
super().__init__(dephtwise_conv, pointwise_conv)