import torch import torch.nn as nn from torch import Tensor class DoubleConv(nn.Module): def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None): super().__init__() if mid_channels is None: mid_channels = out_channels self.conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x: Tensor) -> Tensor: return self.conv(x) class Down(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: return self.maxpool_conv(x) class Up(nn.Module): def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False): super().__init__() if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x): x = self.up(x) return self.conv(x) class Encoder(nn.Module): def __init__(self, z_channels: int, in_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs): super().__init__() self.encoder = nn.ModuleList() num_resolutions = len(channels_mult) in_ch_mult = (1,) + tuple(channels_mult) self.encoder.append(DoubleConv(in_channels, channels)) for i_level in range(num_resolutions): block_in = channels * in_ch_mult[i_level] block_out = channels * channels_mult[i_level] if i_level != num_resolutions - 1: self.encoder.append(Down(block_in, block_out)) else: self.encoder.append(DoubleConv(block_in, block_out)) block_in = block_out self.encoder.append(nn.Conv2d(block_in, z_channels, kernel_size=(1, 1))) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: for layer in self.encoder: x = layer(x) return x class Decoder(nn.Module): def __init__(self, z_channels: int, out_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs): super().__init__() self.decoder = nn.ModuleList() num_resolutions = len(channels_mult) block_in = channels*channels_mult[num_resolutions-1] self.decoder.append(nn.Conv2d(z_channels, block_in, kernel_size=(1, 1))) for i_level in reversed(range(num_resolutions)): block_out = channels * channels_mult[i_level] if i_level != 0: self.decoder.append(Up(block_in, block_out)) else: self.decoder.append(DoubleConv(block_in, block_out)) block_in = block_out self.final_conv = nn.Conv2d(block_in, out_channels, kernel_size=1) def forward(self, x): for layer in self.decoder: x = layer(x) return self.final_conv(x)