Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- deepfillv2/LICENSE +21 -0
- deepfillv2/__init__.py +1 -0
- deepfillv2/network.py +666 -0
- deepfillv2/network_module.py +596 -0
- deepfillv2/network_utils.py +79 -0
- deepfillv2/test_dataset.py +47 -0
- deepfillv2/utils.py +145 -0
deepfillv2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Qiang Wen
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
deepfillv2/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
deepfillv2/network.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.init as init
|
| 4 |
+
import torchvision
|
| 5 |
+
|
| 6 |
+
from deepfillv2.network_module import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def weights_init(net, init_type="kaiming", init_gain=0.02):
|
| 10 |
+
"""Initialize network weights.
|
| 11 |
+
Parameters:
|
| 12 |
+
net (network) -- network to be initialized
|
| 13 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
| 14 |
+
init_var (float) -- scaling factor for normal, xavier and orthogonal.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def init_func(m):
|
| 18 |
+
classname = m.__class__.__name__
|
| 19 |
+
if hasattr(m, "weight") and classname.find("Conv") != -1:
|
| 20 |
+
if init_type == "normal":
|
| 21 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
| 22 |
+
elif init_type == "xavier":
|
| 23 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
| 24 |
+
elif init_type == "kaiming":
|
| 25 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
| 26 |
+
elif init_type == "orthogonal":
|
| 27 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
| 28 |
+
else:
|
| 29 |
+
raise NotImplementedError(
|
| 30 |
+
"initialization method [%s] is not implemented" % init_type
|
| 31 |
+
)
|
| 32 |
+
elif classname.find("BatchNorm2d") != -1:
|
| 33 |
+
init.normal_(m.weight.data, 1.0, 0.02)
|
| 34 |
+
init.constant_(m.bias.data, 0.0)
|
| 35 |
+
elif classname.find("Linear") != -1:
|
| 36 |
+
init.normal_(m.weight, 0, 0.01)
|
| 37 |
+
init.constant_(m.bias, 0)
|
| 38 |
+
|
| 39 |
+
# Apply the initialization function <init_func>
|
| 40 |
+
net.apply(init_func)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# -----------------------------------------------
|
| 44 |
+
# Generator
|
| 45 |
+
# -----------------------------------------------
|
| 46 |
+
# Input: masked image + mask
|
| 47 |
+
# Output: filled image
|
| 48 |
+
class GatedGenerator(nn.Module):
|
| 49 |
+
def __init__(self, opt):
|
| 50 |
+
super(GatedGenerator, self).__init__()
|
| 51 |
+
self.coarse = nn.Sequential(
|
| 52 |
+
# encoder
|
| 53 |
+
GatedConv2d(
|
| 54 |
+
opt.in_channels,
|
| 55 |
+
opt.latent_channels,
|
| 56 |
+
5,
|
| 57 |
+
1,
|
| 58 |
+
2,
|
| 59 |
+
pad_type=opt.pad_type,
|
| 60 |
+
activation=opt.activation,
|
| 61 |
+
norm=opt.norm,
|
| 62 |
+
),
|
| 63 |
+
GatedConv2d(
|
| 64 |
+
opt.latent_channels,
|
| 65 |
+
opt.latent_channels * 2,
|
| 66 |
+
3,
|
| 67 |
+
2,
|
| 68 |
+
1,
|
| 69 |
+
pad_type=opt.pad_type,
|
| 70 |
+
activation=opt.activation,
|
| 71 |
+
norm=opt.norm,
|
| 72 |
+
),
|
| 73 |
+
GatedConv2d(
|
| 74 |
+
opt.latent_channels * 2,
|
| 75 |
+
opt.latent_channels * 2,
|
| 76 |
+
3,
|
| 77 |
+
1,
|
| 78 |
+
1,
|
| 79 |
+
pad_type=opt.pad_type,
|
| 80 |
+
activation=opt.activation,
|
| 81 |
+
norm=opt.norm,
|
| 82 |
+
),
|
| 83 |
+
GatedConv2d(
|
| 84 |
+
opt.latent_channels * 2,
|
| 85 |
+
opt.latent_channels * 4,
|
| 86 |
+
3,
|
| 87 |
+
2,
|
| 88 |
+
1,
|
| 89 |
+
pad_type=opt.pad_type,
|
| 90 |
+
activation=opt.activation,
|
| 91 |
+
norm=opt.norm,
|
| 92 |
+
),
|
| 93 |
+
# Bottleneck
|
| 94 |
+
GatedConv2d(
|
| 95 |
+
opt.latent_channels * 4,
|
| 96 |
+
opt.latent_channels * 4,
|
| 97 |
+
3,
|
| 98 |
+
1,
|
| 99 |
+
1,
|
| 100 |
+
pad_type=opt.pad_type,
|
| 101 |
+
activation=opt.activation,
|
| 102 |
+
norm=opt.norm,
|
| 103 |
+
),
|
| 104 |
+
GatedConv2d(
|
| 105 |
+
opt.latent_channels * 4,
|
| 106 |
+
opt.latent_channels * 4,
|
| 107 |
+
3,
|
| 108 |
+
1,
|
| 109 |
+
1,
|
| 110 |
+
pad_type=opt.pad_type,
|
| 111 |
+
activation=opt.activation,
|
| 112 |
+
norm=opt.norm,
|
| 113 |
+
),
|
| 114 |
+
GatedConv2d(
|
| 115 |
+
opt.latent_channels * 4,
|
| 116 |
+
opt.latent_channels * 4,
|
| 117 |
+
3,
|
| 118 |
+
1,
|
| 119 |
+
2,
|
| 120 |
+
dilation=2,
|
| 121 |
+
pad_type=opt.pad_type,
|
| 122 |
+
activation=opt.activation,
|
| 123 |
+
norm=opt.norm,
|
| 124 |
+
),
|
| 125 |
+
GatedConv2d(
|
| 126 |
+
opt.latent_channels * 4,
|
| 127 |
+
opt.latent_channels * 4,
|
| 128 |
+
3,
|
| 129 |
+
1,
|
| 130 |
+
4,
|
| 131 |
+
dilation=4,
|
| 132 |
+
pad_type=opt.pad_type,
|
| 133 |
+
activation=opt.activation,
|
| 134 |
+
norm=opt.norm,
|
| 135 |
+
),
|
| 136 |
+
GatedConv2d(
|
| 137 |
+
opt.latent_channels * 4,
|
| 138 |
+
opt.latent_channels * 4,
|
| 139 |
+
3,
|
| 140 |
+
1,
|
| 141 |
+
8,
|
| 142 |
+
dilation=8,
|
| 143 |
+
pad_type=opt.pad_type,
|
| 144 |
+
activation=opt.activation,
|
| 145 |
+
norm=opt.norm,
|
| 146 |
+
),
|
| 147 |
+
GatedConv2d(
|
| 148 |
+
opt.latent_channels * 4,
|
| 149 |
+
opt.latent_channels * 4,
|
| 150 |
+
3,
|
| 151 |
+
1,
|
| 152 |
+
16,
|
| 153 |
+
dilation=16,
|
| 154 |
+
pad_type=opt.pad_type,
|
| 155 |
+
activation=opt.activation,
|
| 156 |
+
norm=opt.norm,
|
| 157 |
+
),
|
| 158 |
+
GatedConv2d(
|
| 159 |
+
opt.latent_channels * 4,
|
| 160 |
+
opt.latent_channels * 4,
|
| 161 |
+
3,
|
| 162 |
+
1,
|
| 163 |
+
1,
|
| 164 |
+
pad_type=opt.pad_type,
|
| 165 |
+
activation=opt.activation,
|
| 166 |
+
norm=opt.norm,
|
| 167 |
+
),
|
| 168 |
+
GatedConv2d(
|
| 169 |
+
opt.latent_channels * 4,
|
| 170 |
+
opt.latent_channels * 4,
|
| 171 |
+
3,
|
| 172 |
+
1,
|
| 173 |
+
1,
|
| 174 |
+
pad_type=opt.pad_type,
|
| 175 |
+
activation=opt.activation,
|
| 176 |
+
norm=opt.norm,
|
| 177 |
+
),
|
| 178 |
+
# decoder
|
| 179 |
+
TransposeGatedConv2d(
|
| 180 |
+
opt.latent_channels * 4,
|
| 181 |
+
opt.latent_channels * 2,
|
| 182 |
+
3,
|
| 183 |
+
1,
|
| 184 |
+
1,
|
| 185 |
+
pad_type=opt.pad_type,
|
| 186 |
+
activation=opt.activation,
|
| 187 |
+
norm=opt.norm,
|
| 188 |
+
),
|
| 189 |
+
GatedConv2d(
|
| 190 |
+
opt.latent_channels * 2,
|
| 191 |
+
opt.latent_channels * 2,
|
| 192 |
+
3,
|
| 193 |
+
1,
|
| 194 |
+
1,
|
| 195 |
+
pad_type=opt.pad_type,
|
| 196 |
+
activation=opt.activation,
|
| 197 |
+
norm=opt.norm,
|
| 198 |
+
),
|
| 199 |
+
TransposeGatedConv2d(
|
| 200 |
+
opt.latent_channels * 2,
|
| 201 |
+
opt.latent_channels,
|
| 202 |
+
3,
|
| 203 |
+
1,
|
| 204 |
+
1,
|
| 205 |
+
pad_type=opt.pad_type,
|
| 206 |
+
activation=opt.activation,
|
| 207 |
+
norm=opt.norm,
|
| 208 |
+
),
|
| 209 |
+
GatedConv2d(
|
| 210 |
+
opt.latent_channels,
|
| 211 |
+
opt.latent_channels // 2,
|
| 212 |
+
3,
|
| 213 |
+
1,
|
| 214 |
+
1,
|
| 215 |
+
pad_type=opt.pad_type,
|
| 216 |
+
activation=opt.activation,
|
| 217 |
+
norm=opt.norm,
|
| 218 |
+
),
|
| 219 |
+
GatedConv2d(
|
| 220 |
+
opt.latent_channels // 2,
|
| 221 |
+
opt.out_channels,
|
| 222 |
+
3,
|
| 223 |
+
1,
|
| 224 |
+
1,
|
| 225 |
+
pad_type=opt.pad_type,
|
| 226 |
+
activation="none",
|
| 227 |
+
norm=opt.norm,
|
| 228 |
+
),
|
| 229 |
+
nn.Tanh(),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.refine_conv = nn.Sequential(
|
| 233 |
+
GatedConv2d(
|
| 234 |
+
opt.in_channels,
|
| 235 |
+
opt.latent_channels,
|
| 236 |
+
5,
|
| 237 |
+
1,
|
| 238 |
+
2,
|
| 239 |
+
pad_type=opt.pad_type,
|
| 240 |
+
activation=opt.activation,
|
| 241 |
+
norm=opt.norm,
|
| 242 |
+
),
|
| 243 |
+
GatedConv2d(
|
| 244 |
+
opt.latent_channels,
|
| 245 |
+
opt.latent_channels,
|
| 246 |
+
3,
|
| 247 |
+
2,
|
| 248 |
+
1,
|
| 249 |
+
pad_type=opt.pad_type,
|
| 250 |
+
activation=opt.activation,
|
| 251 |
+
norm=opt.norm,
|
| 252 |
+
),
|
| 253 |
+
GatedConv2d(
|
| 254 |
+
opt.latent_channels,
|
| 255 |
+
opt.latent_channels * 2,
|
| 256 |
+
3,
|
| 257 |
+
1,
|
| 258 |
+
1,
|
| 259 |
+
pad_type=opt.pad_type,
|
| 260 |
+
activation=opt.activation,
|
| 261 |
+
norm=opt.norm,
|
| 262 |
+
),
|
| 263 |
+
GatedConv2d(
|
| 264 |
+
opt.latent_channels * 2,
|
| 265 |
+
opt.latent_channels * 2,
|
| 266 |
+
3,
|
| 267 |
+
2,
|
| 268 |
+
1,
|
| 269 |
+
pad_type=opt.pad_type,
|
| 270 |
+
activation=opt.activation,
|
| 271 |
+
norm=opt.norm,
|
| 272 |
+
),
|
| 273 |
+
GatedConv2d(
|
| 274 |
+
opt.latent_channels * 2,
|
| 275 |
+
opt.latent_channels * 4,
|
| 276 |
+
3,
|
| 277 |
+
1,
|
| 278 |
+
1,
|
| 279 |
+
pad_type=opt.pad_type,
|
| 280 |
+
activation=opt.activation,
|
| 281 |
+
norm=opt.norm,
|
| 282 |
+
),
|
| 283 |
+
GatedConv2d(
|
| 284 |
+
opt.latent_channels * 4,
|
| 285 |
+
opt.latent_channels * 4,
|
| 286 |
+
3,
|
| 287 |
+
1,
|
| 288 |
+
1,
|
| 289 |
+
pad_type=opt.pad_type,
|
| 290 |
+
activation=opt.activation,
|
| 291 |
+
norm=opt.norm,
|
| 292 |
+
),
|
| 293 |
+
GatedConv2d(
|
| 294 |
+
opt.latent_channels * 4,
|
| 295 |
+
opt.latent_channels * 4,
|
| 296 |
+
3,
|
| 297 |
+
1,
|
| 298 |
+
2,
|
| 299 |
+
dilation=2,
|
| 300 |
+
pad_type=opt.pad_type,
|
| 301 |
+
activation=opt.activation,
|
| 302 |
+
norm=opt.norm,
|
| 303 |
+
),
|
| 304 |
+
GatedConv2d(
|
| 305 |
+
opt.latent_channels * 4,
|
| 306 |
+
opt.latent_channels * 4,
|
| 307 |
+
3,
|
| 308 |
+
1,
|
| 309 |
+
4,
|
| 310 |
+
dilation=4,
|
| 311 |
+
pad_type=opt.pad_type,
|
| 312 |
+
activation=opt.activation,
|
| 313 |
+
norm=opt.norm,
|
| 314 |
+
),
|
| 315 |
+
GatedConv2d(
|
| 316 |
+
opt.latent_channels * 4,
|
| 317 |
+
opt.latent_channels * 4,
|
| 318 |
+
3,
|
| 319 |
+
1,
|
| 320 |
+
8,
|
| 321 |
+
dilation=8,
|
| 322 |
+
pad_type=opt.pad_type,
|
| 323 |
+
activation=opt.activation,
|
| 324 |
+
norm=opt.norm,
|
| 325 |
+
),
|
| 326 |
+
GatedConv2d(
|
| 327 |
+
opt.latent_channels * 4,
|
| 328 |
+
opt.latent_channels * 4,
|
| 329 |
+
3,
|
| 330 |
+
1,
|
| 331 |
+
16,
|
| 332 |
+
dilation=16,
|
| 333 |
+
pad_type=opt.pad_type,
|
| 334 |
+
activation=opt.activation,
|
| 335 |
+
norm=opt.norm,
|
| 336 |
+
),
|
| 337 |
+
)
|
| 338 |
+
self.refine_atten_1 = nn.Sequential(
|
| 339 |
+
GatedConv2d(
|
| 340 |
+
opt.in_channels,
|
| 341 |
+
opt.latent_channels,
|
| 342 |
+
5,
|
| 343 |
+
1,
|
| 344 |
+
2,
|
| 345 |
+
pad_type=opt.pad_type,
|
| 346 |
+
activation=opt.activation,
|
| 347 |
+
norm=opt.norm,
|
| 348 |
+
),
|
| 349 |
+
GatedConv2d(
|
| 350 |
+
opt.latent_channels,
|
| 351 |
+
opt.latent_channels,
|
| 352 |
+
3,
|
| 353 |
+
2,
|
| 354 |
+
1,
|
| 355 |
+
pad_type=opt.pad_type,
|
| 356 |
+
activation=opt.activation,
|
| 357 |
+
norm=opt.norm,
|
| 358 |
+
),
|
| 359 |
+
GatedConv2d(
|
| 360 |
+
opt.latent_channels,
|
| 361 |
+
opt.latent_channels * 2,
|
| 362 |
+
3,
|
| 363 |
+
1,
|
| 364 |
+
1,
|
| 365 |
+
pad_type=opt.pad_type,
|
| 366 |
+
activation=opt.activation,
|
| 367 |
+
norm=opt.norm,
|
| 368 |
+
),
|
| 369 |
+
GatedConv2d(
|
| 370 |
+
opt.latent_channels * 2,
|
| 371 |
+
opt.latent_channels * 4,
|
| 372 |
+
3,
|
| 373 |
+
2,
|
| 374 |
+
1,
|
| 375 |
+
pad_type=opt.pad_type,
|
| 376 |
+
activation=opt.activation,
|
| 377 |
+
norm=opt.norm,
|
| 378 |
+
),
|
| 379 |
+
GatedConv2d(
|
| 380 |
+
opt.latent_channels * 4,
|
| 381 |
+
opt.latent_channels * 4,
|
| 382 |
+
3,
|
| 383 |
+
1,
|
| 384 |
+
1,
|
| 385 |
+
pad_type=opt.pad_type,
|
| 386 |
+
activation=opt.activation,
|
| 387 |
+
norm=opt.norm,
|
| 388 |
+
),
|
| 389 |
+
GatedConv2d(
|
| 390 |
+
opt.latent_channels * 4,
|
| 391 |
+
opt.latent_channels * 4,
|
| 392 |
+
3,
|
| 393 |
+
1,
|
| 394 |
+
1,
|
| 395 |
+
pad_type=opt.pad_type,
|
| 396 |
+
activation="relu",
|
| 397 |
+
norm=opt.norm,
|
| 398 |
+
),
|
| 399 |
+
)
|
| 400 |
+
self.refine_atten_2 = nn.Sequential(
|
| 401 |
+
GatedConv2d(
|
| 402 |
+
opt.latent_channels * 4,
|
| 403 |
+
opt.latent_channels * 4,
|
| 404 |
+
3,
|
| 405 |
+
1,
|
| 406 |
+
1,
|
| 407 |
+
pad_type=opt.pad_type,
|
| 408 |
+
activation=opt.activation,
|
| 409 |
+
norm=opt.norm,
|
| 410 |
+
),
|
| 411 |
+
GatedConv2d(
|
| 412 |
+
opt.latent_channels * 4,
|
| 413 |
+
opt.latent_channels * 4,
|
| 414 |
+
3,
|
| 415 |
+
1,
|
| 416 |
+
1,
|
| 417 |
+
pad_type=opt.pad_type,
|
| 418 |
+
activation=opt.activation,
|
| 419 |
+
norm=opt.norm,
|
| 420 |
+
),
|
| 421 |
+
)
|
| 422 |
+
self.refine_combine = nn.Sequential(
|
| 423 |
+
GatedConv2d(
|
| 424 |
+
opt.latent_channels * 8,
|
| 425 |
+
opt.latent_channels * 4,
|
| 426 |
+
3,
|
| 427 |
+
1,
|
| 428 |
+
1,
|
| 429 |
+
pad_type=opt.pad_type,
|
| 430 |
+
activation=opt.activation,
|
| 431 |
+
norm=opt.norm,
|
| 432 |
+
),
|
| 433 |
+
GatedConv2d(
|
| 434 |
+
opt.latent_channels * 4,
|
| 435 |
+
opt.latent_channels * 4,
|
| 436 |
+
3,
|
| 437 |
+
1,
|
| 438 |
+
1,
|
| 439 |
+
pad_type=opt.pad_type,
|
| 440 |
+
activation=opt.activation,
|
| 441 |
+
norm=opt.norm,
|
| 442 |
+
),
|
| 443 |
+
TransposeGatedConv2d(
|
| 444 |
+
opt.latent_channels * 4,
|
| 445 |
+
opt.latent_channels * 2,
|
| 446 |
+
3,
|
| 447 |
+
1,
|
| 448 |
+
1,
|
| 449 |
+
pad_type=opt.pad_type,
|
| 450 |
+
activation=opt.activation,
|
| 451 |
+
norm=opt.norm,
|
| 452 |
+
),
|
| 453 |
+
GatedConv2d(
|
| 454 |
+
opt.latent_channels * 2,
|
| 455 |
+
opt.latent_channels * 2,
|
| 456 |
+
3,
|
| 457 |
+
1,
|
| 458 |
+
1,
|
| 459 |
+
pad_type=opt.pad_type,
|
| 460 |
+
activation=opt.activation,
|
| 461 |
+
norm=opt.norm,
|
| 462 |
+
),
|
| 463 |
+
TransposeGatedConv2d(
|
| 464 |
+
opt.latent_channels * 2,
|
| 465 |
+
opt.latent_channels,
|
| 466 |
+
3,
|
| 467 |
+
1,
|
| 468 |
+
1,
|
| 469 |
+
pad_type=opt.pad_type,
|
| 470 |
+
activation=opt.activation,
|
| 471 |
+
norm=opt.norm,
|
| 472 |
+
),
|
| 473 |
+
GatedConv2d(
|
| 474 |
+
opt.latent_channels,
|
| 475 |
+
opt.latent_channels // 2,
|
| 476 |
+
3,
|
| 477 |
+
1,
|
| 478 |
+
1,
|
| 479 |
+
pad_type=opt.pad_type,
|
| 480 |
+
activation=opt.activation,
|
| 481 |
+
norm=opt.norm,
|
| 482 |
+
),
|
| 483 |
+
GatedConv2d(
|
| 484 |
+
opt.latent_channels // 2,
|
| 485 |
+
opt.out_channels,
|
| 486 |
+
3,
|
| 487 |
+
1,
|
| 488 |
+
1,
|
| 489 |
+
pad_type=opt.pad_type,
|
| 490 |
+
activation="none",
|
| 491 |
+
norm=opt.norm,
|
| 492 |
+
),
|
| 493 |
+
nn.Tanh(),
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
use_cuda = opt.use_cuda
|
| 497 |
+
|
| 498 |
+
self.context_attention = ContextualAttention(
|
| 499 |
+
ksize=3,
|
| 500 |
+
stride=1,
|
| 501 |
+
rate=2,
|
| 502 |
+
fuse_k=3,
|
| 503 |
+
softmax_scale=10,
|
| 504 |
+
fuse=True,
|
| 505 |
+
use_cuda=use_cuda,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def forward(self, img, mask):
|
| 509 |
+
# img: entire img
|
| 510 |
+
# mask: 1 for mask region; 0 for unmask region
|
| 511 |
+
# Coarse
|
| 512 |
+
first_masked_img = img * (1 - mask) + mask
|
| 513 |
+
first_in = torch.cat(
|
| 514 |
+
(first_masked_img, mask), dim=1
|
| 515 |
+
) # in: [B, 4, H, W]
|
| 516 |
+
first_out = self.coarse(first_in) # out: [B, 3, H, W]
|
| 517 |
+
first_out = nn.functional.interpolate(
|
| 518 |
+
first_out,
|
| 519 |
+
(img.shape[2], img.shape[3]),
|
| 520 |
+
recompute_scale_factor=False,
|
| 521 |
+
)
|
| 522 |
+
# Refinement
|
| 523 |
+
second_masked_img = img * (1 - mask) + first_out * mask
|
| 524 |
+
second_in = torch.cat([second_masked_img, mask], dim=1)
|
| 525 |
+
refine_conv = self.refine_conv(second_in)
|
| 526 |
+
refine_atten = self.refine_atten_1(second_in)
|
| 527 |
+
mask_s = nn.functional.interpolate(
|
| 528 |
+
mask,
|
| 529 |
+
(refine_atten.shape[2], refine_atten.shape[3]),
|
| 530 |
+
recompute_scale_factor=False,
|
| 531 |
+
)
|
| 532 |
+
refine_atten = self.context_attention(
|
| 533 |
+
refine_atten, refine_atten, mask_s
|
| 534 |
+
)
|
| 535 |
+
refine_atten = self.refine_atten_2(refine_atten)
|
| 536 |
+
second_out = torch.cat([refine_conv, refine_atten], dim=1)
|
| 537 |
+
second_out = self.refine_combine(second_out)
|
| 538 |
+
second_out = nn.functional.interpolate(
|
| 539 |
+
second_out,
|
| 540 |
+
(img.shape[2], img.shape[3]),
|
| 541 |
+
recompute_scale_factor=False,
|
| 542 |
+
)
|
| 543 |
+
return first_out, second_out
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# -----------------------------------------------
|
| 547 |
+
# Discriminator
|
| 548 |
+
# -----------------------------------------------
|
| 549 |
+
# Input: generated image / ground truth and mask
|
| 550 |
+
# Output: patch based region, we set 30 * 30
|
| 551 |
+
class PatchDiscriminator(nn.Module):
|
| 552 |
+
def __init__(self, opt):
|
| 553 |
+
super(PatchDiscriminator, self).__init__()
|
| 554 |
+
# Down sampling
|
| 555 |
+
self.block1 = Conv2dLayer(
|
| 556 |
+
opt.in_channels,
|
| 557 |
+
opt.latent_channels,
|
| 558 |
+
7,
|
| 559 |
+
1,
|
| 560 |
+
3,
|
| 561 |
+
pad_type=opt.pad_type,
|
| 562 |
+
activation=opt.activation,
|
| 563 |
+
norm=opt.norm,
|
| 564 |
+
sn=True,
|
| 565 |
+
)
|
| 566 |
+
self.block2 = Conv2dLayer(
|
| 567 |
+
opt.latent_channels,
|
| 568 |
+
opt.latent_channels * 2,
|
| 569 |
+
4,
|
| 570 |
+
2,
|
| 571 |
+
1,
|
| 572 |
+
pad_type=opt.pad_type,
|
| 573 |
+
activation=opt.activation,
|
| 574 |
+
norm=opt.norm,
|
| 575 |
+
sn=True,
|
| 576 |
+
)
|
| 577 |
+
self.block3 = Conv2dLayer(
|
| 578 |
+
opt.latent_channels * 2,
|
| 579 |
+
opt.latent_channels * 4,
|
| 580 |
+
4,
|
| 581 |
+
2,
|
| 582 |
+
1,
|
| 583 |
+
pad_type=opt.pad_type,
|
| 584 |
+
activation=opt.activation,
|
| 585 |
+
norm=opt.norm,
|
| 586 |
+
sn=True,
|
| 587 |
+
)
|
| 588 |
+
self.block4 = Conv2dLayer(
|
| 589 |
+
opt.latent_channels * 4,
|
| 590 |
+
opt.latent_channels * 4,
|
| 591 |
+
4,
|
| 592 |
+
2,
|
| 593 |
+
1,
|
| 594 |
+
pad_type=opt.pad_type,
|
| 595 |
+
activation=opt.activation,
|
| 596 |
+
norm=opt.norm,
|
| 597 |
+
sn=True,
|
| 598 |
+
)
|
| 599 |
+
self.block5 = Conv2dLayer(
|
| 600 |
+
opt.latent_channels * 4,
|
| 601 |
+
opt.latent_channels * 4,
|
| 602 |
+
4,
|
| 603 |
+
2,
|
| 604 |
+
1,
|
| 605 |
+
pad_type=opt.pad_type,
|
| 606 |
+
activation=opt.activation,
|
| 607 |
+
norm=opt.norm,
|
| 608 |
+
sn=True,
|
| 609 |
+
)
|
| 610 |
+
self.block6 = Conv2dLayer(
|
| 611 |
+
opt.latent_channels * 4,
|
| 612 |
+
1,
|
| 613 |
+
4,
|
| 614 |
+
2,
|
| 615 |
+
1,
|
| 616 |
+
pad_type=opt.pad_type,
|
| 617 |
+
activation="none",
|
| 618 |
+
norm="none",
|
| 619 |
+
sn=True,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def forward(self, img, mask):
|
| 623 |
+
# the input x should contain 4 channels because it is a combination of recon image and mask
|
| 624 |
+
x = torch.cat((img, mask), 1)
|
| 625 |
+
x = self.block1(x) # out: [B, 64, 256, 256]
|
| 626 |
+
x = self.block2(x) # out: [B, 128, 128, 128]
|
| 627 |
+
x = self.block3(x) # out: [B, 256, 64, 64]
|
| 628 |
+
x = self.block4(x) # out: [B, 256, 32, 32]
|
| 629 |
+
x = self.block5(x) # out: [B, 256, 16, 16]
|
| 630 |
+
x = self.block6(x) # out: [B, 256, 8, 8]
|
| 631 |
+
return x
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ----------------------------------------
|
| 635 |
+
# Perceptual Network
|
| 636 |
+
# ----------------------------------------
|
| 637 |
+
# VGG-16 conv4_3 features
|
| 638 |
+
class PerceptualNet(nn.Module):
|
| 639 |
+
def __init__(self):
|
| 640 |
+
super(PerceptualNet, self).__init__()
|
| 641 |
+
block = [
|
| 642 |
+
torchvision.models.vgg16(pretrained=True).features[:15].eval()
|
| 643 |
+
]
|
| 644 |
+
for p in block[0]:
|
| 645 |
+
p.requires_grad = False
|
| 646 |
+
self.block = torch.nn.ModuleList(block)
|
| 647 |
+
self.transform = torch.nn.functional.interpolate
|
| 648 |
+
self.register_buffer(
|
| 649 |
+
"mean", torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 650 |
+
)
|
| 651 |
+
self.register_buffer(
|
| 652 |
+
"std", torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x):
|
| 656 |
+
x = (x - self.mean) / self.std
|
| 657 |
+
x = self.transform(
|
| 658 |
+
x,
|
| 659 |
+
mode="bilinear",
|
| 660 |
+
size=(224, 224),
|
| 661 |
+
align_corners=False,
|
| 662 |
+
recompute_scale_factor=False,
|
| 663 |
+
)
|
| 664 |
+
for block in self.block:
|
| 665 |
+
x = block(x)
|
| 666 |
+
return x
|
deepfillv2/network_module.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.nn import Parameter
|
| 5 |
+
|
| 6 |
+
from deepfillv2.network_utils import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# -----------------------------------------------
|
| 10 |
+
# Normal ConvBlock
|
| 11 |
+
# -----------------------------------------------
|
| 12 |
+
class Conv2dLayer(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_channels,
|
| 16 |
+
out_channels,
|
| 17 |
+
kernel_size,
|
| 18 |
+
stride=1,
|
| 19 |
+
padding=0,
|
| 20 |
+
dilation=1,
|
| 21 |
+
pad_type="zero",
|
| 22 |
+
activation="elu",
|
| 23 |
+
norm="none",
|
| 24 |
+
sn=False,
|
| 25 |
+
):
|
| 26 |
+
super(Conv2dLayer, self).__init__()
|
| 27 |
+
# Initialize the padding scheme
|
| 28 |
+
if pad_type == "reflect":
|
| 29 |
+
self.pad = nn.ReflectionPad2d(padding)
|
| 30 |
+
elif pad_type == "replicate":
|
| 31 |
+
self.pad = nn.ReplicationPad2d(padding)
|
| 32 |
+
elif pad_type == "zero":
|
| 33 |
+
self.pad = nn.ZeroPad2d(padding)
|
| 34 |
+
else:
|
| 35 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
| 36 |
+
|
| 37 |
+
# Initialize the normalization type
|
| 38 |
+
if norm == "bn":
|
| 39 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 40 |
+
elif norm == "in":
|
| 41 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
| 42 |
+
elif norm == "ln":
|
| 43 |
+
self.norm = LayerNorm(out_channels)
|
| 44 |
+
elif norm == "none":
|
| 45 |
+
self.norm = None
|
| 46 |
+
else:
|
| 47 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
| 48 |
+
|
| 49 |
+
# Initialize the activation funtion
|
| 50 |
+
if activation == "relu":
|
| 51 |
+
self.activation = nn.ReLU(inplace=True)
|
| 52 |
+
elif activation == "lrelu":
|
| 53 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
| 54 |
+
elif activation == "elu":
|
| 55 |
+
self.activation = nn.ELU(inplace=True)
|
| 56 |
+
elif activation == "selu":
|
| 57 |
+
self.activation = nn.SELU(inplace=True)
|
| 58 |
+
elif activation == "tanh":
|
| 59 |
+
self.activation = nn.Tanh()
|
| 60 |
+
elif activation == "sigmoid":
|
| 61 |
+
self.activation = nn.Sigmoid()
|
| 62 |
+
elif activation == "none":
|
| 63 |
+
self.activation = None
|
| 64 |
+
else:
|
| 65 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
| 66 |
+
|
| 67 |
+
# Initialize the convolution layers
|
| 68 |
+
if sn:
|
| 69 |
+
self.conv2d = SpectralNorm(
|
| 70 |
+
nn.Conv2d(
|
| 71 |
+
in_channels,
|
| 72 |
+
out_channels,
|
| 73 |
+
kernel_size,
|
| 74 |
+
stride,
|
| 75 |
+
padding=0,
|
| 76 |
+
dilation=dilation,
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
self.conv2d = nn.Conv2d(
|
| 81 |
+
in_channels,
|
| 82 |
+
out_channels,
|
| 83 |
+
kernel_size,
|
| 84 |
+
stride,
|
| 85 |
+
padding=0,
|
| 86 |
+
dilation=dilation,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = self.pad(x)
|
| 91 |
+
x = self.conv2d(x)
|
| 92 |
+
if self.norm:
|
| 93 |
+
x = self.norm(x)
|
| 94 |
+
if self.activation:
|
| 95 |
+
x = self.activation(x)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TransposeConv2dLayer(nn.Module):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
in_channels,
|
| 103 |
+
out_channels,
|
| 104 |
+
kernel_size,
|
| 105 |
+
stride=1,
|
| 106 |
+
padding=0,
|
| 107 |
+
dilation=1,
|
| 108 |
+
pad_type="zero",
|
| 109 |
+
activation="lrelu",
|
| 110 |
+
norm="none",
|
| 111 |
+
sn=False,
|
| 112 |
+
scale_factor=2,
|
| 113 |
+
):
|
| 114 |
+
super(TransposeConv2dLayer, self).__init__()
|
| 115 |
+
# Initialize the conv scheme
|
| 116 |
+
self.scale_factor = scale_factor
|
| 117 |
+
self.conv2d = Conv2dLayer(
|
| 118 |
+
in_channels,
|
| 119 |
+
out_channels,
|
| 120 |
+
kernel_size,
|
| 121 |
+
stride,
|
| 122 |
+
padding,
|
| 123 |
+
dilation,
|
| 124 |
+
pad_type,
|
| 125 |
+
activation,
|
| 126 |
+
norm,
|
| 127 |
+
sn,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
x = F.interpolate(
|
| 132 |
+
x,
|
| 133 |
+
scale_factor=self.scale_factor,
|
| 134 |
+
mode="nearest",
|
| 135 |
+
recompute_scale_factor=False,
|
| 136 |
+
)
|
| 137 |
+
x = self.conv2d(x)
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# -----------------------------------------------
|
| 142 |
+
# Gated ConvBlock
|
| 143 |
+
# -----------------------------------------------
|
| 144 |
+
class GatedConv2d(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channels,
|
| 148 |
+
out_channels,
|
| 149 |
+
kernel_size,
|
| 150 |
+
stride=1,
|
| 151 |
+
padding=0,
|
| 152 |
+
dilation=1,
|
| 153 |
+
pad_type="reflect",
|
| 154 |
+
activation="elu",
|
| 155 |
+
norm="none",
|
| 156 |
+
sn=False,
|
| 157 |
+
):
|
| 158 |
+
super(GatedConv2d, self).__init__()
|
| 159 |
+
# Initialize the padding scheme
|
| 160 |
+
if pad_type == "reflect":
|
| 161 |
+
self.pad = nn.ReflectionPad2d(padding)
|
| 162 |
+
elif pad_type == "replicate":
|
| 163 |
+
self.pad = nn.ReplicationPad2d(padding)
|
| 164 |
+
elif pad_type == "zero":
|
| 165 |
+
self.pad = nn.ZeroPad2d(padding)
|
| 166 |
+
else:
|
| 167 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
| 168 |
+
|
| 169 |
+
# Initialize the normalization type
|
| 170 |
+
if norm == "bn":
|
| 171 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 172 |
+
elif norm == "in":
|
| 173 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
| 174 |
+
elif norm == "ln":
|
| 175 |
+
self.norm = LayerNorm(out_channels)
|
| 176 |
+
elif norm == "none":
|
| 177 |
+
self.norm = None
|
| 178 |
+
else:
|
| 179 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
| 180 |
+
|
| 181 |
+
# Initialize the activation funtion
|
| 182 |
+
if activation == "relu":
|
| 183 |
+
self.activation = nn.ReLU(inplace=True)
|
| 184 |
+
elif activation == "lrelu":
|
| 185 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
| 186 |
+
elif activation == "elu":
|
| 187 |
+
self.activation = nn.ELU()
|
| 188 |
+
elif activation == "selu":
|
| 189 |
+
self.activation = nn.SELU(inplace=True)
|
| 190 |
+
elif activation == "tanh":
|
| 191 |
+
self.activation = nn.Tanh()
|
| 192 |
+
elif activation == "sigmoid":
|
| 193 |
+
self.activation = nn.Sigmoid()
|
| 194 |
+
elif activation == "none":
|
| 195 |
+
self.activation = None
|
| 196 |
+
else:
|
| 197 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
| 198 |
+
|
| 199 |
+
# Initialize the convolution layers
|
| 200 |
+
if sn:
|
| 201 |
+
self.conv2d = SpectralNorm(
|
| 202 |
+
nn.Conv2d(
|
| 203 |
+
in_channels,
|
| 204 |
+
out_channels,
|
| 205 |
+
kernel_size,
|
| 206 |
+
stride,
|
| 207 |
+
padding=0,
|
| 208 |
+
dilation=dilation,
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
self.mask_conv2d = SpectralNorm(
|
| 212 |
+
nn.Conv2d(
|
| 213 |
+
in_channels,
|
| 214 |
+
out_channels,
|
| 215 |
+
kernel_size,
|
| 216 |
+
stride,
|
| 217 |
+
padding=0,
|
| 218 |
+
dilation=dilation,
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
self.conv2d = nn.Conv2d(
|
| 223 |
+
in_channels,
|
| 224 |
+
out_channels,
|
| 225 |
+
kernel_size,
|
| 226 |
+
stride,
|
| 227 |
+
padding=0,
|
| 228 |
+
dilation=dilation,
|
| 229 |
+
)
|
| 230 |
+
self.mask_conv2d = nn.Conv2d(
|
| 231 |
+
in_channels,
|
| 232 |
+
out_channels,
|
| 233 |
+
kernel_size,
|
| 234 |
+
stride,
|
| 235 |
+
padding=0,
|
| 236 |
+
dilation=dilation,
|
| 237 |
+
)
|
| 238 |
+
self.sigmoid = torch.nn.Sigmoid()
|
| 239 |
+
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
x = self.pad(x)
|
| 242 |
+
conv = self.conv2d(x)
|
| 243 |
+
mask = self.mask_conv2d(x)
|
| 244 |
+
gated_mask = self.sigmoid(mask)
|
| 245 |
+
if self.activation:
|
| 246 |
+
conv = self.activation(conv)
|
| 247 |
+
x = conv * gated_mask
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class TransposeGatedConv2d(nn.Module):
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
in_channels,
|
| 255 |
+
out_channels,
|
| 256 |
+
kernel_size,
|
| 257 |
+
stride=1,
|
| 258 |
+
padding=0,
|
| 259 |
+
dilation=1,
|
| 260 |
+
pad_type="zero",
|
| 261 |
+
activation="lrelu",
|
| 262 |
+
norm="none",
|
| 263 |
+
sn=True,
|
| 264 |
+
scale_factor=2,
|
| 265 |
+
):
|
| 266 |
+
super(TransposeGatedConv2d, self).__init__()
|
| 267 |
+
# Initialize the conv scheme
|
| 268 |
+
self.scale_factor = scale_factor
|
| 269 |
+
self.gated_conv2d = GatedConv2d(
|
| 270 |
+
in_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size,
|
| 273 |
+
stride,
|
| 274 |
+
padding,
|
| 275 |
+
dilation,
|
| 276 |
+
pad_type,
|
| 277 |
+
activation,
|
| 278 |
+
norm,
|
| 279 |
+
sn,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
x = F.interpolate(
|
| 284 |
+
x,
|
| 285 |
+
scale_factor=self.scale_factor,
|
| 286 |
+
mode="nearest",
|
| 287 |
+
recompute_scale_factor=False,
|
| 288 |
+
)
|
| 289 |
+
x = self.gated_conv2d(x)
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ----------------------------------------
|
| 294 |
+
# Layer Norm
|
| 295 |
+
# ----------------------------------------
|
| 296 |
+
class LayerNorm(nn.Module):
|
| 297 |
+
def __init__(self, num_features, eps=1e-8, affine=True):
|
| 298 |
+
super(LayerNorm, self).__init__()
|
| 299 |
+
self.num_features = num_features
|
| 300 |
+
self.affine = affine
|
| 301 |
+
self.eps = eps
|
| 302 |
+
|
| 303 |
+
if self.affine:
|
| 304 |
+
self.gamma = Parameter(torch.Tensor(num_features).uniform_())
|
| 305 |
+
self.beta = Parameter(torch.zeros(num_features))
|
| 306 |
+
|
| 307 |
+
def forward(self, x):
|
| 308 |
+
# layer norm
|
| 309 |
+
shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
|
| 310 |
+
if x.size(0) == 1:
|
| 311 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
| 312 |
+
mean = x.view(-1).mean().view(*shape)
|
| 313 |
+
std = x.view(-1).std().view(*shape)
|
| 314 |
+
else:
|
| 315 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
| 316 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
| 317 |
+
x = (x - mean) / (std + self.eps)
|
| 318 |
+
# if it is learnable
|
| 319 |
+
if self.affine:
|
| 320 |
+
shape = [1, -1] + [1] * (
|
| 321 |
+
x.dim() - 2
|
| 322 |
+
) # for 4d input: [1, -1, 1, 1]
|
| 323 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# -----------------------------------------------
|
| 328 |
+
# SpectralNorm
|
| 329 |
+
# -----------------------------------------------
|
| 330 |
+
def l2normalize(v, eps=1e-12):
|
| 331 |
+
return v / (v.norm() + eps)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class SpectralNorm(nn.Module):
|
| 335 |
+
def __init__(self, module, name="weight", power_iterations=1):
|
| 336 |
+
super(SpectralNorm, self).__init__()
|
| 337 |
+
self.module = module
|
| 338 |
+
self.name = name
|
| 339 |
+
self.power_iterations = power_iterations
|
| 340 |
+
if not self._made_params():
|
| 341 |
+
self._make_params()
|
| 342 |
+
|
| 343 |
+
def _update_u_v(self):
|
| 344 |
+
u = getattr(self.module, self.name + "_u")
|
| 345 |
+
v = getattr(self.module, self.name + "_v")
|
| 346 |
+
w = getattr(self.module, self.name + "_bar")
|
| 347 |
+
|
| 348 |
+
height = w.data.shape[0]
|
| 349 |
+
for _ in range(self.power_iterations):
|
| 350 |
+
v.data = l2normalize(
|
| 351 |
+
torch.mv(torch.t(w.view(height, -1).data), u.data)
|
| 352 |
+
)
|
| 353 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
| 354 |
+
|
| 355 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 356 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 357 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
| 358 |
+
|
| 359 |
+
def _made_params(self):
|
| 360 |
+
try:
|
| 361 |
+
u = getattr(self.module, self.name + "_u")
|
| 362 |
+
v = getattr(self.module, self.name + "_v")
|
| 363 |
+
w = getattr(self.module, self.name + "_bar")
|
| 364 |
+
return True
|
| 365 |
+
except AttributeError:
|
| 366 |
+
return False
|
| 367 |
+
|
| 368 |
+
def _make_params(self):
|
| 369 |
+
w = getattr(self.module, self.name)
|
| 370 |
+
|
| 371 |
+
height = w.data.shape[0]
|
| 372 |
+
width = w.view(height, -1).data.shape[1]
|
| 373 |
+
|
| 374 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 375 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 376 |
+
u.data = l2normalize(u.data)
|
| 377 |
+
v.data = l2normalize(v.data)
|
| 378 |
+
w_bar = Parameter(w.data)
|
| 379 |
+
|
| 380 |
+
del self.module._parameters[self.name]
|
| 381 |
+
|
| 382 |
+
self.module.register_parameter(self.name + "_u", u)
|
| 383 |
+
self.module.register_parameter(self.name + "_v", v)
|
| 384 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
| 385 |
+
|
| 386 |
+
def forward(self, *args):
|
| 387 |
+
self._update_u_v()
|
| 388 |
+
return self.module.forward(*args)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class ContextualAttention(nn.Module):
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
ksize=3,
|
| 395 |
+
stride=1,
|
| 396 |
+
rate=1,
|
| 397 |
+
fuse_k=3,
|
| 398 |
+
softmax_scale=10,
|
| 399 |
+
fuse=True,
|
| 400 |
+
use_cuda=True,
|
| 401 |
+
device_ids=None,
|
| 402 |
+
):
|
| 403 |
+
super(ContextualAttention, self).__init__()
|
| 404 |
+
self.ksize = ksize
|
| 405 |
+
self.stride = stride
|
| 406 |
+
self.rate = rate
|
| 407 |
+
self.fuse_k = fuse_k
|
| 408 |
+
self.softmax_scale = softmax_scale
|
| 409 |
+
self.fuse = fuse
|
| 410 |
+
self.use_cuda = use_cuda
|
| 411 |
+
self.device_ids = device_ids
|
| 412 |
+
|
| 413 |
+
def forward(self, f, b, mask=None):
|
| 414 |
+
"""Contextual attention layer implementation.
|
| 415 |
+
Contextual attention is first introduced in publication:
|
| 416 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
| 417 |
+
Args:
|
| 418 |
+
f: Input feature to match (foreground).
|
| 419 |
+
b: Input feature for match (background).
|
| 420 |
+
mask: Input mask for b, indicating patches not available.
|
| 421 |
+
ksize: Kernel size for contextual attention.
|
| 422 |
+
stride: Stride for extracting patches from b.
|
| 423 |
+
rate: Dilation for matching.
|
| 424 |
+
softmax_scale: Scaled softmax for attention.
|
| 425 |
+
Returns:
|
| 426 |
+
torch.tensor: output
|
| 427 |
+
"""
|
| 428 |
+
# get shapes
|
| 429 |
+
raw_int_fs = list(f.size()) # b*c*h*w
|
| 430 |
+
raw_int_bs = list(b.size()) # b*c*h*w
|
| 431 |
+
|
| 432 |
+
# extract patches from background with stride and rate
|
| 433 |
+
kernel = 2 * self.rate
|
| 434 |
+
# raw_w is extracted for reconstruction
|
| 435 |
+
raw_w = extract_image_patches(
|
| 436 |
+
b,
|
| 437 |
+
ksizes=[kernel, kernel],
|
| 438 |
+
strides=[self.rate * self.stride, self.rate * self.stride],
|
| 439 |
+
rates=[1, 1],
|
| 440 |
+
padding="same",
|
| 441 |
+
) # [N, C*k*k, L]
|
| 442 |
+
# raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024]
|
| 443 |
+
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
|
| 444 |
+
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
|
| 445 |
+
raw_w_groups = torch.split(raw_w, 1, dim=0)
|
| 446 |
+
|
| 447 |
+
# downscaling foreground option: downscaling both foreground and
|
| 448 |
+
# background for matching and use original background for reconstruction.
|
| 449 |
+
f = F.interpolate(
|
| 450 |
+
f,
|
| 451 |
+
scale_factor=1.0 / self.rate,
|
| 452 |
+
mode="nearest",
|
| 453 |
+
recompute_scale_factor=False,
|
| 454 |
+
)
|
| 455 |
+
b = F.interpolate(
|
| 456 |
+
b,
|
| 457 |
+
scale_factor=1.0 / self.rate,
|
| 458 |
+
mode="nearest",
|
| 459 |
+
recompute_scale_factor=False,
|
| 460 |
+
)
|
| 461 |
+
int_fs = list(f.size()) # b*c*h*w
|
| 462 |
+
int_bs = list(b.size())
|
| 463 |
+
f_groups = torch.split(
|
| 464 |
+
f, 1, dim=0
|
| 465 |
+
) # split tensors along the batch dimension
|
| 466 |
+
# w shape: [N, C*k*k, L]
|
| 467 |
+
w = extract_image_patches(
|
| 468 |
+
b,
|
| 469 |
+
ksizes=[self.ksize, self.ksize],
|
| 470 |
+
strides=[self.stride, self.stride],
|
| 471 |
+
rates=[1, 1],
|
| 472 |
+
padding="same",
|
| 473 |
+
)
|
| 474 |
+
# w shape: [N, C, k, k, L]
|
| 475 |
+
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
|
| 476 |
+
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
|
| 477 |
+
w_groups = torch.split(w, 1, dim=0)
|
| 478 |
+
|
| 479 |
+
# process mask
|
| 480 |
+
mask = F.interpolate(
|
| 481 |
+
mask,
|
| 482 |
+
scale_factor=1.0 / self.rate,
|
| 483 |
+
mode="nearest",
|
| 484 |
+
recompute_scale_factor=False,
|
| 485 |
+
)
|
| 486 |
+
int_ms = list(mask.size())
|
| 487 |
+
# m shape: [N, C*k*k, L]
|
| 488 |
+
m = extract_image_patches(
|
| 489 |
+
mask,
|
| 490 |
+
ksizes=[self.ksize, self.ksize],
|
| 491 |
+
strides=[self.stride, self.stride],
|
| 492 |
+
rates=[1, 1],
|
| 493 |
+
padding="same",
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# m shape: [N, C, k, k, L]
|
| 497 |
+
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
|
| 498 |
+
m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
|
| 499 |
+
m = m[0] # m shape: [L, C, k, k]
|
| 500 |
+
# mm shape: [L, 1, 1, 1]
|
| 501 |
+
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to(
|
| 502 |
+
torch.float32
|
| 503 |
+
)
|
| 504 |
+
mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
|
| 505 |
+
|
| 506 |
+
y = []
|
| 507 |
+
offsets = []
|
| 508 |
+
k = self.fuse_k
|
| 509 |
+
scale = (
|
| 510 |
+
self.softmax_scale
|
| 511 |
+
) # to fit the PyTorch tensor image value range
|
| 512 |
+
fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
|
| 513 |
+
if self.use_cuda:
|
| 514 |
+
fuse_weight = fuse_weight.cuda()
|
| 515 |
+
|
| 516 |
+
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
|
| 517 |
+
"""
|
| 518 |
+
O => output channel as a conv filter
|
| 519 |
+
I => input channel as a conv filter
|
| 520 |
+
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
|
| 521 |
+
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
|
| 522 |
+
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
|
| 523 |
+
"""
|
| 524 |
+
# conv for compare
|
| 525 |
+
escape_NaN = torch.FloatTensor([1e-4])
|
| 526 |
+
if self.use_cuda:
|
| 527 |
+
escape_NaN = escape_NaN.cuda()
|
| 528 |
+
wi = wi[0] # [L, C, k, k]
|
| 529 |
+
max_wi = torch.sqrt(
|
| 530 |
+
reduce_sum(
|
| 531 |
+
torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
wi_normed = wi / max_wi
|
| 535 |
+
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
|
| 536 |
+
xi = same_padding(
|
| 537 |
+
xi, [self.ksize, self.ksize], [1, 1], [1, 1]
|
| 538 |
+
) # xi: 1*c*H*W
|
| 539 |
+
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
|
| 540 |
+
# conv implementation for fuse scores to encourage large patches
|
| 541 |
+
if self.fuse:
|
| 542 |
+
# make all of depth to spatial resolution
|
| 543 |
+
yi = yi.view(
|
| 544 |
+
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
| 545 |
+
) # (B=1, I=1, H=32*32, W=32*32)
|
| 546 |
+
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
| 547 |
+
yi = F.conv2d(
|
| 548 |
+
yi, fuse_weight, stride=1
|
| 549 |
+
) # (B=1, C=1, H=32*32, W=32*32)
|
| 550 |
+
yi = yi.contiguous().view(
|
| 551 |
+
1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]
|
| 552 |
+
) # (B=1, 32, 32, 32, 32)
|
| 553 |
+
yi = yi.permute(0, 2, 1, 4, 3)
|
| 554 |
+
yi = yi.contiguous().view(
|
| 555 |
+
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
| 556 |
+
)
|
| 557 |
+
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
| 558 |
+
yi = F.conv2d(yi, fuse_weight, stride=1)
|
| 559 |
+
yi = yi.contiguous().view(
|
| 560 |
+
1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]
|
| 561 |
+
)
|
| 562 |
+
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
|
| 563 |
+
yi = yi.view(
|
| 564 |
+
1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]
|
| 565 |
+
) # (B=1, C=32*32, H=32, W=32)
|
| 566 |
+
# softmax to match
|
| 567 |
+
yi = yi * mm
|
| 568 |
+
yi = F.softmax(yi * scale, dim=1)
|
| 569 |
+
yi = yi * mm # [1, L, H, W]
|
| 570 |
+
|
| 571 |
+
offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
|
| 572 |
+
|
| 573 |
+
if int_bs != int_fs:
|
| 574 |
+
# Normalize the offset value to match foreground dimension
|
| 575 |
+
times = float(int_fs[2] * int_fs[3]) / float(
|
| 576 |
+
int_bs[2] * int_bs[3]
|
| 577 |
+
)
|
| 578 |
+
offset = ((offset + 1).float() * times - 1).to(torch.int64)
|
| 579 |
+
offset = torch.cat(
|
| 580 |
+
[offset // int_fs[3], offset % int_fs[3]], dim=1
|
| 581 |
+
) # 1*2*H*W
|
| 582 |
+
|
| 583 |
+
# deconv for patch pasting
|
| 584 |
+
wi_center = raw_wi[0]
|
| 585 |
+
# yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
|
| 586 |
+
yi = (
|
| 587 |
+
F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1)
|
| 588 |
+
/ 4.0
|
| 589 |
+
) # (B=1, C=128, H=64, W=64)
|
| 590 |
+
y.append(yi)
|
| 591 |
+
offsets.append(offset)
|
| 592 |
+
|
| 593 |
+
y = torch.cat(y, dim=0) # back to the mini-batch
|
| 594 |
+
y.contiguous().view(raw_int_fs)
|
| 595 |
+
|
| 596 |
+
return y
|
deepfillv2/network_utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# for contextual attention
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def extract_image_patches(images, ksizes, strides, rates, padding="same"):
|
| 6 |
+
"""
|
| 7 |
+
Extract patches from images and put them in the C output dimension.
|
| 8 |
+
:param padding:
|
| 9 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
| 10 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
| 11 |
+
each dimension of images
|
| 12 |
+
:param strides: [stride_rows, stride_cols]
|
| 13 |
+
:param rates: [dilation_rows, dilation_cols]
|
| 14 |
+
:return: A Tensor
|
| 15 |
+
"""
|
| 16 |
+
assert len(images.size()) == 4
|
| 17 |
+
assert padding in ["same", "valid"]
|
| 18 |
+
batch_size, channel, height, width = images.size()
|
| 19 |
+
|
| 20 |
+
if padding == "same":
|
| 21 |
+
images = same_padding(images, ksizes, strides, rates)
|
| 22 |
+
elif padding == "valid":
|
| 23 |
+
pass
|
| 24 |
+
else:
|
| 25 |
+
raise NotImplementedError(
|
| 26 |
+
'Unsupported padding type: {}.\
|
| 27 |
+
Only "same" or "valid" are supported.'.format(
|
| 28 |
+
padding
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
unfold = torch.nn.Unfold(
|
| 33 |
+
kernel_size=ksizes, dilation=rates, padding=0, stride=strides
|
| 34 |
+
)
|
| 35 |
+
patches = unfold(images)
|
| 36 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def same_padding(images, ksizes, strides, rates):
|
| 40 |
+
assert len(images.size()) == 4
|
| 41 |
+
batch_size, channel, rows, cols = images.size()
|
| 42 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
| 43 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
| 44 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
| 45 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
| 46 |
+
padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows)
|
| 47 |
+
padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols)
|
| 48 |
+
# Pad the input
|
| 49 |
+
padding_top = int(padding_rows / 2.0)
|
| 50 |
+
padding_left = int(padding_cols / 2.0)
|
| 51 |
+
padding_bottom = padding_rows - padding_top
|
| 52 |
+
padding_right = padding_cols - padding_left
|
| 53 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
| 54 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
| 55 |
+
return images
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
| 59 |
+
if not axis:
|
| 60 |
+
axis = range(len(x.shape))
|
| 61 |
+
for i in sorted(axis, reverse=True):
|
| 62 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def reduce_std(x, axis=None, keepdim=False):
|
| 67 |
+
if not axis:
|
| 68 |
+
axis = range(len(x.shape))
|
| 69 |
+
for i in sorted(axis, reverse=True):
|
| 70 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
| 75 |
+
if not axis:
|
| 76 |
+
axis = range(len(x.shape))
|
| 77 |
+
for i in sorted(axis, reverse=True):
|
| 78 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
| 79 |
+
return x
|
deepfillv2/test_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
|
| 6 |
+
from config import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class InpaintDataset(Dataset):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.imglist = [INIMAGE]
|
| 12 |
+
self.masklist = [MASKIMAGE]
|
| 13 |
+
self.setsize = RESIZE_TO
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.imglist)
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, index):
|
| 19 |
+
# image
|
| 20 |
+
img = cv2.imread(self.imglist[index])
|
| 21 |
+
mask = cv2.imread(self.masklist[index])[:, :, 0]
|
| 22 |
+
## COMMENTING FOR NOW
|
| 23 |
+
# h, w = mask.shape
|
| 24 |
+
# # img = cv2.resize(img, (w, h))
|
| 25 |
+
img = cv2.resize(img, self.setsize)
|
| 26 |
+
mask = cv2.resize(mask, self.setsize)
|
| 27 |
+
##
|
| 28 |
+
# find the Minimum bounding rectangle in the mask
|
| 29 |
+
"""
|
| 30 |
+
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 31 |
+
for cidx, cnt in enumerate(contours):
|
| 32 |
+
(x, y, w, h) = cv2.boundingRect(cnt)
|
| 33 |
+
mask[y:y+h, x:x+w] = 255
|
| 34 |
+
"""
|
| 35 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 36 |
+
|
| 37 |
+
img = (
|
| 38 |
+
torch.from_numpy(img.astype(np.float32) / 255.0)
|
| 39 |
+
.permute(2, 0, 1)
|
| 40 |
+
.contiguous()
|
| 41 |
+
)
|
| 42 |
+
mask = (
|
| 43 |
+
torch.from_numpy(mask.astype(np.float32) / 255.0)
|
| 44 |
+
.unsqueeze(0)
|
| 45 |
+
.contiguous()
|
| 46 |
+
)
|
| 47 |
+
return img, mask
|
deepfillv2/utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
from deepfillv2 import network
|
| 6 |
+
import skimage
|
| 7 |
+
|
| 8 |
+
from config import GPU_DEVICE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ----------------------------------------
|
| 12 |
+
# Network
|
| 13 |
+
# ----------------------------------------
|
| 14 |
+
def create_generator(opt):
|
| 15 |
+
# Initialize the networks
|
| 16 |
+
generator = network.GatedGenerator(opt)
|
| 17 |
+
print("-- Generator is created! --")
|
| 18 |
+
network.weights_init(
|
| 19 |
+
generator, init_type=opt.init_type, init_gain=opt.init_gain
|
| 20 |
+
)
|
| 21 |
+
print("-- Initialized generator with %s type --" % opt.init_type)
|
| 22 |
+
return generator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_discriminator(opt):
|
| 26 |
+
# Initialize the networks
|
| 27 |
+
discriminator = network.PatchDiscriminator(opt)
|
| 28 |
+
print("-- Discriminator is created! --")
|
| 29 |
+
network.weights_init(
|
| 30 |
+
discriminator, init_type=opt.init_type, init_gain=opt.init_gain
|
| 31 |
+
)
|
| 32 |
+
print("-- Initialize discriminator with %s type --" % opt.init_type)
|
| 33 |
+
return discriminator
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_perceptualnet():
|
| 37 |
+
# Get the first 15 layers of vgg16, which is conv3_3
|
| 38 |
+
perceptualnet = network.PerceptualNet()
|
| 39 |
+
print("-- Perceptual network is created! --")
|
| 40 |
+
return perceptualnet
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ----------------------------------------
|
| 44 |
+
# PATH processing
|
| 45 |
+
# ----------------------------------------
|
| 46 |
+
def text_readlines(filename):
|
| 47 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
| 48 |
+
try:
|
| 49 |
+
file = open(filename, "r")
|
| 50 |
+
except IOError:
|
| 51 |
+
error = []
|
| 52 |
+
return error
|
| 53 |
+
content = file.readlines()
|
| 54 |
+
# This for loop deletes the EOF (like \n)
|
| 55 |
+
for i in range(len(content)):
|
| 56 |
+
content[i] = content[i][: len(content[i]) - 1]
|
| 57 |
+
file.close()
|
| 58 |
+
return content
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def savetxt(name, loss_log):
|
| 62 |
+
np_loss_log = np.array(loss_log)
|
| 63 |
+
np.savetxt(name, np_loss_log)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_files(path, mask=False):
|
| 67 |
+
# read a folder, return the complete path
|
| 68 |
+
ret = []
|
| 69 |
+
for root, dirs, files in os.walk(path):
|
| 70 |
+
for filespath in files:
|
| 71 |
+
if filespath != ".DS_Store":
|
| 72 |
+
continue
|
| 73 |
+
ret.append(os.path.join(root, filespath))
|
| 74 |
+
return ret
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_names(path):
|
| 78 |
+
# read a folder, return the image name
|
| 79 |
+
ret = []
|
| 80 |
+
for root, dirs, files in os.walk(path):
|
| 81 |
+
for filespath in files:
|
| 82 |
+
ret.append(filespath)
|
| 83 |
+
return ret
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def text_save(content, filename, mode="a"):
|
| 87 |
+
# save a list to a txt
|
| 88 |
+
# Try to save a list variable in txt file.
|
| 89 |
+
file = open(filename, mode)
|
| 90 |
+
for i in range(len(content)):
|
| 91 |
+
file.write(str(content[i]) + "\n")
|
| 92 |
+
file.close()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def check_path(path):
|
| 96 |
+
if not os.path.exists(path):
|
| 97 |
+
os.makedirs(path)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ----------------------------------------
|
| 101 |
+
# Validation and Sample at training
|
| 102 |
+
# ----------------------------------------
|
| 103 |
+
def save_sample_png(
|
| 104 |
+
sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255
|
| 105 |
+
):
|
| 106 |
+
# Save image one-by-one
|
| 107 |
+
for i in range(len(img_list)):
|
| 108 |
+
img = img_list[i]
|
| 109 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
| 110 |
+
img = img * 255
|
| 111 |
+
# Process img_copy and do not destroy the data of img
|
| 112 |
+
img_copy = (
|
| 113 |
+
img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
|
| 114 |
+
)
|
| 115 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
| 116 |
+
img_copy = img_copy.astype(np.uint8)
|
| 117 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
| 118 |
+
# Save to certain path
|
| 119 |
+
save_img_path = os.path.join(sample_folder, sample_name)
|
| 120 |
+
cv2.imwrite(save_img_path, img_copy)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def psnr(pred, target, pixel_max_cnt=255):
|
| 124 |
+
mse = torch.mul(target - pred, target - pred)
|
| 125 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 126 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
| 127 |
+
return p
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def grey_psnr(pred, target, pixel_max_cnt=255):
|
| 131 |
+
pred = torch.sum(pred, dim=0)
|
| 132 |
+
target = torch.sum(target, dim=0)
|
| 133 |
+
mse = torch.mul(target - pred, target - pred)
|
| 134 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
| 135 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
| 136 |
+
return p
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def ssim(pred, target):
|
| 140 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
| 141 |
+
target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
| 142 |
+
target = target[0]
|
| 143 |
+
pred = pred[0]
|
| 144 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
|
| 145 |
+
return ssim
|