Spaces:
Runtime error
Runtime error
| # This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/original_models.py | |
| # MIT License | |
| # Copyright (c) 2022 Lorenzo Breschi | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| from torch.nn import functional as F | |
| import torchvision | |
| from torchvision import models | |
| import pytorch_lightning as pl | |
| class LeakySoftplus(nn.Module): | |
| def __init__(self,negative_slope: float = 0.01 ): | |
| super().__init__() | |
| self.negative_slope=negative_slope | |
| def forward(self,input): | |
| return F.softplus(input)+F.logsigmoid(input)*self.negative_slope | |
| grelu = nn.LeakyReLU(0.2) | |
| #grelu = nn.Softplus() | |
| #grelu = LeakySoftplus(0.2) | |
| ##### | |
| # Currently default generator we use | |
| # conv0 -> conv1 -> conv2 -> resnet_blocks -> upconv2 -> upconv1 -> conv_11 -> (conv_11_a)* -> conv_12 -> (Tanh)* | |
| # there are 2 conv layers inside conv_11_a | |
| # * means is optional, model uses skip-connections | |
| class Generator(pl.LightningModule): | |
| def __init__(self, norm_layer='batch_norm', use_bias=False, resnet_blocks=7, tanh=True, | |
| filters=[32, 64, 128, 128, 128, 64], input_channels=3, output_channels=3, append_smoothers=False): | |
| super().__init__() | |
| assert norm_layer in [None, 'batch_norm', 'instance_norm'], \ | |
| "norm_layer should be None, 'batch_norm' or 'instance_norm', not {}".format( | |
| norm_layer) | |
| self.norm_layer = None | |
| if norm_layer == 'batch_norm': | |
| self.norm_layer = nn.BatchNorm2d | |
| elif norm_layer == 'instance_norm': | |
| self.norm_layer = nn.InstanceNorm2d | |
| # filters = [f//3 for f in filters] | |
| self.use_bias = use_bias | |
| self.resnet_blocks = resnet_blocks | |
| self.append_smoothers = append_smoothers | |
| stride1 = 2 | |
| stride2 = 2 | |
| self.conv0 = self.relu_layer(in_filters=input_channels, out_filters=filters[0], | |
| kernel_size=7, stride=1, padding=3, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu) | |
| self.conv1 = self.relu_layer(in_filters=filters[0], | |
| out_filters=filters[1], | |
| kernel_size=3, stride=stride1, padding=1, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu) | |
| self.conv2 = self.relu_layer(in_filters=filters[1], | |
| out_filters=filters[2], | |
| kernel_size=3, stride=stride2, padding=1, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu) | |
| self.resnets = nn.ModuleList() | |
| for i in range(self.resnet_blocks): | |
| self.resnets.append( | |
| self.resnet_block(in_filters=filters[2], | |
| out_filters=filters[2], | |
| kernel_size=3, stride=1, padding=1, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu)) | |
| self.upconv2 = self.upconv_layer_upsample_and_conv(in_filters=filters[3] + filters[2], | |
| # in_filters=filters[3], # disable skip-connections | |
| out_filters=filters[4], | |
| scale_factor=stride2, | |
| kernel_size=3, stride=1, padding=1, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu) | |
| self.upconv1 = self.upconv_layer_upsample_and_conv(in_filters=filters[4] + filters[1], | |
| # in_filters=filters[4], # disable skip-connections | |
| out_filters=filters[4], | |
| scale_factor=stride1, | |
| kernel_size=3, stride=1, padding=1, | |
| bias=self.use_bias, | |
| norm_layer=self.norm_layer, | |
| nonlinearity=grelu) | |
| self.conv_11 = nn.Sequential( | |
| nn.Conv2d(in_channels=filters[0] + filters[4] + input_channels, | |
| # in_channels=filters[4], # disable skip-connections | |
| out_channels=filters[5], | |
| kernel_size=7, stride=1, padding=3, bias=self.use_bias, padding_mode='zeros'), | |
| grelu | |
| ) | |
| if self.append_smoothers: | |
| self.conv_11_a = nn.Sequential( | |
| nn.Conv2d(filters[5], filters[5], kernel_size=3, | |
| bias=self.use_bias, padding=1, padding_mode='zeros'), | |
| grelu, | |
| # replace with variable | |
| nn.BatchNorm2d(num_features=filters[5]), | |
| nn.Conv2d(filters[5], filters[5], kernel_size=3, | |
| bias=self.use_bias, padding=1, padding_mode='zeros'), | |
| grelu | |
| ) | |
| if tanh: | |
| self.conv_12 = nn.Sequential(nn.Conv2d(filters[5], output_channels, | |
| kernel_size=1, stride=1, | |
| padding=0, bias=True, padding_mode='zeros'), | |
| #torchvision.transforms.Grayscale(num_output_channels=3), | |
| nn.Sigmoid()) | |
| else: | |
| self.conv_12 = nn.Conv2d(filters[5], output_channels, kernel_size=1, stride=1, | |
| padding=0, bias=True, padding_mode='zeros') | |
| def log_tensors(self, logger, tag, img_tensor): | |
| logger.experiment.add_images(tag, img_tensor) | |
| def forward(self, input, logger=None, **kwargs): | |
| # [1, 3, 534, 800] | |
| output_d0 = self.conv0(input) | |
| output_d1 = self.conv1(output_d0) | |
| # comment to disable skip-connections | |
| output_d2 = self.conv2(output_d1) | |
| output = output_d2 | |
| for layer in self.resnets: | |
| output = layer(output) + output | |
| output_u2 = self.upconv2(torch.cat((output, output_d2), dim=1)) | |
| output_u1 = self.upconv1(torch.cat((output_u2, output_d1), dim=1)) | |
| output = torch.cat( | |
| (output_u1, output_d0, input), dim=1) | |
| output_11 = self.conv_11(output) | |
| if self.append_smoothers: | |
| output_11_a = self.conv_11_a(output_11) | |
| else: | |
| output_11_a = output_11 | |
| output_12 = self.conv_12(output_11_a) | |
| output = output_12 | |
| return output | |
| def relu_layer(self, in_filters, out_filters, kernel_size, stride, padding, bias, | |
| norm_layer, nonlinearity): | |
| out = nn.Sequential() | |
| out.add_module('conv', nn.Conv2d(in_channels=in_filters, | |
| out_channels=out_filters, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, bias=bias, padding_mode='zeros')) | |
| if norm_layer: | |
| out.add_module('normalization', | |
| norm_layer(num_features=out_filters)) | |
| if nonlinearity: | |
| out.add_module('nonlinearity', nonlinearity) | |
| # out.add_module('dropout', nn.Dropout2d(0.25)) | |
| return out | |
| def resnet_block(self, in_filters, out_filters, kernel_size, stride, padding, bias, | |
| norm_layer, nonlinearity): | |
| out = nn.Sequential() | |
| if nonlinearity: | |
| out.add_module('nonlinearity_0', nonlinearity) | |
| out.add_module('conv_0', nn.Conv2d(in_channels=in_filters, | |
| out_channels=out_filters, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, bias=bias, padding_mode='zeros')) | |
| if norm_layer: | |
| out.add_module('normalization', | |
| norm_layer(num_features=out_filters)) | |
| if nonlinearity: | |
| out.add_module('nonlinearity_1', nonlinearity) | |
| out.add_module('conv_1', nn.Conv2d(in_channels=in_filters, | |
| out_channels=out_filters, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, bias=bias, padding_mode='zeros')) | |
| return out | |
| def upconv_layer_upsample_and_conv(self, in_filters, out_filters, scale_factor, kernel_size, stride, padding, bias, | |
| norm_layer, nonlinearity): | |
| parts = [nn.Upsample(scale_factor=scale_factor), | |
| nn.Conv2d(in_filters, out_filters, kernel_size, | |
| stride, padding=padding, bias=False, padding_mode='zeros') | |
| ] | |
| if norm_layer: | |
| parts.append(norm_layer(num_features=out_filters)) | |
| if nonlinearity: | |
| parts.append(nonlinearity) | |
| return nn.Sequential(*parts) | |
| relu = grelu | |
| ##### | |
| # Default discriminator | |
| ##### | |
| relu = nn.LeakyReLU(0.2) | |
| class Discriminator(nn.Module): | |
| def __init__(self, num_filters=12, input_channels=3, n_layers=2, | |
| norm_layer='instance_norm', use_bias=True): | |
| super().__init__() | |
| self.num_filters = num_filters | |
| self.input_channels = input_channels | |
| self.use_bias = use_bias | |
| if norm_layer == 'batch_norm': | |
| self.norm_layer = nn.BatchNorm2d | |
| else: | |
| self.norm_layer = nn.InstanceNorm2d | |
| self.net = self.make_net( | |
| n_layers, self.input_channels, 1, 4, 2, self.use_bias) | |
| def make_net(self, n, flt_in, flt_out=1, k=4, stride=2, bias=True): | |
| padding = 1 | |
| model = nn.Sequential() | |
| model.add_module('conv0', self.make_block( | |
| flt_in, self.num_filters, k, stride, padding, bias, None, relu)) | |
| flt_mult, flt_mult_prev = 1, 1 | |
| # n - 1 blocks | |
| for l in range(1, n): | |
| flt_mult_prev = flt_mult | |
| flt_mult = min(2**(l), 8) | |
| model.add_module('conv_%d' % (l), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult, | |
| k, stride, padding, bias, self.norm_layer, relu)) | |
| flt_mult_prev = flt_mult | |
| flt_mult = min(2**n, 8) | |
| model.add_module('conv_%d' % (n), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult, | |
| k, 1, padding, bias, self.norm_layer, relu)) | |
| model.add_module('conv_out', self.make_block( | |
| self.num_filters * flt_mult, 1, k, 1, padding, bias, None, None)) | |
| return model | |
| def make_block(self, flt_in, flt_out, k, stride, padding, bias, norm, relu): | |
| m = nn.Sequential() | |
| m.add_module('conv', nn.Conv2d(flt_in, flt_out, k, | |
| stride=stride, padding=padding, bias=bias, padding_mode='zeros')) | |
| if norm is not None: | |
| m.add_module('norm', norm(flt_out)) | |
| if relu is not None: | |
| m.add_module('relu', relu) | |
| return m | |
| def forward(self, x): | |
| output = self.net(x) | |
| # output = output.mean((2, 3), True) | |
| # output = output.squeeze(-1).squeeze(-1) | |
| # output = output.mean(dim=(-1,-2)) | |
| return output | |
| ##### | |
| # Perception VGG19 loss | |
| ##### | |
| class PerceptualVGG19(nn.Module): | |
| def __init__(self, feature_layers=[0, 3, 5], use_normalization=False): | |
| super().__init__() | |
| # model = models.vgg19(pretrained=True) | |
| model = models.squeezenet1_1(pretrained=True) | |
| model.float() | |
| model.eval() | |
| self.model = model | |
| self.feature_layers = feature_layers | |
| self.mean = torch.FloatTensor([0.485, 0.456, 0.406]) | |
| self.mean_tensor = None | |
| self.std = torch.FloatTensor([0.229, 0.224, 0.225]) | |
| self.std_tensor = None | |
| self.use_normalization = use_normalization | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def normalize(self, x): | |
| if not self.use_normalization: | |
| return x | |
| if self.mean_tensor is None: | |
| self.mean_tensor = Variable( | |
| self.mean.view(1, 3, 1, 1).expand(x.shape), | |
| requires_grad=False) | |
| self.std_tensor = Variable( | |
| self.std.view(1, 3, 1, 1).expand(x.shape), requires_grad=False) | |
| x = (x + 1) / 2 | |
| return (x - self.mean_tensor) / self.std_tensor | |
| def run(self, x): | |
| features = [] | |
| h = x | |
| for f in range(max(self.feature_layers) + 1): | |
| h = self.model.features[f](h) | |
| if f in self.feature_layers: | |
| not_normed_features = h.clone().view(h.size(0), -1) | |
| features.append(not_normed_features) | |
| return torch.cat(features, dim=1) | |
| def forward(self, x): | |
| h = self.normalize(x) | |
| return self.run(h) | |