463465810cz commited on
Commit
bd633b6
·
1 Parent(s): d0a0da7
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +56 -0
  3. VERSION +1 -0
  4. basicsr/__init__.py +6 -0
  5. basicsr/archs/__init__.py +25 -0
  6. basicsr/archs/arch_util.py +318 -0
  7. basicsr/archs/dat_arch.py +846 -0
  8. basicsr/data/__init__.py +101 -0
  9. basicsr/data/data_sampler.py +48 -0
  10. basicsr/data/data_util.py +283 -0
  11. basicsr/data/paired_image_dataset.py +135 -0
  12. basicsr/data/prefetch_dataloader.py +125 -0
  13. basicsr/data/transforms.py +179 -0
  14. basicsr/losses/__init__.py +26 -0
  15. basicsr/losses/loss_util.py +95 -0
  16. basicsr/losses/losses.py +492 -0
  17. basicsr/metrics/__init__.py +19 -0
  18. basicsr/metrics/metric_util.py +45 -0
  19. basicsr/metrics/psnr_ssim.py +128 -0
  20. basicsr/models/__init__.py +30 -0
  21. basicsr/models/base_model.py +380 -0
  22. basicsr/models/lr_scheduler.py +96 -0
  23. basicsr/models/sr_model.py +231 -0
  24. basicsr/test.py +44 -0
  25. basicsr/utils/__init__.py +30 -0
  26. basicsr/utils/dist_util.py +82 -0
  27. basicsr/utils/file_client.py +167 -0
  28. basicsr/utils/img_util.py +172 -0
  29. basicsr/utils/logger.py +213 -0
  30. basicsr/utils/matlab_functions.py +359 -0
  31. basicsr/utils/misc.py +141 -0
  32. basicsr/utils/options.py +194 -0
  33. basicsr/utils/registry.py +82 -0
  34. basicsr/version.py +5 -0
  35. datasets/README.md +2 -0
  36. experiments/README.md +2 -0
  37. experiments/pretrained_models/README.md +1 -0
  38. options/README.md +2 -0
  39. options/Test/test_DAT_2_x2.yml +93 -0
  40. options/Test/test_DAT_2_x3.yml +92 -0
  41. options/Test/test_DAT_2_x4.yml +93 -0
  42. options/Test/test_DAT_L_x2.yml +93 -0
  43. options/Test/test_DAT_L_x3.yml +92 -0
  44. options/Test/test_DAT_L_x4.yml +93 -0
  45. options/Test/test_DAT_x2.yml +93 -0
  46. options/Test/test_DAT_x3.yml +92 -0
  47. options/Test/test_DAT_x4.yml +93 -0
  48. requirements.txt +18 -0
  49. results/README.md +1 -0
  50. setup.py +166 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ .DS_Store
README.md CHANGED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dual Aggregation Transformer for Image Super-Resolution
2
+
3
+ This repository is for DAT introduced in the paper.
4
+
5
+ ## Dependencies
6
+
7
+ - Python 3.8
8
+ - pytorch >= 1.8.0
9
+ - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
10
+
11
+ ```bash
12
+ # Cd to the default directory 'DAT'
13
+ pip install -r requirements.txt
14
+ python setup.py develop
15
+ ```
16
+
17
+ ## TODO
18
+
19
+ * [x] Classic Image SR
20
+ * [ ] More Image SR: Lightweight Image SR, Blind Image SR, Real-World Image SR, ...
21
+
22
+ ## Test
23
+
24
+ - Download the pre-trained [models](https://ufile.io/4u0ms0h5) and place them in `experiments/pretrained_models/`.
25
+
26
+ We provide all models: DAT, DAT-L, and DAT-2 (x2, x3, x4).
27
+
28
+ - Download [testing](https://ufile.io/6ek67nf8) (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in `datasets/`.
29
+
30
+ - Run the folloing scripts. The testing configuration is in `options/Test/`. More detail about YML, please refer to [Configuration](https://github.com/XPixelGroup/BasicSR/blob/master/docs/Config.md).
31
+
32
+ **You can change the testing configuration in YML file, like 'test_DAT_x2.yml'.**
33
+
34
+ ```shell
35
+ # No self-ensemble
36
+ # DAT, reproduces results in Table 2 of the main paper
37
+ python basicsr/test.py -opt options/Test/test_DAT_x2.yml
38
+ python basicsr/test.py -opt options/Test/test_DAT_x3.yml
39
+ python basicsr/test.py -opt options/Test/test_DAT_x3.yml
40
+
41
+ # DAT-L, reproduces results in Table 2 of the main paper
42
+ python basicsr/test.py -opt options/Test/test_DAT_L_x2.yml
43
+ python basicsr/test.py -opt options/Test/test_DAT_L_x3.yml
44
+ python basicsr/test.py -opt options/Test/test_DAT_L_x3.yml
45
+
46
+ # DAT-L, reproduces results in Table 1 of the supplementary material
47
+ python basicsr/test.py -opt options/Test/test_DAT_2_x2.yml
48
+ python basicsr/test.py -opt options/Test/test_DAT_2_x3.yml
49
+ python basicsr/test.py -opt options/Test/test_DAT_2_x3.yml
50
+ ```
51
+
52
+ - The output is in `results`.
53
+
54
+ ## Acknowledgements
55
+
56
+ This code is built on [BasicSR](https://github.com/XPixelGroup/BasicSR).
VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.5
basicsr/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .archs import *
2
+ from .data import *
3
+ from .metrics import *
4
+ from .models import *
5
+ from .test import *
6
+ from .utils import *
basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ # from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
+ super(Upsample, self).__init__(*m)
115
+
116
+
117
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
+ """Warp an image or feature map with optical flow.
119
+
120
+ Args:
121
+ x (Tensor): Tensor with size (n, c, h, w).
122
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
+ Default: 'zeros'.
126
+ align_corners (bool): Before pytorch 1.3, the default value is
127
+ align_corners=True. After pytorch 1.3, the default value is
128
+ align_corners=False. Here, we use the True as default.
129
+
130
+ Returns:
131
+ Tensor: Warped image or feature map.
132
+ """
133
+ assert x.size()[-2:] == flow.size()[1:3]
134
+ _, _, h, w = x.size()
135
+ # create mesh grid
136
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
+
147
+ # TODO, what if align_corners=False
148
+ return output
149
+
150
+
151
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
+ """Resize a flow according to ratio or shape.
153
+
154
+ Args:
155
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
+ size_type (str): 'ratio' or 'shape'.
157
+ sizes (list[int | float]): the ratio for resizing or the final output
158
+ shape.
159
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
160
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
+ ratio > 1.0).
163
+ 2) The order of output_size should be [out_h, out_w].
164
+ interp_mode (str): The mode of interpolation for resizing.
165
+ Default: 'bilinear'.
166
+ align_corners (bool): Whether align corners. Default: False.
167
+
168
+ Returns:
169
+ Tensor: Resized flow.
170
+ """
171
+ _, _, flow_h, flow_w = flow.size()
172
+ if size_type == 'ratio':
173
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
+ elif size_type == 'shape':
175
+ output_h, output_w = sizes[0], sizes[1]
176
+ else:
177
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
+
179
+ input_flow = flow.clone()
180
+ ratio_h = output_h / flow_h
181
+ ratio_w = output_w / flow_w
182
+ input_flow[:, 0, :, :] *= ratio_w
183
+ input_flow[:, 1, :, :] *= ratio_h
184
+ resized_flow = F.interpolate(
185
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
+ return resized_flow
187
+
188
+
189
+ # TODO: may write a cpp file
190
+ def pixel_unshuffle(x, scale):
191
+ """ Pixel unshuffle.
192
+
193
+ Args:
194
+ x (Tensor): Input feature with shape (b, c, hh, hw).
195
+ scale (int): Downsample ratio.
196
+
197
+ Returns:
198
+ Tensor: the pixel unshuffled feature.
199
+ """
200
+ b, c, hh, hw = x.size()
201
+ out_channel = c * (scale**2)
202
+ assert hh % scale == 0 and hw % scale == 0
203
+ h = hh // scale
204
+ w = hw // scale
205
+ x_view = x.view(b, c, h, scale, w, scale)
206
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
+
208
+
209
+ # class DCNv2Pack(ModulatedDeformConvPack):
210
+ # """Modulated deformable conv for deformable alignment.
211
+ #
212
+ # Different from the official DCNv2Pack, which generates offsets and masks
213
+ # from the preceding features, this DCNv2Pack takes another different
214
+ # features to generate offsets and masks.
215
+ #
216
+ # Ref:
217
+ # Delving Deep into Deformable Alignment in Video Super-Resolution.
218
+ # """
219
+ #
220
+ # def forward(self, x, feat):
221
+ # out = self.conv_offset(feat)
222
+ # o1, o2, mask = torch.chunk(out, 3, dim=1)
223
+ # offset = torch.cat((o1, o2), dim=1)
224
+ # mask = torch.sigmoid(mask)
225
+ #
226
+ # offset_absmean = torch.mean(torch.abs(offset))
227
+ # if offset_absmean > 50:
228
+ # logger = get_root_logger()
229
+ # logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
+ #
231
+ # if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
+ # return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
+ # self.dilation, mask)
234
+ # else:
235
+ # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
+ # self.dilation, self.groups, self.deformable_groups)
237
+
238
+
239
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
+ def norm_cdf(x):
244
+ # Computes standard normal cumulative distribution function
245
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
+
247
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
248
+ warnings.warn(
249
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
+ 'The distribution of values may be incorrect.',
251
+ stacklevel=2)
252
+
253
+ with torch.no_grad():
254
+ # Values are generated by using a truncated uniform distribution and
255
+ # then using the inverse CDF for the normal distribution.
256
+ # Get upper and lower cdf values
257
+ low = norm_cdf((a - mean) / std)
258
+ up = norm_cdf((b - mean) / std)
259
+
260
+ # Uniformly fill tensor with values from [low, up], then translate to
261
+ # [2l-1, 2u-1].
262
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
263
+
264
+ # Use inverse cdf transform for normal distribution to get truncated
265
+ # standard normal
266
+ tensor.erfinv_()
267
+
268
+ # Transform to proper mean, std
269
+ tensor.mul_(std * math.sqrt(2.))
270
+ tensor.add_(mean)
271
+
272
+ # Clamp to ensure it's in the proper range
273
+ tensor.clamp_(min=a, max=b)
274
+ return tensor
275
+
276
+
277
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
+ r"""Fills the input Tensor with values drawn from a truncated
279
+ normal distribution.
280
+
281
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
+
283
+ The values are effectively drawn from the
284
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
+ with values outside :math:`[a, b]` redrawn until they are within
286
+ the bounds. The method used for generating the random values works
287
+ best when :math:`a \leq \text{mean} \leq b`.
288
+
289
+ Args:
290
+ tensor: an n-dimensional `torch.Tensor`
291
+ mean: the mean of the normal distribution
292
+ std: the standard deviation of the normal distribution
293
+ a: the minimum cutoff value
294
+ b: the maximum cutoff value
295
+
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ # From PyTorch
304
+ def _ntuple(n):
305
+
306
+ def parse(x):
307
+ if isinstance(x, collections.abc.Iterable):
308
+ return x
309
+ return tuple(repeat(x, n))
310
+
311
+ return parse
312
+
313
+
314
+ to_1tuple = _ntuple(1)
315
+ to_2tuple = _ntuple(2)
316
+ to_3tuple = _ntuple(3)
317
+ to_4tuple = _ntuple(4)
318
+ to_ntuple = _ntuple
basicsr/archs/dat_arch.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as checkpoint
4
+ from torch import Tensor
5
+ from torch.nn import functional as F
6
+
7
+ from timm.models.layers import DropPath, trunc_normal_
8
+ from einops.layers.torch import Rearrange
9
+ from einops import rearrange
10
+
11
+ import math
12
+ import numpy as np
13
+
14
+ from basicsr.utils.registry import ARCH_REGISTRY
15
+
16
+
17
+ def img2windows(img, H_sp, W_sp):
18
+ """
19
+ Input: Image (B, C, H, W)
20
+ Output: Window Partition (B', N, C)
21
+ """
22
+ B, C, H, W = img.shape
23
+ img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
24
+ img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)
25
+ return img_perm
26
+
27
+
28
+ def windows2img(img_splits_hw, H_sp, W_sp, H, W):
29
+ """
30
+ Input: Window Partition (B', N, C)
31
+ Output: Image (B, H, W, C)
32
+ """
33
+ B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
34
+
35
+ img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
36
+ img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
37
+ return img
38
+
39
+
40
+ class SpatialGate(nn.Module):
41
+ """ Spatial-Gate.
42
+ Args:
43
+ dim (int): Half of input channels.
44
+ """
45
+ def __init__(self, dim):
46
+ super().__init__()
47
+ self.norm = nn.LayerNorm(dim)
48
+ self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv
49
+
50
+ def forward(self, x, H, W):
51
+ # Split
52
+ x1, x2 = x.chunk(2, dim = -1)
53
+ B, N, C = x.shape
54
+ x2 = self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C//2, H, W)).flatten(2).transpose(-1, -2).contiguous()
55
+
56
+ return x1 * x2
57
+
58
+
59
+ class SGFN(nn.Module):
60
+ """ Spatial-Gate Feed-Forward Network.
61
+ Args:
62
+ in_features (int): Number of input channels.
63
+ hidden_features (int | None): Number of hidden channels. Default: None
64
+ out_features (int | None): Number of output channels. Default: None
65
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
66
+ drop (float): Dropout rate. Default: 0.0
67
+ """
68
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
69
+ super().__init__()
70
+ out_features = out_features or in_features
71
+ hidden_features = hidden_features or in_features
72
+ self.fc1 = nn.Linear(in_features, hidden_features)
73
+ self.act = act_layer()
74
+ self.sg = SpatialGate(hidden_features//2)
75
+ self.fc2 = nn.Linear(hidden_features//2, out_features)
76
+ self.drop = nn.Dropout(drop)
77
+
78
+ def forward(self, x, H, W):
79
+ """
80
+ Input: x: (B, H*W, C), H, W
81
+ Output: x: (B, H*W, C)
82
+ """
83
+ x = self.fc1(x)
84
+ x = self.act(x)
85
+ x = self.drop(x)
86
+
87
+ x = self.sg(x, H, W)
88
+ x = self.drop(x)
89
+
90
+ x = self.fc2(x)
91
+ x = self.drop(x)
92
+ return x
93
+
94
+
95
+ class DynamicPosBias(nn.Module):
96
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
97
+ """ Dynamic Relative Position Bias.
98
+ Args:
99
+ dim (int): Number of input channels.
100
+ num_heads (int): Number of attention heads.
101
+ residual (bool): If True, use residual strage to connect conv.
102
+ """
103
+ def __init__(self, dim, num_heads, residual):
104
+ super().__init__()
105
+ self.residual = residual
106
+ self.num_heads = num_heads
107
+ self.pos_dim = dim // 4
108
+ self.pos_proj = nn.Linear(2, self.pos_dim)
109
+ self.pos1 = nn.Sequential(
110
+ nn.LayerNorm(self.pos_dim),
111
+ nn.ReLU(inplace=True),
112
+ nn.Linear(self.pos_dim, self.pos_dim),
113
+ )
114
+ self.pos2 = nn.Sequential(
115
+ nn.LayerNorm(self.pos_dim),
116
+ nn.ReLU(inplace=True),
117
+ nn.Linear(self.pos_dim, self.pos_dim)
118
+ )
119
+ self.pos3 = nn.Sequential(
120
+ nn.LayerNorm(self.pos_dim),
121
+ nn.ReLU(inplace=True),
122
+ nn.Linear(self.pos_dim, self.num_heads)
123
+ )
124
+ def forward(self, biases):
125
+ if self.residual:
126
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
127
+ pos = pos + self.pos1(pos)
128
+ pos = pos + self.pos2(pos)
129
+ pos = self.pos3(pos)
130
+ else:
131
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
132
+ return pos
133
+
134
+
135
+ class Spatial_Attention(nn.Module):
136
+ """ Spatial Window Self-Attention.
137
+ It supports rectangle window (containing square window).
138
+ Args:
139
+ dim (int): Number of input channels.
140
+ idx (int): The indentix of different shape window.
141
+ split_size (tuple(int)): Height or Width of spatial window.
142
+ dim_out (int | None): The dimension of the attention output. Default: None
143
+ num_heads (int): Number of attention heads. Default: 6
144
+ attn_drop (float): Dropout ratio of attention weight. Default: 0.0
145
+ proj_drop (float): Dropout ratio of output. Default: 0.0
146
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set
147
+ position_bias (bool): The dynamic relative position bias. Default: True
148
+ """
149
+ def __init__(self, dim, idx, split_size=[8,8], dim_out=None, num_heads=6, attn_drop=0., proj_drop=0., qk_scale=None, position_bias=True):
150
+ super().__init__()
151
+ self.dim = dim
152
+ self.dim_out = dim_out or dim
153
+ self.split_size = split_size
154
+ self.num_heads = num_heads
155
+ self.idx = idx
156
+ self.position_bias = position_bias
157
+
158
+ head_dim = dim // num_heads
159
+ self.scale = qk_scale or head_dim ** -0.5
160
+
161
+ if idx == 0:
162
+ H_sp, W_sp = self.split_size[0], self.split_size[1]
163
+ elif idx == 1:
164
+ W_sp, H_sp = self.split_size[0], self.split_size[1]
165
+ else:
166
+ print ("ERROR MODE", idx)
167
+ exit(0)
168
+ self.H_sp = H_sp
169
+ self.W_sp = W_sp
170
+
171
+ if self.position_bias:
172
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
173
+ # generate mother-set
174
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp)
175
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp)
176
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
177
+ biases = biases.flatten(1).transpose(0, 1).contiguous().float()
178
+ self.register_buffer('rpe_biases', biases)
179
+
180
+ # get pair-wise relative position index for each token inside the window
181
+ coords_h = torch.arange(self.H_sp)
182
+ coords_w = torch.arange(self.W_sp)
183
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
184
+ coords_flatten = torch.flatten(coords, 1)
185
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
186
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
187
+ relative_coords[:, :, 0] += self.H_sp - 1
188
+ relative_coords[:, :, 1] += self.W_sp - 1
189
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
190
+ relative_position_index = relative_coords.sum(-1)
191
+ self.register_buffer('relative_position_index', relative_position_index)
192
+
193
+ self.attn_drop = nn.Dropout(attn_drop)
194
+
195
+ def im2win(self, x, H, W):
196
+ B, N, C = x.shape
197
+ x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
198
+ x = img2windows(x, self.H_sp, self.W_sp)
199
+ x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
200
+ return x
201
+
202
+ def forward(self, qkv, H, W, mask=None):
203
+ """
204
+ Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size
205
+ Output: x (B, H, W, C)
206
+ """
207
+ q,k,v = qkv[0], qkv[1], qkv[2]
208
+
209
+ B, L, C = q.shape
210
+ assert L == H * W, "flatten img_tokens has wrong size"
211
+
212
+ # partition the q,k,v, image to window
213
+ q = self.im2win(q, H, W)
214
+ k = self.im2win(k, H, W)
215
+ v = self.im2win(v, H, W)
216
+
217
+ q = q * self.scale
218
+ attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N
219
+
220
+ # calculate drpe
221
+ if self.position_bias:
222
+ pos = self.pos(self.rpe_biases)
223
+ # select position bias
224
+ relative_position_bias = pos[self.relative_position_index.view(-1)].view(
225
+ self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1)
226
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
227
+ attn = attn + relative_position_bias.unsqueeze(0)
228
+
229
+ N = attn.shape[3]
230
+
231
+ # use mask for shift window
232
+ if mask is not None:
233
+ nW = mask.shape[0]
234
+ attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
235
+ attn = attn.view(-1, self.num_heads, N, N)
236
+
237
+ attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
238
+ attn = self.attn_drop(attn)
239
+
240
+ x = (attn @ v)
241
+ x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C
242
+
243
+ # merge the window, window to image
244
+ x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C
245
+
246
+ return x
247
+
248
+
249
+ class Axial_Spatial_Attention(nn.Module):
250
+ """ Axial Spatial Self-Attention
251
+ Args:
252
+ dim (int): Number of input channels.
253
+ num_heads (int): Number of attention heads. Default: 6
254
+ split_size (tuple(int)): Height and Width of spatial window.
255
+ shift_size (tuple(int)): Shift size for spatial window.
256
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
257
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
258
+ drop (float): Dropout rate. Default: 0.0
259
+ attn_drop (float): Attention dropout rate. Default: 0.0
260
+ rg_idx (int): The indentix of Residual Group (RG)
261
+ b_idx (int): The indentix of Block in each RG
262
+ """
263
+ def __init__(self, dim, num_heads,
264
+ reso=64, split_size=[8,8], shift_size=[1,2], qkv_bias=False, qk_scale=None,
265
+ drop=0., attn_drop=0., rg_idx=0, b_idx=0):
266
+ super().__init__()
267
+ self.dim = dim
268
+ self.num_heads = num_heads
269
+ self.split_size = split_size
270
+ self.shift_size = shift_size
271
+ self.b_idx = b_idx
272
+ self.rg_idx = rg_idx
273
+ self.patches_resolution = reso
274
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
275
+
276
+ assert 0 <= self.shift_size[0] < self.split_size[0], "shift_size must in 0-split_size0"
277
+ assert 0 <= self.shift_size[1] < self.split_size[1], "shift_size must in 0-split_size1"
278
+
279
+ self.branch_num = 2
280
+
281
+ self.proj = nn.Linear(dim, dim)
282
+ self.proj_drop = nn.Dropout(drop)
283
+
284
+ self.attns = nn.ModuleList([
285
+ Spatial_Attention(
286
+ dim//2, idx = i,
287
+ split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
288
+ qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, position_bias=True)
289
+ for i in range(self.branch_num)])
290
+
291
+ if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (self.rg_idx % 2 != 0 and self.b_idx % 4 == 0):
292
+ attn_mask = self.calculate_mask(self.patches_resolution, self.patches_resolution)
293
+ self.register_buffer("attn_mask_0", attn_mask[0])
294
+ self.register_buffer("attn_mask_1", attn_mask[1])
295
+ else:
296
+ attn_mask = None
297
+ self.register_buffer("attn_mask_0", None)
298
+ self.register_buffer("attn_mask_1", None)
299
+
300
+ # Adaptive Interaction Module
301
+ self.dwconv = nn.Sequential(
302
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
303
+ nn.BatchNorm2d(dim),
304
+ nn.GELU()
305
+ )
306
+ self.channel_interaction = nn.Sequential(
307
+ nn.AdaptiveAvgPool2d(1),
308
+ nn.Conv2d(dim, dim // 8, kernel_size=1),
309
+ nn.BatchNorm2d(dim // 8),
310
+ nn.GELU(),
311
+ nn.Conv2d(dim // 8, dim, kernel_size=1),
312
+ )
313
+ self.spatial_interaction = nn.Sequential(
314
+ nn.Conv2d(dim, dim // 16, kernel_size=1),
315
+ nn.BatchNorm2d(dim // 16),
316
+ nn.GELU(),
317
+ nn.Conv2d(dim // 16, 1, kernel_size=1)
318
+ )
319
+
320
+ def calculate_mask(self, H, W):
321
+ # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
322
+ # calculate attention mask for shift window
323
+ img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0
324
+ img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1
325
+ h_slices_0 = (slice(0, -self.split_size[0]),
326
+ slice(-self.split_size[0], -self.shift_size[0]),
327
+ slice(-self.shift_size[0], None))
328
+ w_slices_0 = (slice(0, -self.split_size[1]),
329
+ slice(-self.split_size[1], -self.shift_size[1]),
330
+ slice(-self.shift_size[1], None))
331
+
332
+ h_slices_1 = (slice(0, -self.split_size[1]),
333
+ slice(-self.split_size[1], -self.shift_size[1]),
334
+ slice(-self.shift_size[1], None))
335
+ w_slices_1 = (slice(0, -self.split_size[0]),
336
+ slice(-self.split_size[0], -self.shift_size[0]),
337
+ slice(-self.shift_size[0], None))
338
+ cnt = 0
339
+ for h in h_slices_0:
340
+ for w in w_slices_0:
341
+ img_mask_0[:, h, w, :] = cnt
342
+ cnt += 1
343
+ cnt = 0
344
+ for h in h_slices_1:
345
+ for w in w_slices_1:
346
+ img_mask_1[:, h, w, :] = cnt
347
+ cnt += 1
348
+
349
+ # calculate mask for window-0
350
+ img_mask_0 = img_mask_0.view(1, H // self.split_size[0], self.split_size[0], W // self.split_size[1], self.split_size[1], 1)
351
+ img_mask_0 = img_mask_0.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1], 1) # nW, sw[0], sw[1], 1
352
+ mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1])
353
+ attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2)
354
+ attn_mask_0 = attn_mask_0.masked_fill(attn_mask_0 != 0, float(-100.0)).masked_fill(attn_mask_0 == 0, float(0.0))
355
+
356
+ # calculate mask for window-1
357
+ img_mask_1 = img_mask_1.view(1, H // self.split_size[1], self.split_size[1], W // self.split_size[0], self.split_size[0], 1)
358
+ img_mask_1 = img_mask_1.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[1], self.split_size[0], 1) # nW, sw[1], sw[0], 1
359
+ mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0])
360
+ attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2)
361
+ attn_mask_1 = attn_mask_1.masked_fill(attn_mask_1 != 0, float(-100.0)).masked_fill(attn_mask_1 == 0, float(0.0))
362
+
363
+ return attn_mask_0, attn_mask_1
364
+
365
+ def forward(self, x, H, W):
366
+ """
367
+ Input: x: (B, H*W, C), H, W
368
+ Output: x: (B, H*W, C)
369
+ """
370
+ B, L, C = x.shape
371
+ assert L == H * W, "flatten img_tokens has wrong size"
372
+
373
+ qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C
374
+ # V without partition
375
+ v = qkv[2].transpose(-2,-1).contiguous().view(B, C, H, W)
376
+
377
+ # image padding
378
+ max_split_size = max(self.split_size[0], self.split_size[1])
379
+ pad_l = pad_t = 0
380
+ pad_r = (max_split_size - W % max_split_size) % max_split_size
381
+ pad_b = (max_split_size - H % max_split_size) % max_split_size
382
+
383
+ qkv = qkv.reshape(3*B, H, W, C).permute(0, 3, 1, 2) # 3B C H W
384
+ qkv = F.pad(qkv, (pad_l, pad_r, pad_t, pad_b)).reshape(3, B, C, -1).transpose(-2, -1) # l r t b
385
+ _H = pad_b + H
386
+ _W = pad_r + W
387
+ _L = _H * _W
388
+
389
+ # window-0 and window-1 on split channels [C/2, C/2]; for square windows (e.g., 8x8), window-0 and window-1 can be merged
390
+ # shift in block: (0, 4, 8, ...), (2, 6, 10, ...), (0, 4, 8, ...), (2, 6, 10, ...), ...
391
+ if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (self.rg_idx % 2 != 0 and self.b_idx % 4 == 0):
392
+ qkv = qkv.view(3, B, _H, _W, C)
393
+ qkv_0 = torch.roll(qkv[:,:,:,:,:C//2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3))
394
+ qkv_0 = qkv_0.view(3, B, _L, C//2)
395
+ qkv_1 = torch.roll(qkv[:,:,:,:,C//2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3))
396
+ qkv_1 = qkv_1.view(3, B, _L, C//2)
397
+
398
+ if self.patches_resolution != _H or self.patches_resolution != _W:
399
+ mask_tmp = self.calculate_mask(_H, _W)
400
+ x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
401
+ x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device))
402
+ else:
403
+ x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0)
404
+ x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1)
405
+
406
+ x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
407
+ x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2))
408
+ x1 = x1[:, :H, :W, :].reshape(B, L, C//2)
409
+ x2 = x2[:, :H, :W, :].reshape(B, L, C//2)
410
+ # attention output
411
+ attened_x = torch.cat([x1,x2], dim=2)
412
+
413
+ else:
414
+ x1 = self.attns[0](qkv[:,:,:,:C//2], _H, _W)[:, :H, :W, :].reshape(B, L, C//2)
415
+ x2 = self.attns[1](qkv[:,:,:,C//2:], _H, _W)[:, :H, :W, :].reshape(B, L, C//2)
416
+ # attention output
417
+ attened_x = torch.cat([x1,x2], dim=2)
418
+
419
+ # convolution output
420
+ conv_x = self.dwconv(v)
421
+
422
+ # C-Map (before sigmoid)
423
+ channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C)
424
+ # S-Map (before sigmoid)
425
+ attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
426
+ spatial_map = self.spatial_interaction(attention_reshape)
427
+
428
+ # C-I
429
+ attened_x = attened_x * torch.sigmoid(channel_map)
430
+ # S-I
431
+ conv_x = torch.sigmoid(spatial_map) * conv_x
432
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
433
+
434
+ x = attened_x + conv_x
435
+
436
+ x = self.proj(x)
437
+ x = self.proj_drop(x)
438
+
439
+ return x
440
+
441
+
442
+ class Axial_Channel_Attention(nn.Module):
443
+ # The implementation builds on XCiT code https://github.com/facebookresearch/xcit
444
+ """ Axial Channel Self-Attention
445
+ Args:
446
+ dim (int): Number of input channels.
447
+ num_heads (int): Number of attention heads. Default: 6
448
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
449
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
450
+ attn_drop (float): Attention dropout rate. Default: 0.0
451
+ drop_path (float): Stochastic depth rate. Default: 0.0
452
+ """
453
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
454
+ super().__init__()
455
+ self.num_heads = num_heads
456
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
457
+
458
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
459
+ self.attn_drop = nn.Dropout(attn_drop)
460
+ self.proj = nn.Linear(dim, dim)
461
+ self.proj_drop = nn.Dropout(proj_drop)
462
+
463
+ # Adaptive Interaction Module
464
+ self.dwconv = nn.Sequential(
465
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
466
+ nn.BatchNorm2d(dim),
467
+ nn.GELU()
468
+ )
469
+ self.channel_interaction = nn.Sequential(
470
+ nn.AdaptiveAvgPool2d(1),
471
+ nn.Conv2d(dim, dim // 8, kernel_size=1),
472
+ nn.BatchNorm2d(dim // 8),
473
+ nn.GELU(),
474
+ nn.Conv2d(dim // 8, dim, kernel_size=1),
475
+ )
476
+ self.spatial_interaction = nn.Sequential(
477
+ nn.Conv2d(dim, dim // 16, kernel_size=1),
478
+ nn.BatchNorm2d(dim // 16),
479
+ nn.GELU(),
480
+ nn.Conv2d(dim // 16, 1, kernel_size=1)
481
+ )
482
+
483
+ def forward(self, x, H, W):
484
+ """
485
+ Input: x: (B, H*W, C), H, W
486
+ Output: x: (B, H*W, C)
487
+ """
488
+ B, N, C = x.shape
489
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
490
+ qkv = qkv.permute(2, 0, 3, 1, 4)
491
+ q, k, v = qkv[0], qkv[1], qkv[2]
492
+
493
+ q = q.transpose(-2, -1)
494
+ k = k.transpose(-2, -1)
495
+ v = v.transpose(-2, -1)
496
+
497
+ v_ = v.reshape(B, C, N).contiguous().view(B, C, H, W)
498
+
499
+ q = torch.nn.functional.normalize(q, dim=-1)
500
+ k = torch.nn.functional.normalize(k, dim=-1)
501
+
502
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
503
+ attn = attn.softmax(dim=-1)
504
+ attn = self.attn_drop(attn)
505
+
506
+ # attention output
507
+ attened_x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
508
+
509
+ # convolution output
510
+ conv_x = self.dwconv(v_)
511
+
512
+ # C-Map (before sigmoid)
513
+ attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
514
+ channel_map = self.channel_interaction(attention_reshape)
515
+ # S-Map (before sigmoid)
516
+ spatial_map = self.spatial_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, N, 1)
517
+
518
+ # S-I
519
+ attened_x = attened_x * torch.sigmoid(spatial_map)
520
+ # C-I
521
+ conv_x = conv_x * torch.sigmoid(channel_map)
522
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, N, C)
523
+
524
+ x = attened_x + conv_x
525
+
526
+ x = self.proj(x)
527
+ x = self.proj_drop(x)
528
+
529
+ return x
530
+
531
+
532
+ class DATB(nn.Module):
533
+ def __init__(self, dim, num_heads, reso=64, split_size=[2,4],shift_size=[1,2], expansion_factor=4., qkv_bias=False, qk_scale=None, drop=0.,
534
+ attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rg_idx=0, b_idx=0):
535
+ super().__init__()
536
+
537
+ self.norm1 = norm_layer(dim)
538
+
539
+ if b_idx % 2 == 0:
540
+ # DSTB
541
+ self.attn = Axial_Spatial_Attention(
542
+ dim, num_heads=num_heads, reso=reso, split_size=split_size, shift_size=shift_size, qkv_bias=qkv_bias, qk_scale=qk_scale,
543
+ drop=drop, attn_drop=attn_drop, rg_idx=rg_idx, b_idx=b_idx
544
+ )
545
+ else:
546
+ # DCTB
547
+ self.attn = Axial_Channel_Attention(
548
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
549
+ proj_drop=drop
550
+ )
551
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
552
+
553
+ ffn_hidden_dim = int(dim * expansion_factor)
554
+ self.ffn = SGFN(in_features=dim, hidden_features=ffn_hidden_dim, out_features=dim, act_layer=act_layer)
555
+ self.norm2 = norm_layer(dim)
556
+
557
+ def forward(self, x, x_size):
558
+ """
559
+ Input: x: (B, H*W, C), x_size: (H, W)
560
+ Output: x: (B, H*W, C)
561
+ """
562
+ H , W = x_size
563
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
564
+ x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
565
+
566
+ return x
567
+
568
+
569
+ class ResidualGroup(nn.Module):
570
+ """ ResidualGroup
571
+ Args:
572
+ dim (int): Number of input channels.
573
+ reso (int): Input resolution.
574
+ num_heads (int): Number of attention heads.
575
+ split_size (tuple(int)): Height and Width of spatial window.
576
+ expansion_factor (float): Ratio of ffn hidden dim to embedding dim.
577
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
578
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
579
+ drop (float): Dropout rate. Default: 0
580
+ attn_drop(float): Attention dropout rate. Default: 0
581
+ drop_paths (float | None): Stochastic depth rate.
582
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
583
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
584
+ depth (int): Number of Cross Aggregation Transformer blocks in residual group.
585
+ use_chk (bool): Whether to use checkpointing to save memory.
586
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
587
+ """
588
+ def __init__( self,
589
+ dim,
590
+ reso,
591
+ num_heads,
592
+ split_size=[2,4],
593
+ expansion_factor=4.,
594
+ qkv_bias=False,
595
+ qk_scale=None,
596
+ drop=0.,
597
+ attn_drop=0.,
598
+ drop_paths=None,
599
+ act_layer=nn.GELU,
600
+ norm_layer=nn.LayerNorm,
601
+ depth=2,
602
+ use_chk=False,
603
+ resi_connection='1conv',
604
+ rg_idx=0):
605
+ super().__init__()
606
+ self.use_chk = use_chk
607
+ self.reso = reso
608
+
609
+ self.blocks = nn.ModuleList([
610
+ DATB(
611
+ dim=dim,
612
+ num_heads=num_heads,
613
+ reso = reso,
614
+ split_size = split_size,
615
+ shift_size = [split_size[0]//2, split_size[1]//2],
616
+ expansion_factor=expansion_factor,
617
+ qkv_bias=qkv_bias,
618
+ qk_scale=qk_scale,
619
+ drop=drop,
620
+ attn_drop=attn_drop,
621
+ drop_path=drop_paths[i],
622
+ act_layer=act_layer,
623
+ norm_layer=norm_layer,
624
+ rg_idx = rg_idx,
625
+ b_idx = i,
626
+ )for i in range(depth)])
627
+
628
+ if resi_connection == '1conv':
629
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
630
+ elif resi_connection == '3conv':
631
+ self.conv = nn.Sequential(
632
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
633
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
634
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
635
+
636
+ def forward(self, x, x_size):
637
+ """
638
+ Input: x: (B, H*W, C), x_size: (H, W)
639
+ Output: x: (B, H*W, C)
640
+ """
641
+ H, W = x_size
642
+ res = x
643
+ for blk in self.blocks:
644
+ if self.use_chk:
645
+ x = checkpoint.checkpoint(blk, x, x_size)
646
+ else:
647
+ x = blk(x, x_size)
648
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
649
+ x = self.conv(x)
650
+ x = rearrange(x, "b c h w -> b (h w) c")
651
+ x = res + x
652
+
653
+ return x
654
+
655
+
656
+ class Upsample(nn.Sequential):
657
+ """Upsample module.
658
+ Args:
659
+ scale (int): Scale factor. Supported scales: 2^n and 3.
660
+ num_feat (int): Channel number of intermediate features.
661
+ """
662
+ def __init__(self, scale, num_feat):
663
+ m = []
664
+ if (scale & (scale - 1)) == 0: # scale = 2^n
665
+ for _ in range(int(math.log(scale, 2))):
666
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
667
+ m.append(nn.PixelShuffle(2))
668
+ elif scale == 3:
669
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
670
+ m.append(nn.PixelShuffle(3))
671
+ else:
672
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
673
+ super(Upsample, self).__init__(*m)
674
+
675
+
676
+ @ARCH_REGISTRY.register()
677
+ class DAT(nn.Module):
678
+ """ Dual Aggregation Transformer
679
+ Args:
680
+ img_size (int): Input image size. Default: 64
681
+ in_chans (int): Number of input image channels. Default: 3
682
+ embed_dim (int): Patch embedding dimension. Default: 180
683
+ depths (tuple(int)): Depth of each residual group (number of DATB in each RG).
684
+ split_size (tuple(int)): Height and Width of spatial window.
685
+ num_heads (tuple(int)): Number of attention heads in different residual groups.
686
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
687
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
688
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
689
+ drop_rate (float): Dropout rate. Default: 0
690
+ attn_drop_rate (float): Attention dropout rate. Default: 0
691
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
692
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
693
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
694
+ use_chk (bool): Whether to use checkpointing to save memory.
695
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for compress artifact reduction
696
+ img_range: Image range. 1. or 255.
697
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
+ """
699
+ def __init__(self,
700
+ img_size=64,
701
+ in_chans=3,
702
+ embed_dim=180,
703
+ split_size=[2,4],
704
+ depth=[2,2,2,2],
705
+ num_heads=[2,2,2,2],
706
+ expansion_factor=4.,
707
+ qkv_bias=True,
708
+ qk_scale=None,
709
+ drop_rate=0.,
710
+ attn_drop_rate=0.,
711
+ drop_path_rate=0.1,
712
+ act_layer=nn.GELU,
713
+ norm_layer=nn.LayerNorm,
714
+ use_chk=False,
715
+ upscale=2,
716
+ img_range=1.,
717
+ resi_connection='1conv',
718
+ **kwargs):
719
+ super().__init__()
720
+
721
+ num_in_ch = in_chans
722
+ num_out_ch = in_chans
723
+ num_feat = 64
724
+ self.img_range = img_range
725
+ if in_chans == 3:
726
+ rgb_mean = (0.4488, 0.4371, 0.4040)
727
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
728
+ else:
729
+ self.mean = torch.zeros(1, 1, 1, 1)
730
+ self.upscale = upscale
731
+
732
+ # ------------------------- 1, Shallow Feature Extraction ------------------------- #
733
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
734
+
735
+ # ------------------------- 2, Deep Feature Extraction ------------------------- #
736
+ self.num_layers = len(depth)
737
+ self.use_chk = use_chk
738
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
739
+ heads=num_heads
740
+
741
+ self.before_RG = nn.Sequential(
742
+ Rearrange('b c h w -> b (h w) c'),
743
+ nn.LayerNorm(embed_dim)
744
+ )
745
+
746
+ curr_dim = embed_dim
747
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule
748
+
749
+ self.layers = nn.ModuleList()
750
+ for i in range(self.num_layers):
751
+ layer = ResidualGroup(
752
+ dim=embed_dim,
753
+ num_heads=heads[i],
754
+ reso=img_size,
755
+ split_size=split_size,
756
+ expansion_factor=expansion_factor,
757
+ qkv_bias=qkv_bias,
758
+ qk_scale=qk_scale,
759
+ drop=drop_rate,
760
+ attn_drop=attn_drop_rate,
761
+ drop_paths=dpr[sum(depth[:i]):sum(depth[:i + 1])],
762
+ act_layer=act_layer,
763
+ norm_layer=norm_layer,
764
+ depth=depth[i],
765
+ use_chk=use_chk,
766
+ resi_connection=resi_connection,
767
+ rg_idx=i)
768
+ self.layers.append(layer)
769
+
770
+ self.norm = norm_layer(curr_dim)
771
+ # build the last conv layer in deep feature extraction
772
+ if resi_connection == '1conv':
773
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
774
+ elif resi_connection == '3conv':
775
+ # to save parameters and memory
776
+ self.conv_after_body = nn.Sequential(
777
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
778
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
779
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
780
+
781
+ # ------------------------- 3, Reconstruction ------------------------- #
782
+ self.conv_before_upsample = nn.Sequential(
783
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
784
+ self.upsample = Upsample(upscale, num_feat)
785
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
786
+
787
+ self.apply(self._init_weights)
788
+
789
+ def _init_weights(self, m):
790
+ if isinstance(m, nn.Linear):
791
+ trunc_normal_(m.weight, std=.02)
792
+ if isinstance(m, nn.Linear) and m.bias is not None:
793
+ nn.init.constant_(m.bias, 0)
794
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)):
795
+ nn.init.constant_(m.bias, 0)
796
+ nn.init.constant_(m.weight, 1.0)
797
+
798
+ def forward_features(self, x):
799
+ _, _, H, W = x.shape
800
+ x_size = [H, W]
801
+ x = self.before_RG(x)
802
+ for layer in self.layers:
803
+ x = layer(x, x_size)
804
+ x = self.norm(x)
805
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
806
+
807
+ return x
808
+
809
+ def forward(self, x):
810
+ """
811
+ Input: x: (B, C, H, W)
812
+ """
813
+ self.mean = self.mean.type_as(x)
814
+ x = (x - self.mean) * self.img_range
815
+
816
+ x = self.conv_first(x)
817
+ x = self.conv_after_body(self.forward_features(x)) + x
818
+ x = self.conv_before_upsample(x)
819
+ x = self.conv_last(self.upsample(x))
820
+
821
+ x = x / self.img_range + self.mean
822
+ return x
823
+
824
+
825
+ if __name__ == '__main__':
826
+ upscale = 1
827
+ height = 64
828
+ width = 64
829
+ model = DAT(
830
+ upscale=2,
831
+ in_chans=3,
832
+ img_size=64,
833
+ img_range=1.,
834
+ depth=[6,6,6,6,6,6],
835
+ embed_dim=180,
836
+ num_heads=[6,6,6,6,6,6],
837
+ mlp_ratio=2,
838
+ resi_connection='1conv',
839
+ split_size=[8,16],
840
+ ).cuda().eval()
841
+ print(model)
842
+ print(height, width)
843
+
844
+ x = torch.randn((1, 3, height, width)).cuda()
845
+ x = model(x)
846
+ print(x.shape)
basicsr/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.utils import img2tensor, scandir
8
+
9
+
10
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
11
+ """Generate an index list for reading `num_frames` frames from a sequence
12
+ of images.
13
+
14
+ Args:
15
+ crt_idx (int): Current center index.
16
+ max_frame_num (int): Max number of the sequence of images (from 1).
17
+ num_frames (int): Reading num_frames frames.
18
+ padding (str): Padding mode, one of
19
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
20
+ Examples: current_idx = 0, num_frames = 5
21
+ The generated frame indices under different padding mode:
22
+ replicate: [0, 0, 0, 1, 2]
23
+ reflection: [2, 1, 0, 1, 2]
24
+ reflection_circle: [4, 3, 0, 1, 2]
25
+ circle: [3, 4, 0, 1, 2]
26
+
27
+ Returns:
28
+ list[int]: A list of indices.
29
+ """
30
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
31
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
32
+
33
+ max_frame_num = max_frame_num - 1 # start from 0
34
+ num_pad = num_frames // 2
35
+
36
+ indices = []
37
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
38
+ if i < 0:
39
+ if padding == 'replicate':
40
+ pad_idx = 0
41
+ elif padding == 'reflection':
42
+ pad_idx = -i
43
+ elif padding == 'reflection_circle':
44
+ pad_idx = crt_idx + num_pad - i
45
+ else:
46
+ pad_idx = num_frames + i
47
+ elif i > max_frame_num:
48
+ if padding == 'replicate':
49
+ pad_idx = max_frame_num
50
+ elif padding == 'reflection':
51
+ pad_idx = max_frame_num * 2 - i
52
+ elif padding == 'reflection_circle':
53
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
54
+ else:
55
+ pad_idx = i - num_frames
56
+ else:
57
+ pad_idx = i
58
+ indices.append(pad_idx)
59
+ return indices
60
+
61
+
62
+ def paired_paths_from_lmdb(folders, keys):
63
+ """Generate paired paths from lmdb files.
64
+
65
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
66
+
67
+ lq.lmdb
68
+ ├── data.mdb
69
+ ├── lock.mdb
70
+ ├── meta_info.txt
71
+
72
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
73
+ https://lmdb.readthedocs.io/en/release/ for more details.
74
+
75
+ The meta_info.txt is a specified txt file to record the meta information
76
+ of our datasets. It will be automatically created when preparing
77
+ datasets by our provided dataset tools.
78
+ Each line in the txt file records
79
+ 1)image name (with extension),
80
+ 2)image shape,
81
+ 3)compression level, separated by a white space.
82
+ Example: `baboon.png (120,125,3) 1`
83
+
84
+ We use the image name without extension as the lmdb key.
85
+ Note that we use the same key for the corresponding lq and gt images.
86
+
87
+ Args:
88
+ folders (list[str]): A list of folder path. The order of list should
89
+ be [input_folder, gt_folder].
90
+ keys (list[str]): A list of keys identifying folders. The order should
91
+ be in consistent with folders, e.g., ['lq', 'gt'].
92
+ Note that this key is different from lmdb keys.
93
+
94
+ Returns:
95
+ list[str]: Returned path list.
96
+ """
97
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
98
+ f'But got {len(folders)}')
99
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
100
+ input_folder, gt_folder = folders
101
+ input_key, gt_key = keys
102
+
103
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
104
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
105
+ f'formats. But received {input_key}: {input_folder}; '
106
+ f'{gt_key}: {gt_folder}')
107
+ # ensure that the two meta_info files are the same
108
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
109
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
110
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
111
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
112
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
113
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
114
+ else:
115
+ paths = []
116
+ for lmdb_key in sorted(input_lmdb_keys):
117
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
118
+ return paths
119
+
120
+
121
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
122
+ """Generate paired paths from an meta information file.
123
+
124
+ Each line in the meta information file contains the image names and
125
+ image shape (usually for gt), separated by a white space.
126
+
127
+ Example of an meta information file:
128
+ ```
129
+ 0001_s001.png (480,480,3)
130
+ 0001_s002.png (480,480,3)
131
+ ```
132
+
133
+ Args:
134
+ folders (list[str]): A list of folder path. The order of list should
135
+ be [input_folder, gt_folder].
136
+ keys (list[str]): A list of keys identifying folders. The order should
137
+ be in consistent with folders, e.g., ['lq', 'gt'].
138
+ meta_info_file (str): Path to the meta information file.
139
+ filename_tmpl (str): Template for each filename. Note that the
140
+ template excludes the file extension. Usually the filename_tmpl is
141
+ for files in the input folder.
142
+
143
+ Returns:
144
+ list[str]: Returned path list.
145
+ """
146
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
147
+ f'But got {len(folders)}')
148
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
149
+ input_folder, gt_folder = folders
150
+ input_key, gt_key = keys
151
+
152
+ with open(meta_info_file, 'r') as fin:
153
+ gt_names = [line.strip().split(' ')[0] for line in fin]
154
+
155
+ paths = []
156
+ for gt_name in gt_names:
157
+ basename, ext = osp.splitext(osp.basename(gt_name))
158
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
159
+ input_path = osp.join(input_folder, input_name)
160
+ gt_path = osp.join(gt_folder, gt_name)
161
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
162
+ return paths
163
+
164
+
165
+ def paired_paths_from_folder(folders, keys, filename_tmpl, task):
166
+ """Generate paired paths from folders.
167
+
168
+ Args:
169
+ folders (list[str]): A list of folder path. The order of list should
170
+ be [input_folder, gt_folder].
171
+ keys (list[str]): A list of keys identifying folders. The order should
172
+ be in consistent with folders, e.g., ['lq', 'gt'].
173
+ filename_tmpl (str): Template for each filename. Note that the
174
+ template excludes the file extension. Usually the filename_tmpl is
175
+ for files in the input folder.
176
+
177
+ Returns:
178
+ list[str]: Returned path list.
179
+ """
180
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
181
+ f'But got {len(folders)}')
182
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
183
+ input_folder, gt_folder = folders
184
+ input_key, gt_key = keys
185
+
186
+ input_paths = list(scandir(input_folder))
187
+ gt_paths = list(scandir(gt_folder))
188
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
189
+ f'{len(input_paths)}, {len(gt_paths)}.')
190
+ paths = []
191
+ for gt_path in gt_paths:
192
+ basename, ext = osp.splitext(osp.basename(gt_path))
193
+ if task == "CAR":
194
+ input_name = f'{filename_tmpl.format(basename)}.jpg'
195
+ else:
196
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
197
+ input_path = osp.join(input_folder, input_name)
198
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
199
+ gt_path = osp.join(gt_folder, gt_path)
200
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
201
+ return paths
202
+
203
+
204
+ def paths_from_folder(folder):
205
+ """Generate paths from folder.
206
+
207
+ Args:
208
+ folder (str): Folder path.
209
+
210
+ Returns:
211
+ list[str]: Returned path list.
212
+ """
213
+
214
+ paths = list(scandir(folder))
215
+ paths = [osp.join(folder, path) for path in paths]
216
+ return paths
217
+
218
+
219
+ def paths_from_lmdb(folder):
220
+ """Generate paths from lmdb.
221
+
222
+ Args:
223
+ folder (str): Folder path.
224
+
225
+ Returns:
226
+ list[str]: Returned path list.
227
+ """
228
+ if not folder.endswith('.lmdb'):
229
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
230
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
231
+ paths = [line.split('.')[0] for line in fin]
232
+ return paths
233
+
234
+
235
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
236
+ """Generate Gaussian kernel used in `duf_downsample`.
237
+
238
+ Args:
239
+ kernel_size (int): Kernel size. Default: 13.
240
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
241
+
242
+ Returns:
243
+ np.array: The Gaussian kernel.
244
+ """
245
+ from scipy.ndimage import filters as filters
246
+ kernel = np.zeros((kernel_size, kernel_size))
247
+ # set element at the middle to one, a dirac delta
248
+ kernel[kernel_size // 2, kernel_size // 2] = 1
249
+ # gaussian-smooth the dirac, resulting in a gaussian filter
250
+ return filters.gaussian_filter(kernel, sigma)
251
+
252
+
253
+ def duf_downsample(x, kernel_size=13, scale=4):
254
+ """Downsamping with Gaussian kernel used in the DUF official code.
255
+
256
+ Args:
257
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
258
+ kernel_size (int): Kernel size. Default: 13.
259
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
260
+ Default: 4.
261
+
262
+ Returns:
263
+ Tensor: DUF downsampled frames.
264
+ """
265
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
266
+
267
+ squeeze_flag = False
268
+ if x.ndim == 4:
269
+ squeeze_flag = True
270
+ x = x.unsqueeze(0)
271
+ b, t, c, h, w = x.size()
272
+ x = x.view(-1, 1, h, w)
273
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
274
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
275
+
276
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
277
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
278
+ x = F.conv2d(x, gaussian_filter, stride=scale)
279
+ x = x[:, :, 2:-2, 2:-2]
280
+ x = x.view(b, t, c, x.size(2), x.size(3))
281
+ if squeeze_flag:
282
+ x = x.squeeze(0)
283
+ return x
basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
+ from basicsr.data.transforms import augment, paired_random_crop
6
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
7
+ from basicsr.utils.matlab_functions import bgr2ycbcr
8
+ from basicsr.utils.registry import DATASET_REGISTRY
9
+
10
+ import numpy as np
11
+
12
+ @DATASET_REGISTRY.register()
13
+ class PairedImageDataset(data.Dataset):
14
+ """Paired image dataset for image restoration.
15
+
16
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
17
+
18
+ There are three modes:
19
+ 1. 'lmdb': Use lmdb files.
20
+ If opt['io_backend'] == lmdb.
21
+ 2. 'meta_info_file': Use meta information file to generate paths.
22
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
23
+ 3. 'folder': Scan folders to generate paths.
24
+ The rest.
25
+
26
+ Args:
27
+ opt (dict): Config for train datasets. It contains the following keys:
28
+ dataroot_gt (str): Data root path for gt.
29
+ dataroot_lq (str): Data root path for lq.
30
+ meta_info_file (str): Path for meta information file.
31
+ io_backend (dict): IO backend type and other kwarg.
32
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
33
+ Default: '{}'.
34
+ gt_size (int): Cropped patched size for gt patches.
35
+ use_hflip (bool): Use horizontal flips.
36
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
37
+
38
+ scale (bool): Scale, which will be added automatically.
39
+ phase (str): 'train' or 'val'.
40
+ """
41
+
42
+ def __init__(self, opt):
43
+ super(PairedImageDataset, self).__init__()
44
+ self.opt = opt
45
+ # file client (io backend)
46
+ self.file_client = None
47
+ self.io_backend_opt = opt['io_backend']
48
+ self.mean = opt['mean'] if 'mean' in opt else None
49
+ self.task = opt['task'] if 'task' in opt else None
50
+ self.std = opt['std'] if 'std' in opt else None
51
+
52
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
53
+ if 'filename_tmpl' in opt:
54
+ self.filename_tmpl = opt['filename_tmpl']
55
+ else:
56
+ self.filename_tmpl = '{}'
57
+
58
+ if self.io_backend_opt['type'] == 'lmdb':
59
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
60
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
61
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
62
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
63
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
64
+ self.opt['meta_info_file'], self.filename_tmpl)
65
+ else:
66
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl, self.task)
67
+
68
+ def __getitem__(self, index):
69
+ if self.file_client is None:
70
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
71
+
72
+ scale = self.opt['scale']
73
+
74
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
75
+
76
+ if self.task == 'CAR':
77
+ # image range: [0, 255], int., H W 1
78
+ gt_path = self.paths[index]['gt_path']
79
+ img_bytes = self.file_client.get(gt_path, 'gt')
80
+ img_gt = imfrombytes(img_bytes, flag='grayscale', float32=False)
81
+ lq_path = self.paths[index]['lq_path']
82
+ img_bytes = self.file_client.get(lq_path, 'lq')
83
+ img_lq = imfrombytes(img_bytes, flag='grayscale', float32=False)
84
+ img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
85
+ img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
86
+
87
+ elif self.task == 'Color-DN':
88
+ gt_path = self.paths[index]['gt_path']
89
+ lq_path = gt_path
90
+ img_bytes = self.file_client.get(gt_path, 'gt')
91
+ img_gt = imfrombytes(img_bytes, float32=True)
92
+ if self.opt['phase'] != 'train':
93
+ np.random.seed(seed=0)
94
+ img_lq = img_gt + np.random.normal(0, self.noise/255., img_gt.shape)
95
+
96
+ else:
97
+ # image range: [0, 1], float32., H W 3
98
+ gt_path = self.paths[index]['gt_path']
99
+ img_bytes = self.file_client.get(gt_path, 'gt')
100
+ img_gt = imfrombytes(img_bytes, float32=True)
101
+ lq_path = self.paths[index]['lq_path']
102
+ img_bytes = self.file_client.get(lq_path, 'lq')
103
+ img_lq = imfrombytes(img_bytes, float32=True)
104
+
105
+ # augmentation for training
106
+ if self.opt['phase'] == 'train':
107
+ gt_size = self.opt['gt_size']
108
+ # random crop
109
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
110
+ # flip, rotation
111
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
112
+
113
+ # color space transform
114
+ if 'color' in self.opt and self.opt['color'] == 'y':
115
+ img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
116
+ img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
117
+
118
+ # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
119
+ # TODO: It is better to update the datasets, rather than force to crop
120
+ if self.opt['phase'] != 'train':
121
+ img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
122
+
123
+ # BGR to RGB, HWC to CHW, numpy to tensor
124
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
125
+ # normalize
126
+ if self.mean is not None or self.std is not None:
127
+ normalize(img_lq, self.mean, self.std, inplace=True)
128
+ normalize(img_gt, self.mean, self.std, inplace=True)
129
+
130
+ # print(img_lq.shape,img_gt.shape,img_lq.min(),img_gt.min(),img_lq.max(),img_gt.max(),lq_path,gt_path)
131
+
132
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
133
+
134
+ def __len__(self):
135
+ return len(self.paths)
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
basicsr/data/transforms.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import torch
4
+
5
+
6
+ def mod_crop(img, scale):
7
+ """Mod crop images, used during testing.
8
+
9
+ Args:
10
+ img (ndarray): Input image.
11
+ scale (int): Scale factor.
12
+
13
+ Returns:
14
+ ndarray: Result image.
15
+ """
16
+ img = img.copy()
17
+ if img.ndim in (2, 3):
18
+ h, w = img.shape[0], img.shape[1]
19
+ h_remainder, w_remainder = h % scale, w % scale
20
+ img = img[:h - h_remainder, :w - w_remainder, ...]
21
+ else:
22
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
23
+ return img
24
+
25
+
26
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
27
+ """Paired random crop. Support Numpy array and Tensor inputs.
28
+
29
+ It crops lists of lq and gt images with corresponding locations.
30
+
31
+ Args:
32
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
33
+ should have the same shape. If the input is an ndarray, it will
34
+ be transformed to a list containing itself.
35
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36
+ should have the same shape. If the input is an ndarray, it will
37
+ be transformed to a list containing itself.
38
+ gt_patch_size (int): GT patch size.
39
+ scale (int): Scale factor.
40
+ gt_path (str): Path to ground-truth. Default: None.
41
+
42
+ Returns:
43
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
44
+ only have one element, just return ndarray.
45
+ """
46
+
47
+ if not isinstance(img_gts, list):
48
+ img_gts = [img_gts]
49
+ if not isinstance(img_lqs, list):
50
+ img_lqs = [img_lqs]
51
+
52
+ # determine input type: Numpy array or Tensor
53
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
54
+
55
+ if input_type == 'Tensor':
56
+ h_lq, w_lq = img_lqs[0].size()[-2:]
57
+ h_gt, w_gt = img_gts[0].size()[-2:]
58
+ else:
59
+ h_lq, w_lq = img_lqs[0].shape[0:2]
60
+ h_gt, w_gt = img_gts[0].shape[0:2]
61
+ lq_patch_size = gt_patch_size // scale
62
+
63
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
64
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
65
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
66
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
67
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
68
+ f'({lq_patch_size}, {lq_patch_size}). '
69
+ f'Please remove {gt_path}.')
70
+
71
+ # randomly choose top and left coordinates for lq patch
72
+ top = random.randint(0, h_lq - lq_patch_size)
73
+ left = random.randint(0, w_lq - lq_patch_size)
74
+
75
+ # crop lq patch
76
+ if input_type == 'Tensor':
77
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
78
+ else:
79
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
80
+
81
+ # crop corresponding gt patch
82
+ top_gt, left_gt = int(top * scale), int(left * scale)
83
+ if input_type == 'Tensor':
84
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
85
+ else:
86
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
87
+ if len(img_gts) == 1:
88
+ img_gts = img_gts[0]
89
+ if len(img_lqs) == 1:
90
+ img_lqs = img_lqs[0]
91
+ return img_gts, img_lqs
92
+
93
+
94
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
95
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
96
+
97
+ We use vertical flip and transpose for rotation implementation.
98
+ All the images in the list use the same augmentation.
99
+
100
+ Args:
101
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
102
+ is an ndarray, it will be transformed to a list.
103
+ hflip (bool): Horizontal flip. Default: True.
104
+ rotation (bool): Ratotation. Default: True.
105
+ flows (list[ndarray]: Flows to be augmented. If the input is an
106
+ ndarray, it will be transformed to a list.
107
+ Dimension is (h, w, 2). Default: None.
108
+ return_status (bool): Return the status of flip and rotation.
109
+ Default: False.
110
+
111
+ Returns:
112
+ list[ndarray] | ndarray: Augmented images and flows. If returned
113
+ results only have one element, just return ndarray.
114
+
115
+ """
116
+ hflip = hflip and random.random() < 0.5
117
+ vflip = rotation and random.random() < 0.5
118
+ rot90 = rotation and random.random() < 0.5
119
+
120
+ def _augment(img):
121
+ if hflip: # horizontal
122
+ cv2.flip(img, 1, img)
123
+ if vflip: # vertical
124
+ cv2.flip(img, 0, img)
125
+ if rot90:
126
+ img = img.transpose(1, 0, 2)
127
+ return img
128
+
129
+ def _augment_flow(flow):
130
+ if hflip: # horizontal
131
+ cv2.flip(flow, 1, flow)
132
+ flow[:, :, 0] *= -1
133
+ if vflip: # vertical
134
+ cv2.flip(flow, 0, flow)
135
+ flow[:, :, 1] *= -1
136
+ if rot90:
137
+ flow = flow.transpose(1, 0, 2)
138
+ flow = flow[:, :, [1, 0]]
139
+ return flow
140
+
141
+ if not isinstance(imgs, list):
142
+ imgs = [imgs]
143
+ imgs = [_augment(img) for img in imgs]
144
+ if len(imgs) == 1:
145
+ imgs = imgs[0]
146
+
147
+ if flows is not None:
148
+ if not isinstance(flows, list):
149
+ flows = [flows]
150
+ flows = [_augment_flow(flow) for flow in flows]
151
+ if len(flows) == 1:
152
+ flows = flows[0]
153
+ return imgs, flows
154
+ else:
155
+ if return_status:
156
+ return imgs, (hflip, vflip, rot90)
157
+ else:
158
+ return imgs
159
+
160
+
161
+ def img_rotate(img, angle, center=None, scale=1.0):
162
+ """Rotate image.
163
+
164
+ Args:
165
+ img (ndarray): Image to be rotated.
166
+ angle (float): Rotation angle in degrees. Positive values mean
167
+ counter-clockwise rotation.
168
+ center (tuple[int]): Rotation center. If the center is None,
169
+ initialize it as the center of the image. Default: None.
170
+ scale (float): Isotropic scale factor. Default: 1.0.
171
+ """
172
+ (h, w) = img.shape[:2]
173
+
174
+ if center is None:
175
+ center = (w // 2, h // 2)
176
+
177
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
178
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
179
+ return rotated_img
basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must contain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
basicsr/losses/losses.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import autograd as autograd
4
+ from torch import nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ # from basicsr.archs.vgg_arch import VGGFeatureExtractor
8
+ from basicsr.utils.registry import LOSS_REGISTRY
9
+ from .loss_util import weighted_loss
10
+
11
+ _reduction_modes = ['none', 'mean', 'sum']
12
+
13
+
14
+ @weighted_loss
15
+ def l1_loss(pred, target):
16
+ return F.l1_loss(pred, target, reduction='none')
17
+
18
+
19
+ @weighted_loss
20
+ def mse_loss(pred, target):
21
+ return F.mse_loss(pred, target, reduction='none')
22
+
23
+
24
+ @weighted_loss
25
+ def charbonnier_loss(pred, target, eps=1e-12):
26
+ return torch.sqrt((pred - target)**2 + eps)
27
+
28
+
29
+ @LOSS_REGISTRY.register()
30
+ class L1Loss(nn.Module):
31
+ """L1 (mean absolute error, MAE) loss.
32
+
33
+ Args:
34
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
35
+ reduction (str): Specifies the reduction to apply to the output.
36
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
37
+ """
38
+
39
+ def __init__(self, loss_weight=1.0, reduction='mean'):
40
+ super(L1Loss, self).__init__()
41
+ if reduction not in ['none', 'mean', 'sum']:
42
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
43
+
44
+ self.loss_weight = loss_weight
45
+ self.reduction = reduction
46
+
47
+ def forward(self, pred, target, weight=None, **kwargs):
48
+ """
49
+ Args:
50
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
51
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
52
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
53
+ """
54
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
55
+
56
+
57
+ @LOSS_REGISTRY.register()
58
+ class MSELoss(nn.Module):
59
+ """MSE (L2) loss.
60
+
61
+ Args:
62
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
63
+ reduction (str): Specifies the reduction to apply to the output.
64
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
65
+ """
66
+
67
+ def __init__(self, loss_weight=1.0, reduction='mean'):
68
+ super(MSELoss, self).__init__()
69
+ if reduction not in ['none', 'mean', 'sum']:
70
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
71
+
72
+ self.loss_weight = loss_weight
73
+ self.reduction = reduction
74
+
75
+ def forward(self, pred, target, weight=None, **kwargs):
76
+ """
77
+ Args:
78
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
79
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
80
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
81
+ """
82
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
83
+
84
+
85
+ @LOSS_REGISTRY.register()
86
+ class CharbonnierLoss(nn.Module):
87
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
88
+ variant of L1Loss).
89
+
90
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
91
+ Super-Resolution".
92
+
93
+ Args:
94
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
95
+ reduction (str): Specifies the reduction to apply to the output.
96
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
97
+ eps (float): A value used to control the curvature near zero. Default: 1e-12.
98
+ """
99
+
100
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
101
+ super(CharbonnierLoss, self).__init__()
102
+ if reduction not in ['none', 'mean', 'sum']:
103
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
104
+
105
+ self.loss_weight = loss_weight
106
+ self.reduction = reduction
107
+ self.eps = eps
108
+
109
+ def forward(self, pred, target, weight=None, **kwargs):
110
+ """
111
+ Args:
112
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
113
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
114
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
115
+ """
116
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
117
+
118
+
119
+ @LOSS_REGISTRY.register()
120
+ class WeightedTVLoss(L1Loss):
121
+ """Weighted TV loss.
122
+
123
+ Args:
124
+ loss_weight (float): Loss weight. Default: 1.0.
125
+ """
126
+
127
+ def __init__(self, loss_weight=1.0, reduction='mean'):
128
+ if reduction not in ['mean', 'sum']:
129
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
130
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
131
+
132
+ def forward(self, pred, weight=None):
133
+ if weight is None:
134
+ y_weight = None
135
+ x_weight = None
136
+ else:
137
+ y_weight = weight[:, :, :-1, :]
138
+ x_weight = weight[:, :, :, :-1]
139
+
140
+ y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
141
+ x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
142
+
143
+ loss = x_diff + y_diff
144
+
145
+ return loss
146
+
147
+
148
+ # @LOSS_REGISTRY.register()
149
+ # class PerceptualLoss(nn.Module):
150
+ # """Perceptual loss with commonly used style loss.
151
+ #
152
+ # Args:
153
+ # layer_weights (dict): The weight for each layer of vgg feature.
154
+ # Here is an example: {'conv5_4': 1.}, which means the conv5_4
155
+ # feature layer (before relu5_4) will be extracted with weight
156
+ # 1.0 in calculating losses.
157
+ # vgg_type (str): The type of vgg network used as feature extractor.
158
+ # Default: 'vgg19'.
159
+ # use_input_norm (bool): If True, normalize the input image in vgg.
160
+ # Default: True.
161
+ # range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
162
+ # Default: False.
163
+ # perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
164
+ # loss will be calculated and the loss will multiplied by the
165
+ # weight. Default: 1.0.
166
+ # style_weight (float): If `style_weight > 0`, the style loss will be
167
+ # calculated and the loss will multiplied by the weight.
168
+ # Default: 0.
169
+ # criterion (str): Criterion used for perceptual loss. Default: 'l1'.
170
+ # """
171
+ #
172
+ # def __init__(self,
173
+ # layer_weights,
174
+ # vgg_type='vgg19',
175
+ # use_input_norm=True,
176
+ # range_norm=False,
177
+ # perceptual_weight=1.0,
178
+ # style_weight=0.,
179
+ # criterion='l1'):
180
+ # super(PerceptualLoss, self).__init__()
181
+ # self.perceptual_weight = perceptual_weight
182
+ # self.style_weight = style_weight
183
+ # self.layer_weights = layer_weights
184
+ # self.vgg = VGGFeatureExtractor(
185
+ # layer_name_list=list(layer_weights.keys()),
186
+ # vgg_type=vgg_type,
187
+ # use_input_norm=use_input_norm,
188
+ # range_norm=range_norm)
189
+ #
190
+ # self.criterion_type = criterion
191
+ # if self.criterion_type == 'l1':
192
+ # self.criterion = torch.nn.L1Loss()
193
+ # elif self.criterion_type == 'l2':
194
+ # self.criterion = torch.nn.L2loss()
195
+ # elif self.criterion_type == 'fro':
196
+ # self.criterion = None
197
+ # else:
198
+ # raise NotImplementedError(f'{criterion} criterion has not been supported.')
199
+ #
200
+ # def forward(self, x, gt):
201
+ # """Forward function.
202
+ #
203
+ # Args:
204
+ # x (Tensor): Input tensor with shape (n, c, h, w).
205
+ # gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
206
+ #
207
+ # Returns:
208
+ # Tensor: Forward results.
209
+ # """
210
+ # # extract vgg features
211
+ # x_features = self.vgg(x)
212
+ # gt_features = self.vgg(gt.detach())
213
+ #
214
+ # # calculate perceptual loss
215
+ # if self.perceptual_weight > 0:
216
+ # percep_loss = 0
217
+ # for k in x_features.keys():
218
+ # if self.criterion_type == 'fro':
219
+ # percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
220
+ # else:
221
+ # percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
222
+ # percep_loss *= self.perceptual_weight
223
+ # else:
224
+ # percep_loss = None
225
+ #
226
+ # # calculate style loss
227
+ # if self.style_weight > 0:
228
+ # style_loss = 0
229
+ # for k in x_features.keys():
230
+ # if self.criterion_type == 'fro':
231
+ # style_loss += torch.norm(
232
+ # self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
233
+ # else:
234
+ # style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
235
+ # gt_features[k])) * self.layer_weights[k]
236
+ # style_loss *= self.style_weight
237
+ # else:
238
+ # style_loss = None
239
+ #
240
+ # return percep_loss, style_loss
241
+ #
242
+ # def _gram_mat(self, x):
243
+ # """Calculate Gram matrix.
244
+ #
245
+ # Args:
246
+ # x (torch.Tensor): Tensor with shape of (n, c, h, w).
247
+ #
248
+ # Returns:
249
+ # torch.Tensor: Gram matrix.
250
+ # """
251
+ # n, c, h, w = x.size()
252
+ # features = x.view(n, c, w * h)
253
+ # features_t = features.transpose(1, 2)
254
+ # gram = features.bmm(features_t) / (c * h * w)
255
+ # return gram
256
+
257
+
258
+ @LOSS_REGISTRY.register()
259
+ class GANLoss(nn.Module):
260
+ """Define GAN loss.
261
+
262
+ Args:
263
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
264
+ real_label_val (float): The value for real label. Default: 1.0.
265
+ fake_label_val (float): The value for fake label. Default: 0.0.
266
+ loss_weight (float): Loss weight. Default: 1.0.
267
+ Note that loss_weight is only for generators; and it is always 1.0
268
+ for discriminators.
269
+ """
270
+
271
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
272
+ super(GANLoss, self).__init__()
273
+ self.gan_type = gan_type
274
+ self.loss_weight = loss_weight
275
+ self.real_label_val = real_label_val
276
+ self.fake_label_val = fake_label_val
277
+
278
+ if self.gan_type == 'vanilla':
279
+ self.loss = nn.BCEWithLogitsLoss()
280
+ elif self.gan_type == 'lsgan':
281
+ self.loss = nn.MSELoss()
282
+ elif self.gan_type == 'wgan':
283
+ self.loss = self._wgan_loss
284
+ elif self.gan_type == 'wgan_softplus':
285
+ self.loss = self._wgan_softplus_loss
286
+ elif self.gan_type == 'hinge':
287
+ self.loss = nn.ReLU()
288
+ else:
289
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
290
+
291
+ def _wgan_loss(self, input, target):
292
+ """wgan loss.
293
+
294
+ Args:
295
+ input (Tensor): Input tensor.
296
+ target (bool): Target label.
297
+
298
+ Returns:
299
+ Tensor: wgan loss.
300
+ """
301
+ return -input.mean() if target else input.mean()
302
+
303
+ def _wgan_softplus_loss(self, input, target):
304
+ """wgan loss with soft plus. softplus is a smooth approximation to the
305
+ ReLU function.
306
+
307
+ In StyleGAN2, it is called:
308
+ Logistic loss for discriminator;
309
+ Non-saturating loss for generator.
310
+
311
+ Args:
312
+ input (Tensor): Input tensor.
313
+ target (bool): Target label.
314
+
315
+ Returns:
316
+ Tensor: wgan loss.
317
+ """
318
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
319
+
320
+ def get_target_label(self, input, target_is_real):
321
+ """Get target label.
322
+
323
+ Args:
324
+ input (Tensor): Input tensor.
325
+ target_is_real (bool): Whether the target is real or fake.
326
+
327
+ Returns:
328
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
329
+ return Tensor.
330
+ """
331
+
332
+ if self.gan_type in ['wgan', 'wgan_softplus']:
333
+ return target_is_real
334
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
335
+ return input.new_ones(input.size()) * target_val
336
+
337
+ def forward(self, input, target_is_real, is_disc=False):
338
+ """
339
+ Args:
340
+ input (Tensor): The input for the loss module, i.e., the network
341
+ prediction.
342
+ target_is_real (bool): Whether the targe is real or fake.
343
+ is_disc (bool): Whether the loss for discriminators or not.
344
+ Default: False.
345
+
346
+ Returns:
347
+ Tensor: GAN loss value.
348
+ """
349
+ target_label = self.get_target_label(input, target_is_real)
350
+ if self.gan_type == 'hinge':
351
+ if is_disc: # for discriminators in hinge-gan
352
+ input = -input if target_is_real else input
353
+ loss = self.loss(1 + input).mean()
354
+ else: # for generators in hinge-gan
355
+ loss = -input.mean()
356
+ else: # other gan types
357
+ loss = self.loss(input, target_label)
358
+
359
+ # loss_weight is always 1.0 for discriminators
360
+ return loss if is_disc else loss * self.loss_weight
361
+
362
+
363
+ @LOSS_REGISTRY.register()
364
+ class MultiScaleGANLoss(GANLoss):
365
+ """
366
+ MultiScaleGANLoss accepts a list of predictions
367
+ """
368
+
369
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
370
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
371
+
372
+ def forward(self, input, target_is_real, is_disc=False):
373
+ """
374
+ The input is a list of tensors, or a list of (a list of tensors)
375
+ """
376
+ if isinstance(input, list):
377
+ loss = 0
378
+ for pred_i in input:
379
+ if isinstance(pred_i, list):
380
+ # Only compute GAN loss for the last layer
381
+ # in case of multiscale feature matching
382
+ pred_i = pred_i[-1]
383
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
384
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
385
+ loss += loss_tensor
386
+ return loss / len(input)
387
+ else:
388
+ return super().forward(input, target_is_real, is_disc)
389
+
390
+
391
+ def r1_penalty(real_pred, real_img):
392
+ """R1 regularization for discriminator. The core idea is to
393
+ penalize the gradient on real data alone: when the
394
+ generator distribution produces the true data distribution
395
+ and the discriminator is equal to 0 on the data manifold, the
396
+ gradient penalty ensures that the discriminator cannot create
397
+ a non-zero gradient orthogonal to the data manifold without
398
+ suffering a loss in the GAN game.
399
+
400
+ Ref:
401
+ Eq. 9 in Which training methods for GANs do actually converge.
402
+ """
403
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
404
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
405
+ return grad_penalty
406
+
407
+
408
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
409
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
410
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
411
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
412
+
413
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
414
+
415
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
416
+
417
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
418
+
419
+
420
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
421
+ """Calculate gradient penalty for wgan-gp.
422
+
423
+ Args:
424
+ discriminator (nn.Module): Network for the discriminator.
425
+ real_data (Tensor): Real input data.
426
+ fake_data (Tensor): Fake input data.
427
+ weight (Tensor): Weight tensor. Default: None.
428
+
429
+ Returns:
430
+ Tensor: A tensor for gradient penalty.
431
+ """
432
+
433
+ batch_size = real_data.size(0)
434
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
435
+
436
+ # interpolate between real_data and fake_data
437
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
438
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
439
+
440
+ disc_interpolates = discriminator(interpolates)
441
+ gradients = autograd.grad(
442
+ outputs=disc_interpolates,
443
+ inputs=interpolates,
444
+ grad_outputs=torch.ones_like(disc_interpolates),
445
+ create_graph=True,
446
+ retain_graph=True,
447
+ only_inputs=True)[0]
448
+
449
+ if weight is not None:
450
+ gradients = gradients * weight
451
+
452
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
453
+ if weight is not None:
454
+ gradients_penalty /= torch.mean(weight)
455
+
456
+ return gradients_penalty
457
+
458
+
459
+ @LOSS_REGISTRY.register()
460
+ class GANFeatLoss(nn.Module):
461
+ """Define feature matching loss for gans
462
+
463
+ Args:
464
+ criterion (str): Support 'l1', 'l2', 'charbonnier'.
465
+ loss_weight (float): Loss weight. Default: 1.0.
466
+ reduction (str): Specifies the reduction to apply to the output.
467
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
468
+ """
469
+
470
+ def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'):
471
+ super(GANFeatLoss, self).__init__()
472
+ if criterion == 'l1':
473
+ self.loss_op = L1Loss(loss_weight, reduction)
474
+ elif criterion == 'l2':
475
+ self.loss_op = MSELoss(loss_weight, reduction)
476
+ elif criterion == 'charbonnier':
477
+ self.loss_op = CharbonnierLoss(loss_weight, reduction)
478
+ else:
479
+ raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier')
480
+
481
+ self.loss_weight = loss_weight
482
+
483
+ def forward(self, pred_fake, pred_real):
484
+ num_d = len(pred_fake)
485
+ loss = 0
486
+ for i in range(num_d): # for each discriminator
487
+ # last output is the final prediction, exclude it
488
+ num_intermediate_outputs = len(pred_fake[i]) - 1
489
+ for j in range(num_intermediate_outputs): # for each layer output
490
+ unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach())
491
+ loss += unweighted_loss / num_d
492
+ return loss * self.loss_weight
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+
6
+ __all__ = ['calculate_psnr', 'calculate_ssim']
7
+
8
+
9
+ def calculate_metric(data, opt):
10
+ """Calculate metric from data and options.
11
+
12
+ Args:
13
+ opt (dict): Configuration. It must contain:
14
+ type (str): Model type.
15
+ """
16
+ opt = deepcopy(opt)
17
+ metric_type = opt.pop('type')
18
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19
+ return metric
basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
25
+ if len(img.shape) == 2:
26
+ img = img[..., None]
27
+ if input_order == 'CHW':
28
+ img = img.transpose(1, 2, 0)
29
+ return img
30
+
31
+
32
+ def to_y_channel(img):
33
+ """Change to Y channel of YCbCr.
34
+
35
+ Args:
36
+ img (ndarray): Images with range [0, 255].
37
+
38
+ Returns:
39
+ (ndarray): Images with range [0, 255] (float type) without round.
40
+ """
41
+ img = img.astype(np.float32) / 255.
42
+ if img.ndim == 3 and img.shape[2] == 3:
43
+ img = bgr2ycbcr(img, y_only=True)
44
+ img = img[..., None]
45
+ return img * 255.
basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ from basicsr.utils.registry import METRIC_REGISTRY
6
+
7
+
8
+ @METRIC_REGISTRY.register()
9
+ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
10
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
11
+
12
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13
+
14
+ Args:
15
+ img (ndarray): Images with range [0, 255].
16
+ img2 (ndarray): Images with range [0, 255].
17
+ crop_border (int): Cropped pixels in each edge of an image. These
18
+ pixels are not involved in the PSNR calculation.
19
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
20
+ Default: 'HWC'.
21
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22
+
23
+ Returns:
24
+ float: psnr result.
25
+ """
26
+
27
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
28
+ if input_order not in ['HWC', 'CHW']:
29
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
30
+ img = reorder_image(img, input_order=input_order)
31
+ img2 = reorder_image(img2, input_order=input_order)
32
+ img = img.astype(np.float64)
33
+ img2 = img2.astype(np.float64)
34
+
35
+ if crop_border != 0:
36
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
37
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38
+
39
+ if test_y_channel:
40
+ img = to_y_channel(img)
41
+ img2 = to_y_channel(img2)
42
+
43
+ mse = np.mean((img - img2)**2)
44
+ if mse == 0:
45
+ return float('inf')
46
+ return 20. * np.log10(255. / np.sqrt(mse))
47
+
48
+
49
+ def _ssim(img, img2):
50
+ """Calculate SSIM (structural similarity) for one channel images.
51
+
52
+ It is called by func:`calculate_ssim`.
53
+
54
+ Args:
55
+ img (ndarray): Images with range [0, 255] with order 'HWC'.
56
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57
+
58
+ Returns:
59
+ float: ssim result.
60
+ """
61
+
62
+ c1 = (0.01 * 255)**2
63
+ c2 = (0.03 * 255)**2
64
+
65
+ img = img.astype(np.float64)
66
+ img2 = img2.astype(np.float64)
67
+ kernel = cv2.getGaussianKernel(11, 1.5)
68
+ window = np.outer(kernel, kernel.transpose())
69
+
70
+ mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
71
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72
+ mu1_sq = mu1**2
73
+ mu2_sq = mu2**2
74
+ mu1_mu2 = mu1 * mu2
75
+ sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
76
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
77
+ sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78
+
79
+ ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
80
+ return ssim_map.mean()
81
+
82
+
83
+ @METRIC_REGISTRY.register()
84
+ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
85
+ """Calculate SSIM (structural similarity).
86
+
87
+ Ref:
88
+ Image quality assessment: From error visibility to structural similarity
89
+
90
+ The results are the same as that of the official released MATLAB code in
91
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
+
93
+ For three-channel images, SSIM is calculated for each channel and then
94
+ averaged.
95
+
96
+ Args:
97
+ img (ndarray): Images with range [0, 255].
98
+ img2 (ndarray): Images with range [0, 255].
99
+ crop_border (int): Cropped pixels in each edge of an image. These
100
+ pixels are not involved in the SSIM calculation.
101
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
+ Default: 'HWC'.
103
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
+
105
+ Returns:
106
+ float: ssim result.
107
+ """
108
+
109
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
110
+ if input_order not in ['HWC', 'CHW']:
111
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
112
+ img = reorder_image(img, input_order=input_order)
113
+ img2 = reorder_image(img2, input_order=input_order)
114
+ img = img.astype(np.float64)
115
+ img2 = img2.astype(np.float64)
116
+
117
+ if crop_border != 0:
118
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
119
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
+
121
+ if test_y_channel:
122
+ img = to_y_channel(img)
123
+ img2 = to_y_channel(img2)
124
+
125
+ ssims = []
126
+ for i in range(img.shape[2]):
127
+ ssims.append(_ssim(img[..., i], img2[..., i]))
128
+ return np.array(ssims).mean()
basicsr/models/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import MODEL_REGISTRY
7
+
8
+ __all__ = ['build_model']
9
+
10
+ # automatically scan and import model modules for registry
11
+ # scan all the files under the 'models' folder and collect files ending with
12
+ # '_model.py'
13
+ model_folder = osp.dirname(osp.abspath(__file__))
14
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15
+ # import all the model modules
16
+ _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
17
+
18
+
19
+ def build_model(opt):
20
+ """Build model from options.
21
+
22
+ Args:
23
+ opt (dict): Configuration. It must contain:
24
+ model_type (str): Model type.
25
+ """
26
+ opt = deepcopy(opt)
27
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28
+ logger = get_root_logger()
29
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
30
+ return model
basicsr/models/base_model.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ from basicsr.models import lr_scheduler as lr_scheduler
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.dist_util import master_only
11
+
12
+
13
+ class BaseModel():
14
+ """Base model."""
15
+
16
+ def __init__(self, opt):
17
+ self.opt = opt
18
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
19
+ self.is_train = opt['is_train']
20
+ self.schedulers = []
21
+ self.optimizers = []
22
+
23
+ def feed_data(self, data):
24
+ pass
25
+
26
+ def optimize_parameters(self):
27
+ pass
28
+
29
+ def get_current_visuals(self):
30
+ pass
31
+
32
+ def save(self, epoch, current_iter):
33
+ """Save networks and training state."""
34
+ pass
35
+
36
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
37
+ """Validation function.
38
+
39
+ Args:
40
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
41
+ current_iter (int): Current iteration.
42
+ tb_logger (tensorboard logger): Tensorboard logger.
43
+ save_img (bool): Whether to save images. Default: False.
44
+ """
45
+ if self.opt['dist']:
46
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
47
+ else:
48
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
49
+
50
+ def _initialize_best_metric_results(self, dataset_name):
51
+ """Initialize the best metric results dict for recording the best metric value and iteration."""
52
+ if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
53
+ return
54
+ elif not hasattr(self, 'best_metric_results'):
55
+ self.best_metric_results = dict()
56
+
57
+ # add a dataset record
58
+ record = dict()
59
+ for metric, content in self.opt['val']['metrics'].items():
60
+ better = content.get('better', 'higher')
61
+ init_val = float('-inf') if better == 'higher' else float('inf')
62
+ record[metric] = dict(better=better, val=init_val, iter=-1)
63
+ self.best_metric_results[dataset_name] = record
64
+
65
+ def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
66
+ if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
67
+ if val >= self.best_metric_results[dataset_name][metric]['val']:
68
+ self.best_metric_results[dataset_name][metric]['val'] = val
69
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
70
+ else:
71
+ if val <= self.best_metric_results[dataset_name][metric]['val']:
72
+ self.best_metric_results[dataset_name][metric]['val'] = val
73
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
74
+
75
+ def model_ema(self, decay=0.999):
76
+ net_g = self.get_bare_model(self.net_g)
77
+
78
+ net_g_params = dict(net_g.named_parameters())
79
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
80
+
81
+ for k in net_g_ema_params.keys():
82
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
83
+
84
+ def get_current_log(self):
85
+ return self.log_dict
86
+
87
+ def model_to_device(self, net):
88
+ """Model to device. It also warps models with DistributedDataParallel
89
+ or DataParallel.
90
+
91
+ Args:
92
+ net (nn.Module)
93
+ """
94
+ net = net.to(self.device)
95
+ if self.opt['dist']:
96
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
97
+ net = DistributedDataParallel(
98
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
99
+ elif self.opt['num_gpu'] > 1:
100
+ net = DataParallel(net)
101
+ return net
102
+
103
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
104
+ if optim_type == 'Adam':
105
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
106
+ else:
107
+ raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
108
+ return optimizer
109
+
110
+ def setup_schedulers(self):
111
+ """Set up schedulers."""
112
+ train_opt = self.opt['train']
113
+ scheduler_type = train_opt['scheduler'].pop('type')
114
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
115
+ for optimizer in self.optimizers:
116
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
117
+ elif scheduler_type == 'CosineAnnealingRestartLR':
118
+ for optimizer in self.optimizers:
119
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
120
+ else:
121
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
122
+
123
+ def get_bare_model(self, net):
124
+ """Get bare model, especially under wrapping with
125
+ DistributedDataParallel or DataParallel.
126
+ """
127
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
128
+ net = net.module
129
+ return net
130
+
131
+ @master_only
132
+ def print_network(self, net):
133
+ """Print the str and parameter number of a network.
134
+
135
+ Args:
136
+ net (nn.Module)
137
+ """
138
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
139
+ net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
140
+ else:
141
+ net_cls_str = f'{net.__class__.__name__}'
142
+
143
+ net = self.get_bare_model(net)
144
+ net_str = str(net)
145
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
146
+
147
+ logger = get_root_logger()
148
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
149
+ logger.info(net_str)
150
+
151
+ def _set_lr(self, lr_groups_l):
152
+ """Set learning rate for warmup.
153
+
154
+ Args:
155
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
156
+ """
157
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
158
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
159
+ param_group['lr'] = lr
160
+
161
+ def _get_init_lr(self):
162
+ """Get the initial lr, which is set by the scheduler.
163
+ """
164
+ init_lr_groups_l = []
165
+ for optimizer in self.optimizers:
166
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
167
+ return init_lr_groups_l
168
+
169
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
170
+ """Update learning rate.
171
+
172
+ Args:
173
+ current_iter (int): Current iteration.
174
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
175
+ Default: -1.
176
+ """
177
+ if current_iter > 1:
178
+ for scheduler in self.schedulers:
179
+ scheduler.step()
180
+ # set up warm-up learning rate
181
+ if current_iter < warmup_iter:
182
+ # get initial lr for each group
183
+ init_lr_g_l = self._get_init_lr()
184
+ # modify warming-up learning rates
185
+ # currently only support linearly warm up
186
+ warm_up_lr_l = []
187
+ for init_lr_g in init_lr_g_l:
188
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
189
+ # set learning rate
190
+ self._set_lr(warm_up_lr_l)
191
+
192
+ def get_current_learning_rate(self):
193
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
194
+
195
+ @master_only
196
+ def save_network(self, net, net_label, current_iter, param_key='params'):
197
+ """Save networks.
198
+
199
+ Args:
200
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
201
+ net_label (str): Network label.
202
+ current_iter (int): Current iter number.
203
+ param_key (str | list[str]): The parameter key(s) to save network.
204
+ Default: 'params'.
205
+ """
206
+ if current_iter == -1:
207
+ current_iter = 'latest'
208
+ save_filename = f'{net_label}_{current_iter}.pth'
209
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
210
+
211
+ net = net if isinstance(net, list) else [net]
212
+ param_key = param_key if isinstance(param_key, list) else [param_key]
213
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
214
+
215
+ save_dict = {}
216
+ for net_, param_key_ in zip(net, param_key):
217
+ net_ = self.get_bare_model(net_)
218
+ state_dict = net_.state_dict()
219
+ for key, param in state_dict.items():
220
+ if key.startswith('module.'): # remove unnecessary 'module.'
221
+ key = key[7:]
222
+ state_dict[key] = param.cpu()
223
+ save_dict[param_key_] = state_dict
224
+
225
+ # avoid occasional writing errors
226
+ retry = 3
227
+ while retry > 0:
228
+ try:
229
+ torch.save(save_dict, save_path)
230
+ except Exception as e:
231
+ logger = get_root_logger()
232
+ logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
233
+ time.sleep(1)
234
+ else:
235
+ break
236
+ finally:
237
+ retry -= 1
238
+ if retry == 0:
239
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
240
+ # raise IOError(f'Cannot save {save_path}.')
241
+
242
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
243
+ """Print keys with different name or different size when loading models.
244
+
245
+ 1. Print keys with different names.
246
+ 2. If strict=False, print the same key but with different tensor size.
247
+ It also ignore these keys with different sizes (not load).
248
+
249
+ Args:
250
+ crt_net (torch model): Current network.
251
+ load_net (dict): Loaded network.
252
+ strict (bool): Whether strictly loaded. Default: True.
253
+ """
254
+ crt_net = self.get_bare_model(crt_net)
255
+ crt_net = crt_net.state_dict()
256
+ crt_net_keys = set(crt_net.keys())
257
+ load_net_keys = set(load_net.keys())
258
+
259
+ logger = get_root_logger()
260
+ if crt_net_keys != load_net_keys:
261
+ logger.warning('Current net - loaded net:')
262
+ for v in sorted(list(crt_net_keys - load_net_keys)):
263
+ logger.warning(f' {v}')
264
+ logger.warning('Loaded net - current net:')
265
+ for v in sorted(list(load_net_keys - crt_net_keys)):
266
+ logger.warning(f' {v}')
267
+
268
+ # check the size for the same keys
269
+ if not strict:
270
+ common_keys = crt_net_keys & load_net_keys
271
+ for k in common_keys:
272
+ if crt_net[k].size() != load_net[k].size():
273
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
274
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
275
+ load_net[k + '.ignore'] = load_net.pop(k)
276
+
277
+ def load_network(self, net, load_path, strict=True, param_key='params'):
278
+ """Load network.
279
+
280
+ Args:
281
+ load_path (str): The path of networks to be loaded.
282
+ net (nn.Module): Network.
283
+ strict (bool): Whether strictly loaded.
284
+ param_key (str): The parameter key of loaded network. If set to
285
+ None, use the root 'path'.
286
+ Default: 'params'.
287
+ """
288
+ logger = get_root_logger()
289
+ net = self.get_bare_model(net)
290
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
291
+ if param_key is not None:
292
+ if param_key not in load_net and 'params' in load_net:
293
+ param_key = 'params'
294
+ logger.info('Loading: params_ema does not exist, use params.')
295
+ load_net = load_net[param_key]
296
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
297
+ # remove unnecessary 'module.'
298
+ for k, v in deepcopy(load_net).items():
299
+ if k.startswith('module.'):
300
+ load_net[k[7:]] = v
301
+ load_net.pop(k)
302
+ self._print_different_keys_loading(net, load_net, strict)
303
+ net.load_state_dict(load_net, strict=strict)
304
+
305
+ @master_only
306
+ def save_training_state(self, epoch, current_iter):
307
+ """Save training states during training, which will be used for
308
+ resuming.
309
+
310
+ Args:
311
+ epoch (int): Current epoch.
312
+ current_iter (int): Current iteration.
313
+ """
314
+ if current_iter != -1:
315
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
316
+ for o in self.optimizers:
317
+ state['optimizers'].append(o.state_dict())
318
+ for s in self.schedulers:
319
+ state['schedulers'].append(s.state_dict())
320
+ save_filename = f'{current_iter}.state'
321
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
322
+
323
+ # avoid occasional writing errors
324
+ retry = 3
325
+ while retry > 0:
326
+ try:
327
+ torch.save(state, save_path)
328
+ except Exception as e:
329
+ logger = get_root_logger()
330
+ logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
331
+ time.sleep(1)
332
+ else:
333
+ break
334
+ finally:
335
+ retry -= 1
336
+ if retry == 0:
337
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
338
+ # raise IOError(f'Cannot save {save_path}.')
339
+
340
+ def resume_training(self, resume_state):
341
+ """Reload the optimizers and schedulers for resumed training.
342
+
343
+ Args:
344
+ resume_state (dict): Resume state.
345
+ """
346
+ resume_optimizers = resume_state['optimizers']
347
+ resume_schedulers = resume_state['schedulers']
348
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
349
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
350
+ for i, o in enumerate(resume_optimizers):
351
+ self.optimizers[i].load_state_dict(o)
352
+ for i, s in enumerate(resume_schedulers):
353
+ self.schedulers[i].load_state_dict(s)
354
+
355
+ def reduce_loss_dict(self, loss_dict):
356
+ """reduce loss dict.
357
+
358
+ In distributed training, it averages the losses among different GPUs .
359
+
360
+ Args:
361
+ loss_dict (OrderedDict): Loss dict.
362
+ """
363
+ with torch.no_grad():
364
+ if self.opt['dist']:
365
+ keys = []
366
+ losses = []
367
+ for name, value in loss_dict.items():
368
+ keys.append(name)
369
+ losses.append(value)
370
+ losses = torch.stack(losses, 0)
371
+ torch.distributed.reduce(losses, dst=0)
372
+ if self.opt['rank'] == 0:
373
+ losses /= self.opt['world_size']
374
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
375
+
376
+ log_dict = OrderedDict()
377
+ for name, value in loss_dict.items():
378
+ log_dict[name] = value.mean().item()
379
+
380
+ return log_dict
basicsr/models/lr_scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class MultiStepRestartLR(_LRScheduler):
7
+ """ MultiStep with restarts learning rate scheme.
8
+
9
+ Args:
10
+ optimizer (torch.nn.optimizer): Torch optimizer.
11
+ milestones (list): Iterations that will decrease learning rate.
12
+ gamma (float): Decrease ratio. Default: 0.1.
13
+ restarts (list): Restart iterations. Default: [0].
14
+ restart_weights (list): Restart weights at each restart iteration.
15
+ Default: [1].
16
+ last_epoch (int): Used in _LRScheduler. Default: -1.
17
+ """
18
+
19
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20
+ self.milestones = Counter(milestones)
21
+ self.gamma = gamma
22
+ self.restarts = restarts
23
+ self.restart_weights = restart_weights
24
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26
+
27
+ def get_lr(self):
28
+ if self.last_epoch in self.restarts:
29
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31
+ if self.last_epoch not in self.milestones:
32
+ return [group['lr'] for group in self.optimizer.param_groups]
33
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34
+
35
+
36
+ def get_position_from_periods(iteration, cumulative_period):
37
+ """Get the position from a period list.
38
+
39
+ It will return the index of the right-closest number in the period list.
40
+ For example, the cumulative_period = [100, 200, 300, 400],
41
+ if iteration == 50, return 0;
42
+ if iteration == 210, return 2;
43
+ if iteration == 300, return 2.
44
+
45
+ Args:
46
+ iteration (int): Current iteration.
47
+ cumulative_period (list[int]): Cumulative period list.
48
+
49
+ Returns:
50
+ int: The position of the right-closest number in the period list.
51
+ """
52
+ for i, period in enumerate(cumulative_period):
53
+ if iteration <= period:
54
+ return i
55
+
56
+
57
+ class CosineAnnealingRestartLR(_LRScheduler):
58
+ """ Cosine annealing with restarts learning rate scheme.
59
+
60
+ An example of config:
61
+ periods = [10, 10, 10, 10]
62
+ restart_weights = [1, 0.5, 0.5, 0.5]
63
+ eta_min=1e-7
64
+
65
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66
+ scheduler will restart with the weights in restart_weights.
67
+
68
+ Args:
69
+ optimizer (torch.nn.optimizer): Torch optimizer.
70
+ periods (list): Period for each cosine anneling cycle.
71
+ restart_weights (list): Restart weights at each restart iteration.
72
+ Default: [1].
73
+ eta_min (float): The minimum lr. Default: 0.
74
+ last_epoch (int): Used in _LRScheduler. Default: -1.
75
+ """
76
+
77
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78
+ self.periods = periods
79
+ self.restart_weights = restart_weights
80
+ self.eta_min = eta_min
81
+ assert (len(self.periods) == len(
82
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
83
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85
+
86
+ def get_lr(self):
87
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88
+ current_weight = self.restart_weights[idx]
89
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90
+ current_period = self.periods[idx]
91
+
92
+ return [
93
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95
+ for base_lr in self.base_lrs
96
+ ]
basicsr/models/sr_model.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+ from basicsr.archs import build_network
7
+ from basicsr.losses import build_loss
8
+ from basicsr.metrics import calculate_metric
9
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from basicsr.utils.registry import MODEL_REGISTRY
11
+ from .base_model import BaseModel
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class SRModel(BaseModel):
16
+ """Base SR model for single image super-resolution."""
17
+
18
+ def __init__(self, opt):
19
+ super(SRModel, self).__init__(opt)
20
+
21
+ # define network
22
+ self.net_g = build_network(opt['network_g'])
23
+ self.net_g = self.model_to_device(self.net_g)
24
+ self.print_network(self.net_g)
25
+
26
+ # load pretrained models
27
+ load_path = self.opt['path'].get('pretrain_network_g', None)
28
+ if load_path is not None:
29
+ param_key = self.opt['path'].get('param_key_g', 'params')
30
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
31
+
32
+ if self.is_train:
33
+ self.init_training_settings()
34
+
35
+ def init_training_settings(self):
36
+ self.net_g.train()
37
+ train_opt = self.opt['train']
38
+
39
+ self.ema_decay = train_opt.get('ema_decay', 0)
40
+ if self.ema_decay > 0:
41
+ logger = get_root_logger()
42
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
43
+ # define network net_g with Exponential Moving Average (EMA)
44
+ # net_g_ema is used only for testing on one GPU and saving
45
+ # There is no need to wrap with DistributedDataParallel
46
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
47
+ # load pretrained model
48
+ load_path = self.opt['path'].get('pretrain_network_g', None)
49
+ if load_path is not None:
50
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
51
+ else:
52
+ self.model_ema(0) # copy net_g weight
53
+ self.net_g_ema.eval()
54
+
55
+ # define losses
56
+ if train_opt.get('pixel_opt'):
57
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
58
+ else:
59
+ self.cri_pix = None
60
+
61
+ if train_opt.get('perceptual_opt'):
62
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
63
+ else:
64
+ self.cri_perceptual = None
65
+
66
+ if self.cri_pix is None and self.cri_perceptual is None:
67
+ raise ValueError('Both pixel and perceptual losses are None.')
68
+
69
+ # set up optimizers and schedulers
70
+ self.setup_optimizers()
71
+ self.setup_schedulers()
72
+
73
+ def setup_optimizers(self):
74
+ train_opt = self.opt['train']
75
+ optim_params = []
76
+ for k, v in self.net_g.named_parameters():
77
+ if v.requires_grad:
78
+ optim_params.append(v)
79
+ else:
80
+ logger = get_root_logger()
81
+ logger.warning(f'Params {k} will not be optimized.')
82
+
83
+ optim_type = train_opt['optim_g'].pop('type')
84
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
85
+ self.optimizers.append(self.optimizer_g)
86
+
87
+ def feed_data(self, data):
88
+ self.lq = data['lq'].to(self.device)
89
+ if 'gt' in data:
90
+ self.gt = data['gt'].to(self.device)
91
+
92
+ def optimize_parameters(self, current_iter):
93
+ self.optimizer_g.zero_grad()
94
+ self.output = self.net_g(self.lq)
95
+
96
+ l_total = 0
97
+ loss_dict = OrderedDict()
98
+ # pixel loss
99
+ if self.cri_pix:
100
+ l_pix = self.cri_pix(self.output, self.gt)
101
+ l_total += l_pix
102
+ loss_dict['l_pix'] = l_pix
103
+ # perceptual loss
104
+ if self.cri_perceptual:
105
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
106
+ if l_percep is not None:
107
+ l_total += l_percep
108
+ loss_dict['l_percep'] = l_percep
109
+ if l_style is not None:
110
+ l_total += l_style
111
+ loss_dict['l_style'] = l_style
112
+
113
+ l_total.backward()
114
+ self.optimizer_g.step()
115
+
116
+ self.log_dict = self.reduce_loss_dict(loss_dict)
117
+
118
+ if self.ema_decay > 0:
119
+ self.model_ema(decay=self.ema_decay)
120
+
121
+ def test(self):
122
+ if hasattr(self, 'net_g_ema'):
123
+ self.net_g_ema.eval()
124
+ with torch.no_grad():
125
+ self.output = self.net_g_ema(self.lq)
126
+ else:
127
+ self.net_g.eval()
128
+ with torch.no_grad():
129
+ self.output = self.net_g(self.lq)
130
+ self.net_g.train()
131
+
132
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
133
+ if self.opt['rank'] == 0:
134
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
135
+
136
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
137
+ dataset_name = dataloader.dataset.opt['name']
138
+ with_metrics = self.opt['val'].get('metrics') is not None
139
+ use_pbar = self.opt['val'].get('pbar', False)
140
+
141
+ if with_metrics:
142
+ if not hasattr(self, 'metric_results'): # only execute in the first run
143
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
144
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
145
+ self._initialize_best_metric_results(dataset_name)
146
+ # zero self.metric_results
147
+ if with_metrics:
148
+ self.metric_results = {metric: 0 for metric in self.metric_results}
149
+
150
+ metric_data = dict()
151
+ if use_pbar:
152
+ pbar = tqdm(total=len(dataloader), unit='image')
153
+
154
+ for idx, val_data in enumerate(dataloader):
155
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
156
+ self.feed_data(val_data)
157
+ self.test()
158
+
159
+ visuals = self.get_current_visuals()
160
+ sr_img = tensor2img([visuals['result']])
161
+ metric_data['img'] = sr_img
162
+ if 'gt' in visuals:
163
+ gt_img = tensor2img([visuals['gt']])
164
+ metric_data['img2'] = gt_img
165
+ del self.gt
166
+
167
+ # tentative for out of GPU memory
168
+ del self.lq
169
+ del self.output
170
+ torch.cuda.empty_cache()
171
+
172
+ if save_img:
173
+ if self.opt['is_train']:
174
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
175
+ f'{img_name}_{current_iter}.png')
176
+ else:
177
+ if self.opt['val']['suffix']:
178
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
179
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
180
+ else:
181
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
182
+ f'{img_name}_{self.opt["name"]}.png')
183
+ imwrite(sr_img, save_img_path)
184
+
185
+ if with_metrics:
186
+ # calculate metrics
187
+ for name, opt_ in self.opt['val']['metrics'].items():
188
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
189
+ if use_pbar:
190
+ pbar.update(1)
191
+ pbar.set_description(f'Test {img_name}')
192
+ if use_pbar:
193
+ pbar.close()
194
+
195
+ if with_metrics:
196
+ for metric in self.metric_results.keys():
197
+ self.metric_results[metric] /= (idx + 1)
198
+ # update the best metric result
199
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
200
+
201
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
202
+
203
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
204
+ log_str = f'Validation {dataset_name}\n'
205
+ for metric, value in self.metric_results.items():
206
+ log_str += f'\t # {metric}: {value:.4f}'
207
+ if hasattr(self, 'best_metric_results'):
208
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
209
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
210
+ log_str += '\n'
211
+
212
+ logger = get_root_logger()
213
+ logger.info(log_str)
214
+ if tb_logger:
215
+ for metric, value in self.metric_results.items():
216
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
217
+
218
+ def get_current_visuals(self):
219
+ out_dict = OrderedDict()
220
+ out_dict['lq'] = self.lq.detach().cpu()
221
+ out_dict['result'] = self.output.detach().cpu()
222
+ if hasattr(self, 'gt'):
223
+ out_dict['gt'] = self.gt.detach().cpu()
224
+ return out_dict
225
+
226
+ def save(self, epoch, current_iter):
227
+ if hasattr(self, 'net_g_ema'):
228
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
229
+ else:
230
+ self.save_network(self.net_g, 'net_g', current_iter)
231
+ self.save_training_state(epoch, current_iter)
basicsr/test.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from os import path as osp
4
+
5
+ from basicsr.data import build_dataloader, build_dataset
6
+ from basicsr.models import build_model
7
+ from basicsr.utils import get_root_logger, get_time_str, make_exp_dirs
8
+ from basicsr.utils.options import dict2str, parse_options
9
+
10
+
11
+ def test_pipeline(root_path):
12
+ # parse options, set distributed setting, set ramdom seed
13
+ opt, _ = parse_options(root_path, is_train=False)
14
+
15
+ torch.backends.cudnn.benchmark = True
16
+ # torch.backends.cudnn.deterministic = True
17
+
18
+ # mkdir and initialize loggers
19
+ make_exp_dirs(opt)
20
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
21
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
22
+ logger.info(dict2str(opt))
23
+
24
+ # create test dataset and dataloader
25
+ test_loaders = []
26
+ for _, dataset_opt in sorted(opt['datasets'].items()):
27
+ test_set = build_dataset(dataset_opt)
28
+ test_loader = build_dataloader(
29
+ test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
30
+ logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
31
+ test_loaders.append(test_loader)
32
+
33
+ # create model
34
+ model = build_model(opt)
35
+
36
+ for test_loader in test_loaders:
37
+ test_set_name = test_loader.dataset.opt['name']
38
+ logger.info(f'Testing {test_set_name}...')
39
+ model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
40
+
41
+
42
+ if __name__ == '__main__':
43
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
44
+ test_pipeline(root_path)
basicsr/utils/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .file_client import FileClient
2
+ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
3
+ from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
4
+ from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
5
+
6
+ __all__ = [
7
+ # file_client.py
8
+ 'FileClient',
9
+ # img_util.py
10
+ 'img2tensor',
11
+ 'tensor2img',
12
+ 'imfrombytes',
13
+ 'imwrite',
14
+ 'crop_border',
15
+ # logger.py
16
+ 'MessageLogger',
17
+ 'AvgTimer',
18
+ 'init_tb_logger',
19
+ 'init_wandb_logger',
20
+ 'get_root_logger',
21
+ 'get_env_info',
22
+ # misc.py
23
+ 'set_random_seed',
24
+ 'get_time_str',
25
+ 'mkdir_and_rename',
26
+ 'make_exp_dirs',
27
+ 'scandir',
28
+ 'check_resume',
29
+ 'sizeof_fmt',
30
+ ]
basicsr/utils/dist_util.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ def init_dist(launcher, backend='nccl', **kwargs):
11
+ if mp.get_start_method(allow_none=True) is None:
12
+ mp.set_start_method('spawn')
13
+ if launcher == 'pytorch':
14
+ _init_dist_pytorch(backend, **kwargs)
15
+ elif launcher == 'slurm':
16
+ _init_dist_slurm(backend, **kwargs)
17
+ else:
18
+ raise ValueError(f'Invalid launcher type: {launcher}')
19
+
20
+
21
+ def _init_dist_pytorch(backend, **kwargs):
22
+ rank = int(os.environ['RANK'])
23
+ num_gpus = torch.cuda.device_count()
24
+ torch.cuda.set_device(rank % num_gpus)
25
+ dist.init_process_group(backend=backend, **kwargs)
26
+
27
+
28
+ def _init_dist_slurm(backend, port=None):
29
+ """Initialize slurm distributed training environment.
30
+
31
+ If argument ``port`` is not specified, then the master port will be system
32
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33
+ environment variable, then a default port ``29500`` will be used.
34
+
35
+ Args:
36
+ backend (str): Backend of torch.distributed.
37
+ port (int, optional): Master port. Defaults to None.
38
+ """
39
+ proc_id = int(os.environ['SLURM_PROCID'])
40
+ ntasks = int(os.environ['SLURM_NTASKS'])
41
+ node_list = os.environ['SLURM_NODELIST']
42
+ num_gpus = torch.cuda.device_count()
43
+ torch.cuda.set_device(proc_id % num_gpus)
44
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45
+ # specify master port
46
+ if port is not None:
47
+ os.environ['MASTER_PORT'] = str(port)
48
+ elif 'MASTER_PORT' in os.environ:
49
+ pass # use MASTER_PORT in the environment variable
50
+ else:
51
+ # 29500 is torch.distributed default port
52
+ os.environ['MASTER_PORT'] = '29500'
53
+ os.environ['MASTER_ADDR'] = addr
54
+ os.environ['WORLD_SIZE'] = str(ntasks)
55
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56
+ os.environ['RANK'] = str(proc_id)
57
+ dist.init_process_group(backend=backend)
58
+
59
+
60
+ def get_dist_info():
61
+ if dist.is_available():
62
+ initialized = dist.is_initialized()
63
+ else:
64
+ initialized = False
65
+ if initialized:
66
+ rank = dist.get_rank()
67
+ world_size = dist.get_world_size()
68
+ else:
69
+ rank = 0
70
+ world_size = 1
71
+ return rank, world_size
72
+
73
+
74
+ def master_only(func):
75
+
76
+ @functools.wraps(func)
77
+ def wrapper(*args, **kwargs):
78
+ rank, _ = get_dist_info()
79
+ if rank == 0:
80
+ return func(*args, **kwargs)
81
+
82
+ return wrapper
basicsr/utils/file_client.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BaseStorageBackend(metaclass=ABCMeta):
6
+ """Abstract class of storage backends.
7
+
8
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
9
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10
+ as texts.
11
+ """
12
+
13
+ @abstractmethod
14
+ def get(self, filepath):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_text(self, filepath):
19
+ pass
20
+
21
+
22
+ class MemcachedBackend(BaseStorageBackend):
23
+ """Memcached storage backend.
24
+
25
+ Attributes:
26
+ server_list_cfg (str): Config file for memcached server list.
27
+ client_cfg (str): Config file for memcached client.
28
+ sys_path (str | None): Additional path to be appended to `sys.path`.
29
+ Default: None.
30
+ """
31
+
32
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33
+ if sys_path is not None:
34
+ import sys
35
+ sys.path.append(sys_path)
36
+ try:
37
+ import mc
38
+ except ImportError:
39
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
40
+
41
+ self.server_list_cfg = server_list_cfg
42
+ self.client_cfg = client_cfg
43
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44
+ # mc.pyvector servers as a point which points to a memory cache
45
+ self._mc_buffer = mc.pyvector()
46
+
47
+ def get(self, filepath):
48
+ filepath = str(filepath)
49
+ import mc
50
+ self._client.Get(filepath, self._mc_buffer)
51
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
52
+ return value_buf
53
+
54
+ def get_text(self, filepath):
55
+ raise NotImplementedError
56
+
57
+
58
+ class HardDiskBackend(BaseStorageBackend):
59
+ """Raw hard disks storage backend."""
60
+
61
+ def get(self, filepath):
62
+ filepath = str(filepath)
63
+ with open(filepath, 'rb') as f:
64
+ value_buf = f.read()
65
+ return value_buf
66
+
67
+ def get_text(self, filepath):
68
+ filepath = str(filepath)
69
+ with open(filepath, 'r') as f:
70
+ value_buf = f.read()
71
+ return value_buf
72
+
73
+
74
+ class LmdbBackend(BaseStorageBackend):
75
+ """Lmdb storage backend.
76
+
77
+ Args:
78
+ db_paths (str | list[str]): Lmdb database paths.
79
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80
+ readonly (bool, optional): Lmdb environment parameter. If True,
81
+ disallow any write operations. Default: True.
82
+ lock (bool, optional): Lmdb environment parameter. If False, when
83
+ concurrent access occurs, do not lock the database. Default: False.
84
+ readahead (bool, optional): Lmdb environment parameter. If False,
85
+ disable the OS filesystem readahead mechanism, which may improve
86
+ random read performance when a database is larger than RAM.
87
+ Default: False.
88
+
89
+ Attributes:
90
+ db_paths (list): Lmdb database path.
91
+ _client (list): A list of several lmdb envs.
92
+ """
93
+
94
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95
+ try:
96
+ import lmdb
97
+ except ImportError:
98
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
99
+
100
+ if isinstance(client_keys, str):
101
+ client_keys = [client_keys]
102
+
103
+ if isinstance(db_paths, list):
104
+ self.db_paths = [str(v) for v in db_paths]
105
+ elif isinstance(db_paths, str):
106
+ self.db_paths = [str(db_paths)]
107
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
109
+
110
+ self._client = {}
111
+ for client, path in zip(client_keys, self.db_paths):
112
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113
+
114
+ def get(self, filepath, client_key):
115
+ """Get values according to the filepath from one lmdb named client_key.
116
+
117
+ Args:
118
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119
+ client_key (str): Used for distinguishing different lmdb envs.
120
+ """
121
+ filepath = str(filepath)
122
+ assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
123
+ client = self._client[client_key]
124
+ with client.begin(write=False) as txn:
125
+ value_buf = txn.get(filepath.encode('ascii'))
126
+ return value_buf
127
+
128
+ def get_text(self, filepath):
129
+ raise NotImplementedError
130
+
131
+
132
+ class FileClient(object):
133
+ """A general file client to access files in different backend.
134
+
135
+ The client loads a file or text in a specified backend from its path
136
+ and return it as a binary file. it can also register other backend
137
+ accessor with a given name and backend class.
138
+
139
+ Attributes:
140
+ backend (str): The storage backend type. Options are "disk",
141
+ "memcached" and "lmdb".
142
+ client (:obj:`BaseStorageBackend`): The backend object.
143
+ """
144
+
145
+ _backends = {
146
+ 'disk': HardDiskBackend,
147
+ 'memcached': MemcachedBackend,
148
+ 'lmdb': LmdbBackend,
149
+ }
150
+
151
+ def __init__(self, backend='disk', **kwargs):
152
+ if backend not in self._backends:
153
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154
+ f' are {list(self._backends.keys())}')
155
+ self.backend = backend
156
+ self.client = self._backends[backend](**kwargs)
157
+
158
+ def get(self, filepath, client_key='default'):
159
+ # client_key is used only for lmdb, where different fileclients have
160
+ # different lmdb environments.
161
+ if self.backend == 'lmdb':
162
+ return self.client.get(filepath, client_key)
163
+ else:
164
+ return self.client.get(filepath)
165
+
166
+ def get_text(self, filepath):
167
+ return self.client.get_text(filepath)
basicsr/utils/img_util.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ if img.dtype == 'float64':
25
+ img = img.astype('float32')
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ img = torch.from_numpy(img.transpose(2, 0, 1))
28
+ if float32:
29
+ img = img.float()
30
+ return img
31
+
32
+ if isinstance(imgs, list):
33
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
+ else:
35
+ return _totensor(imgs, bgr2rgb, float32)
36
+
37
+
38
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
+ """Convert torch Tensors into image numpy arrays.
40
+
41
+ After clamping to [min, max], values will be normalized to [0, 1].
42
+
43
+ Args:
44
+ tensor (Tensor or list[Tensor]): Accept shapes:
45
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
+ 2) 3D Tensor of shape (3/1 x H x W);
47
+ 3) 2D Tensor of shape (H x W).
48
+ Tensor channel should be in RGB order.
49
+ rgb2bgr (bool): Whether to change rgb to bgr.
50
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
+ to uint8 type with range [0, 255]; otherwise, float type with
52
+ range [0, 1]. Default: ``np.uint8``.
53
+ min_max (tuple[int]): min and max values for clamp.
54
+
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
+
96
+
97
+ def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
+ """This implementation is slightly faster than tensor2img.
99
+ It now only supports torch tensor with shape (1, c, h, w).
100
+
101
+ Args:
102
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
+ min_max (tuple[int]): min and max values for clamp.
105
+ """
106
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
+ output = output.type(torch.uint8).cpu().numpy()
109
+ if rgb2bgr:
110
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
+ return output
112
+
113
+
114
+ def imfrombytes(content, flag='color', float32=False):
115
+ """Read an image from bytes.
116
+
117
+ Args:
118
+ content (bytes): Image bytes got from files or other streams.
119
+ flag (str): Flags specifying the color type of a loaded image,
120
+ candidates are `color`, `grayscale` and `unchanged`.
121
+ float32 (bool): Whether to change to float32., If True, will also norm
122
+ to [0, 1]. Default: False.
123
+
124
+ Returns:
125
+ ndarray: Loaded image array.
126
+ """
127
+ img_np = np.frombuffer(content, np.uint8)
128
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
+ img = cv2.imdecode(img_np, imread_flags[flag])
130
+ if float32:
131
+ img = img.astype(np.float32) / 255.
132
+ return img
133
+
134
+
135
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
136
+ """Write image to file.
137
+
138
+ Args:
139
+ img (ndarray): Image array to be written.
140
+ file_path (str): Image file path.
141
+ params (None or list): Same as opencv's :func:`imwrite` interface.
142
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
+ whether to create it automatically.
144
+
145
+ Returns:
146
+ bool: Successful or not.
147
+ """
148
+ if auto_mkdir:
149
+ dir_name = os.path.abspath(os.path.dirname(file_path))
150
+ os.makedirs(dir_name, exist_ok=True)
151
+ ok = cv2.imwrite(file_path, img, params)
152
+ if not ok:
153
+ raise IOError('Failed in writing images.')
154
+
155
+
156
+ def crop_border(imgs, crop_border):
157
+ """Crop borders of images.
158
+
159
+ Args:
160
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
+ crop_border (int): Crop border for each end of height and weight.
162
+
163
+ Returns:
164
+ list[ndarray]: Cropped images.
165
+ """
166
+ if crop_border == 0:
167
+ return imgs
168
+ else:
169
+ if isinstance(imgs, list):
170
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
+ else:
172
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
basicsr/utils/logger.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ from .dist_util import get_dist_info, master_only
6
+
7
+ initialized_logger = {}
8
+
9
+
10
+ class AvgTimer():
11
+
12
+ def __init__(self, window=200):
13
+ self.window = window # average window
14
+ self.current_time = 0
15
+ self.total_time = 0
16
+ self.count = 0
17
+ self.avg_time = 0
18
+ self.start()
19
+
20
+ def start(self):
21
+ self.start_time = self.tic = time.time()
22
+
23
+ def record(self):
24
+ self.count += 1
25
+ self.toc = time.time()
26
+ self.current_time = self.toc - self.tic
27
+ self.total_time += self.current_time
28
+ # calculate average time
29
+ self.avg_time = self.total_time / self.count
30
+
31
+ # reset
32
+ if self.count > self.window:
33
+ self.count = 0
34
+ self.total_time = 0
35
+
36
+ self.tic = time.time()
37
+
38
+ def get_current_time(self):
39
+ return self.current_time
40
+
41
+ def get_avg_time(self):
42
+ return self.avg_time
43
+
44
+
45
+ class MessageLogger():
46
+ """Message logger for printing.
47
+
48
+ Args:
49
+ opt (dict): Config. It contains the following keys:
50
+ name (str): Exp name.
51
+ logger (dict): Contains 'print_freq' (str) for logger interval.
52
+ train (dict): Contains 'total_iter' (int) for total iters.
53
+ use_tb_logger (bool): Use tensorboard logger.
54
+ start_iter (int): Start iter. Default: 1.
55
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
56
+ """
57
+
58
+ def __init__(self, opt, start_iter=1, tb_logger=None):
59
+ self.exp_name = opt['name']
60
+ self.interval = opt['logger']['print_freq']
61
+ self.start_iter = start_iter
62
+ self.max_iters = opt['train']['total_iter']
63
+ self.use_tb_logger = opt['logger']['use_tb_logger']
64
+ self.tb_logger = tb_logger
65
+ self.start_time = time.time()
66
+ self.logger = get_root_logger()
67
+
68
+ def reset_start_time(self):
69
+ self.start_time = time.time()
70
+
71
+ @master_only
72
+ def __call__(self, log_vars):
73
+ """Format logging message.
74
+
75
+ Args:
76
+ log_vars (dict): It contains the following keys:
77
+ epoch (int): Epoch number.
78
+ iter (int): Current iter.
79
+ lrs (list): List for learning rates.
80
+
81
+ time (float): Iter time.
82
+ data_time (float): Data time for each iter.
83
+ """
84
+ # epoch, iter, learning rates
85
+ epoch = log_vars.pop('epoch')
86
+ current_iter = log_vars.pop('iter')
87
+ lrs = log_vars.pop('lrs')
88
+
89
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
90
+ for v in lrs:
91
+ message += f'{v:.3e},'
92
+ message += ')] '
93
+
94
+ # time and estimated time
95
+ if 'time' in log_vars.keys():
96
+ iter_time = log_vars.pop('time')
97
+ data_time = log_vars.pop('data_time')
98
+
99
+ total_time = time.time() - self.start_time
100
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
101
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
102
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
103
+ message += f'[eta: {eta_str}, '
104
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
105
+
106
+ # other items, especially losses
107
+ for k, v in log_vars.items():
108
+ message += f'{k}: {v:.4e} '
109
+ # tensorboard logger
110
+ if self.use_tb_logger and 'debug' not in self.exp_name:
111
+ if k.startswith('l_'):
112
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
113
+ else:
114
+ self.tb_logger.add_scalar(k, v, current_iter)
115
+ self.logger.info(message)
116
+
117
+
118
+ @master_only
119
+ def init_tb_logger(log_dir):
120
+ from torch.utils.tensorboard import SummaryWriter
121
+ tb_logger = SummaryWriter(log_dir=log_dir)
122
+ return tb_logger
123
+
124
+
125
+ @master_only
126
+ def init_wandb_logger(opt):
127
+ """We now only use wandb to sync tensorboard log."""
128
+ import wandb
129
+ logger = get_root_logger()
130
+
131
+ project = opt['logger']['wandb']['project']
132
+ resume_id = opt['logger']['wandb'].get('resume_id')
133
+ if resume_id:
134
+ wandb_id = resume_id
135
+ resume = 'allow'
136
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
137
+ else:
138
+ wandb_id = wandb.util.generate_id()
139
+ resume = 'never'
140
+
141
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
142
+
143
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
144
+
145
+
146
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
147
+ """Get the root logger.
148
+
149
+ The logger will be initialized if it has not been initialized. By default a
150
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
151
+ also be added.
152
+
153
+ Args:
154
+ logger_name (str): root logger name. Default: 'basicsr'.
155
+ log_file (str | None): The log filename. If specified, a FileHandler
156
+ will be added to the root logger.
157
+ log_level (int): The root logger level. Note that only the process of
158
+ rank 0 is affected, while other processes will set the level to
159
+ "Error" and be silent most of the time.
160
+
161
+ Returns:
162
+ logging.Logger: The root logger.
163
+ """
164
+ logger = logging.getLogger(logger_name)
165
+ # if the logger has been initialized, just return it
166
+ if logger_name in initialized_logger:
167
+ return logger
168
+
169
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
170
+ stream_handler = logging.StreamHandler()
171
+ stream_handler.setFormatter(logging.Formatter(format_str))
172
+ logger.addHandler(stream_handler)
173
+ logger.propagate = False
174
+ rank, _ = get_dist_info()
175
+ if rank != 0:
176
+ logger.setLevel('ERROR')
177
+ elif log_file is not None:
178
+ logger.setLevel(log_level)
179
+ # add file handler
180
+ file_handler = logging.FileHandler(log_file, 'w')
181
+ file_handler.setFormatter(logging.Formatter(format_str))
182
+ file_handler.setLevel(log_level)
183
+ logger.addHandler(file_handler)
184
+ initialized_logger[logger_name] = True
185
+ return logger
186
+
187
+
188
+ def get_env_info():
189
+ """Get environment information.
190
+
191
+ Currently, only log the software version.
192
+ """
193
+ import torch
194
+ import torchvision
195
+
196
+ from basicsr.version import __version__
197
+ msg = r"""
198
+ ____ _ _____ ____
199
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
200
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
201
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
202
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
203
+ ______ __ __ __ __
204
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
205
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
206
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
207
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
208
+ """
209
+ msg += ('\nVersion Information: '
210
+ f'\n\tBasicSR: {__version__}'
211
+ f'\n\tPyTorch: {torch.__version__}'
212
+ f'\n\tTorchVision: {torchvision.__version__}')
213
+ return msg
basicsr/utils/matlab_functions.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def cubic(x):
7
+ """cubic function used for calculate_weights_indices."""
8
+ absx = torch.abs(x)
9
+ absx2 = absx**2
10
+ absx3 = absx**3
11
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
13
+ (absx <= 2)).type_as(absx))
14
+
15
+
16
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
17
+ """Calculate weights and indices, used for imresize function.
18
+
19
+ Args:
20
+ in_length (int): Input length.
21
+ out_length (int): Output length.
22
+ scale (float): Scale factor.
23
+ kernel_width (int): Kernel width.
24
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
25
+ """
26
+
27
+ if (scale < 1) and antialiasing:
28
+ # Use a modified kernel (larger kernel width) to simultaneously
29
+ # interpolate and antialias
30
+ kernel_width = kernel_width / scale
31
+
32
+ # Output-space coordinates
33
+ x = torch.linspace(1, out_length, out_length)
34
+
35
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
36
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
37
+ # space maps to 1.5 in input space.
38
+ u = x / scale + 0.5 * (1 - 1 / scale)
39
+
40
+ # What is the left-most pixel that can be involved in the computation?
41
+ left = torch.floor(u - kernel_width / 2)
42
+
43
+ # What is the maximum number of pixels that can be involved in the
44
+ # computation? Note: it's OK to use an extra pixel here; if the
45
+ # corresponding weights are all zero, it will be eliminated at the end
46
+ # of this function.
47
+ p = math.ceil(kernel_width) + 2
48
+
49
+ # The indices of the input pixels involved in computing the k-th output
50
+ # pixel are in row k of the indices matrix.
51
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
52
+ out_length, p)
53
+
54
+ # The weights used to compute the k-th output pixel are in row k of the
55
+ # weights matrix.
56
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
57
+
58
+ # apply cubic kernel
59
+ if (scale < 1) and antialiasing:
60
+ weights = scale * cubic(distance_to_center * scale)
61
+ else:
62
+ weights = cubic(distance_to_center)
63
+
64
+ # Normalize the weights matrix so that each row sums to 1.
65
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
66
+ weights = weights / weights_sum.expand(out_length, p)
67
+
68
+ # If a column in weights is all zero, get rid of it. only consider the
69
+ # first and last column.
70
+ weights_zero_tmp = torch.sum((weights == 0), 0)
71
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
72
+ indices = indices.narrow(1, 1, p - 2)
73
+ weights = weights.narrow(1, 1, p - 2)
74
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
75
+ indices = indices.narrow(1, 0, p - 2)
76
+ weights = weights.narrow(1, 0, p - 2)
77
+ weights = weights.contiguous()
78
+ indices = indices.contiguous()
79
+ sym_len_s = -indices.min() + 1
80
+ sym_len_e = indices.max() - in_length
81
+ indices = indices + sym_len_s - 1
82
+ return weights, indices, int(sym_len_s), int(sym_len_e)
83
+
84
+
85
+ @torch.no_grad()
86
+ def imresize(img, scale, antialiasing=True):
87
+ """imresize function same as MATLAB.
88
+
89
+ It now only supports bicubic.
90
+ The same scale applies for both height and width.
91
+
92
+ Args:
93
+ img (Tensor | Numpy array):
94
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
95
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
96
+ scale (float): Scale factor. The same scale applies for both height
97
+ and width.
98
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
99
+ Default: True.
100
+
101
+ Returns:
102
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
103
+ """
104
+ squeeze_flag = False
105
+ if type(img).__module__ == np.__name__: # numpy type
106
+ numpy_type = True
107
+ if img.ndim == 2:
108
+ img = img[:, :, None]
109
+ squeeze_flag = True
110
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
111
+ else:
112
+ numpy_type = False
113
+ if img.ndim == 2:
114
+ img = img.unsqueeze(0)
115
+ squeeze_flag = True
116
+
117
+ in_c, in_h, in_w = img.size()
118
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
119
+ kernel_width = 4
120
+ kernel = 'cubic'
121
+
122
+ # get weights and indices
123
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
124
+ antialiasing)
125
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
126
+ antialiasing)
127
+ # process H dimension
128
+ # symmetric copying
129
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
130
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
131
+
132
+ sym_patch = img[:, :sym_len_hs, :]
133
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
135
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
136
+
137
+ sym_patch = img[:, -sym_len_he:, :]
138
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
139
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
140
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
141
+
142
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
143
+ kernel_width = weights_h.size(1)
144
+ for i in range(out_h):
145
+ idx = int(indices_h[i][0])
146
+ for j in range(in_c):
147
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
148
+
149
+ # process W dimension
150
+ # symmetric copying
151
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
152
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
153
+
154
+ sym_patch = out_1[:, :, :sym_len_ws]
155
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
156
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
157
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
158
+
159
+ sym_patch = out_1[:, :, -sym_len_we:]
160
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
161
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
162
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
163
+
164
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
165
+ kernel_width = weights_w.size(1)
166
+ for i in range(out_w):
167
+ idx = int(indices_w[i][0])
168
+ for j in range(in_c):
169
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
170
+
171
+ if squeeze_flag:
172
+ out_2 = out_2.squeeze(0)
173
+ if numpy_type:
174
+ out_2 = out_2.numpy()
175
+ if not squeeze_flag:
176
+ out_2 = out_2.transpose(1, 2, 0)
177
+
178
+ return out_2
179
+
180
+
181
+ def rgb2ycbcr(img, y_only=False):
182
+ """Convert a RGB image to YCbCr image.
183
+
184
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
185
+ It implements the ITU-R BT.601 conversion for standard-definition
186
+ television. See more details in
187
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
188
+
189
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
190
+ In OpenCV, it implements a JPEG conversion. See more details in
191
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
192
+
193
+ Args:
194
+ img (ndarray): The input image. It accepts:
195
+ 1. np.uint8 type with range [0, 255];
196
+ 2. np.float32 type with range [0, 1].
197
+ y_only (bool): Whether to only return Y channel. Default: False.
198
+
199
+ Returns:
200
+ ndarray: The converted YCbCr image. The output image has the same type
201
+ and range as input image.
202
+ """
203
+ img_type = img.dtype
204
+ img = _convert_input_type_range(img)
205
+ if y_only:
206
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
207
+ else:
208
+ out_img = np.matmul(
209
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
210
+ out_img = _convert_output_type_range(out_img, img_type)
211
+ return out_img
212
+
213
+
214
+ def bgr2ycbcr(img, y_only=False):
215
+ """Convert a BGR image to YCbCr image.
216
+
217
+ The bgr version of rgb2ycbcr.
218
+ It implements the ITU-R BT.601 conversion for standard-definition
219
+ television. See more details in
220
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
221
+
222
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
223
+ In OpenCV, it implements a JPEG conversion. See more details in
224
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
225
+
226
+ Args:
227
+ img (ndarray): The input image. It accepts:
228
+ 1. np.uint8 type with range [0, 255];
229
+ 2. np.float32 type with range [0, 1].
230
+ y_only (bool): Whether to only return Y channel. Default: False.
231
+
232
+ Returns:
233
+ ndarray: The converted YCbCr image. The output image has the same type
234
+ and range as input image.
235
+ """
236
+ img_type = img.dtype
237
+ img = _convert_input_type_range(img)
238
+ if y_only:
239
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
240
+ else:
241
+ out_img = np.matmul(
242
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
243
+ out_img = _convert_output_type_range(out_img, img_type)
244
+ return out_img
245
+
246
+
247
+ def ycbcr2rgb(img):
248
+ """Convert a YCbCr image to RGB image.
249
+
250
+ This function produces the same results as Matlab's ycbcr2rgb function.
251
+ It implements the ITU-R BT.601 conversion for standard-definition
252
+ television. See more details in
253
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
254
+
255
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
256
+ In OpenCV, it implements a JPEG conversion. See more details in
257
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
258
+
259
+ Args:
260
+ img (ndarray): The input image. It accepts:
261
+ 1. np.uint8 type with range [0, 255];
262
+ 2. np.float32 type with range [0, 1].
263
+
264
+ Returns:
265
+ ndarray: The converted RGB image. The output image has the same type
266
+ and range as input image.
267
+ """
268
+ img_type = img.dtype
269
+ img = _convert_input_type_range(img) * 255
270
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
271
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
272
+ out_img = _convert_output_type_range(out_img, img_type)
273
+ return out_img
274
+
275
+
276
+ def ycbcr2bgr(img):
277
+ """Convert a YCbCr image to BGR image.
278
+
279
+ The bgr version of ycbcr2rgb.
280
+ It implements the ITU-R BT.601 conversion for standard-definition
281
+ television. See more details in
282
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
283
+
284
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
285
+ In OpenCV, it implements a JPEG conversion. See more details in
286
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
287
+
288
+ Args:
289
+ img (ndarray): The input image. It accepts:
290
+ 1. np.uint8 type with range [0, 255];
291
+ 2. np.float32 type with range [0, 1].
292
+
293
+ Returns:
294
+ ndarray: The converted BGR image. The output image has the same type
295
+ and range as input image.
296
+ """
297
+ img_type = img.dtype
298
+ img = _convert_input_type_range(img) * 255
299
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
300
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
301
+ out_img = _convert_output_type_range(out_img, img_type)
302
+ return out_img
303
+
304
+
305
+ def _convert_input_type_range(img):
306
+ """Convert the type and range of the input image.
307
+
308
+ It converts the input image to np.float32 type and range of [0, 1].
309
+ It is mainly used for pre-processing the input image in colorspace
310
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
311
+
312
+ Args:
313
+ img (ndarray): The input image. It accepts:
314
+ 1. np.uint8 type with range [0, 255];
315
+ 2. np.float32 type with range [0, 1].
316
+
317
+ Returns:
318
+ (ndarray): The converted image with type of np.float32 and range of
319
+ [0, 1].
320
+ """
321
+ img_type = img.dtype
322
+ img = img.astype(np.float32)
323
+ if img_type == np.float32:
324
+ pass
325
+ elif img_type == np.uint8:
326
+ img /= 255.
327
+ else:
328
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
329
+ return img
330
+
331
+
332
+ def _convert_output_type_range(img, dst_type):
333
+ """Convert the type and range of the image according to dst_type.
334
+
335
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
336
+ images will be converted to np.uint8 type with range [0, 255]. If
337
+ `dst_type` is np.float32, it converts the image to np.float32 type with
338
+ range [0, 1].
339
+ It is mainly used for post-processing images in colorspace conversion
340
+ functions such as rgb2ycbcr and ycbcr2rgb.
341
+
342
+ Args:
343
+ img (ndarray): The image to be converted with np.float32 type and
344
+ range [0, 255].
345
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
346
+ converts the image to np.uint8 type with range [0, 255]. If
347
+ dst_type is np.float32, it converts the image to np.float32 type
348
+ with range [0, 1].
349
+
350
+ Returns:
351
+ (ndarray): The converted image with desired type and range.
352
+ """
353
+ if dst_type not in (np.uint8, np.float32):
354
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
355
+ if dst_type == np.uint8:
356
+ img = img.round()
357
+ else:
358
+ img /= 255.
359
+ return img.astype(dst_type)
basicsr/utils/misc.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import time
5
+ import torch
6
+ from os import path as osp
7
+
8
+ from .dist_util import master_only
9
+
10
+
11
+ def set_random_seed(seed):
12
+ """Set random seeds."""
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ def get_time_str():
21
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
22
+
23
+
24
+ def mkdir_and_rename(path):
25
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
26
+
27
+ Args:
28
+ path (str): Folder path.
29
+ """
30
+ if osp.exists(path):
31
+ new_name = path + '_archived_' + get_time_str()
32
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
33
+ os.rename(path, new_name)
34
+ os.makedirs(path, exist_ok=True)
35
+
36
+
37
+ @master_only
38
+ def make_exp_dirs(opt):
39
+ """Make dirs for experiments."""
40
+ path_opt = opt['path'].copy()
41
+ if opt['is_train']:
42
+ mkdir_and_rename(path_opt.pop('experiments_root'))
43
+ else:
44
+ mkdir_and_rename(path_opt.pop('results_root'))
45
+ for key, path in path_opt.items():
46
+ if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
47
+ continue
48
+ else:
49
+ os.makedirs(path, exist_ok=True)
50
+
51
+
52
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
53
+ """Scan a directory to find the interested files.
54
+
55
+ Args:
56
+ dir_path (str): Path of the directory.
57
+ suffix (str | tuple(str), optional): File suffix that we are
58
+ interested in. Default: None.
59
+ recursive (bool, optional): If set to True, recursively scan the
60
+ directory. Default: False.
61
+ full_path (bool, optional): If set to True, include the dir_path.
62
+ Default: False.
63
+
64
+ Returns:
65
+ A generator for all the interested files with relative paths.
66
+ """
67
+
68
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
69
+ raise TypeError('"suffix" must be a string or tuple of strings')
70
+
71
+ root = dir_path
72
+
73
+ def _scandir(dir_path, suffix, recursive):
74
+ for entry in os.scandir(dir_path):
75
+ if not entry.name.startswith('.') and entry.is_file():
76
+ if full_path:
77
+ return_path = entry.path
78
+ else:
79
+ return_path = osp.relpath(entry.path, root)
80
+
81
+ if suffix is None:
82
+ yield return_path
83
+ elif return_path.endswith(suffix):
84
+ yield return_path
85
+ else:
86
+ if recursive:
87
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
88
+ else:
89
+ continue
90
+
91
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
92
+
93
+
94
+ def check_resume(opt, resume_iter):
95
+ """Check resume states and pretrain_network paths.
96
+
97
+ Args:
98
+ opt (dict): Options.
99
+ resume_iter (int): Resume iteration.
100
+ """
101
+ if opt['path']['resume_state']:
102
+ # get all the networks
103
+ networks = [key for key in opt.keys() if key.startswith('network_')]
104
+ flag_pretrain = False
105
+ for network in networks:
106
+ if opt['path'].get(f'pretrain_{network}') is not None:
107
+ flag_pretrain = True
108
+ if flag_pretrain:
109
+ print('pretrain_network path will be ignored during resuming.')
110
+ # set pretrained model paths
111
+ for network in networks:
112
+ name = f'pretrain_{network}'
113
+ basename = network.replace('network_', '')
114
+ if opt['path'].get('ignore_resume_networks') is None or (network
115
+ not in opt['path']['ignore_resume_networks']):
116
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
117
+ print(f"Set {name} to {opt['path'][name]}")
118
+
119
+ # change param_key to params in resume
120
+ param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
121
+ for param_key in param_keys:
122
+ if opt['path'][param_key] == 'params_ema':
123
+ opt['path'][param_key] = 'params'
124
+ print(f'Set {param_key} to params')
125
+
126
+
127
+ def sizeof_fmt(size, suffix='B'):
128
+ """Get human readable file size.
129
+
130
+ Args:
131
+ size (int): File size.
132
+ suffix (str): Suffix. Default: 'B'.
133
+
134
+ Return:
135
+ str: Formatted file siz.
136
+ """
137
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
138
+ if abs(size) < 1024.0:
139
+ return f'{size:3.1f} {unit}{suffix}'
140
+ size /= 1024.0
141
+ return f'{size:3.1f} Y{suffix}'
basicsr/utils/options.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import yaml
5
+ from collections import OrderedDict
6
+ from os import path as osp
7
+
8
+ from basicsr.utils import set_random_seed
9
+ from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
10
+
11
+
12
+ def ordered_yaml():
13
+ """Support OrderedDict for yaml.
14
+
15
+ Returns:
16
+ yaml Loader and Dumper.
17
+ """
18
+ try:
19
+ from yaml import CDumper as Dumper
20
+ from yaml import CLoader as Loader
21
+ except ImportError:
22
+ from yaml import Dumper, Loader
23
+
24
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25
+
26
+ def dict_representer(dumper, data):
27
+ return dumper.represent_dict(data.items())
28
+
29
+ def dict_constructor(loader, node):
30
+ return OrderedDict(loader.construct_pairs(node))
31
+
32
+ Dumper.add_representer(OrderedDict, dict_representer)
33
+ Loader.add_constructor(_mapping_tag, dict_constructor)
34
+ return Loader, Dumper
35
+
36
+
37
+ def dict2str(opt, indent_level=1):
38
+ """dict to string for printing options.
39
+
40
+ Args:
41
+ opt (dict): Option dict.
42
+ indent_level (int): Indent level. Default: 1.
43
+
44
+ Return:
45
+ (str): Option string for printing.
46
+ """
47
+ msg = '\n'
48
+ for k, v in opt.items():
49
+ if isinstance(v, dict):
50
+ msg += ' ' * (indent_level * 2) + k + ':['
51
+ msg += dict2str(v, indent_level + 1)
52
+ msg += ' ' * (indent_level * 2) + ']\n'
53
+ else:
54
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
55
+ return msg
56
+
57
+
58
+ def _postprocess_yml_value(value):
59
+ # None
60
+ if value == '~' or value.lower() == 'none':
61
+ return None
62
+ # bool
63
+ if value.lower() == 'true':
64
+ return True
65
+ elif value.lower() == 'false':
66
+ return False
67
+ # !!float number
68
+ if value.startswith('!!float'):
69
+ return float(value.replace('!!float', ''))
70
+ # number
71
+ if value.isdigit():
72
+ return int(value)
73
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
74
+ return float(value)
75
+ # list
76
+ if value.startswith('['):
77
+ return eval(value)
78
+ # str
79
+ return value
80
+
81
+
82
+ def parse_options(root_path, is_train=True):
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
85
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
86
+ parser.add_argument('--auto_resume', action='store_true')
87
+ parser.add_argument('--debug', action='store_true')
88
+ parser.add_argument('--local_rank', type=int, default=0)
89
+ parser.add_argument(
90
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
91
+ args = parser.parse_args()
92
+
93
+ # parse yml to dict
94
+ with open(args.opt, mode='r') as f:
95
+ opt = yaml.load(f, Loader=ordered_yaml()[0])
96
+
97
+ # distributed settings
98
+ if args.launcher == 'none':
99
+ opt['dist'] = False
100
+ print('Disable distributed.', flush=True)
101
+ else:
102
+ opt['dist'] = True
103
+ if args.launcher == 'slurm' and 'dist_params' in opt:
104
+ init_dist(args.launcher, **opt['dist_params'])
105
+ else:
106
+ init_dist(args.launcher)
107
+ opt['rank'], opt['world_size'] = get_dist_info()
108
+
109
+ # random seed
110
+ seed = opt.get('manual_seed')
111
+ if seed is None:
112
+ seed = random.randint(1, 10000)
113
+ opt['manual_seed'] = seed
114
+ set_random_seed(seed + opt['rank'])
115
+
116
+ # force to update yml options
117
+ if args.force_yml is not None:
118
+ for entry in args.force_yml:
119
+ # now do not support creating new keys
120
+ keys, value = entry.split('=')
121
+ keys, value = keys.strip(), value.strip()
122
+ value = _postprocess_yml_value(value)
123
+ eval_str = 'opt'
124
+ for key in keys.split(':'):
125
+ eval_str += f'["{key}"]'
126
+ eval_str += '=value'
127
+ # using exec function
128
+ exec(eval_str)
129
+
130
+ opt['auto_resume'] = args.auto_resume
131
+ opt['is_train'] = is_train
132
+
133
+ # debug setting
134
+ if args.debug and not opt['name'].startswith('debug'):
135
+ opt['name'] = 'debug_' + opt['name']
136
+
137
+ if opt['num_gpu'] == 'auto':
138
+ opt['num_gpu'] = torch.cuda.device_count()
139
+
140
+ # datasets
141
+ for phase, dataset in opt['datasets'].items():
142
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
143
+ phase = phase.split('_')[0]
144
+ dataset['phase'] = phase
145
+ if 'scale' in opt:
146
+ dataset['scale'] = opt['scale']
147
+ if dataset.get('dataroot_gt') is not None:
148
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
149
+ if dataset.get('dataroot_lq') is not None:
150
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
151
+
152
+ # paths
153
+ for key, val in opt['path'].items():
154
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
155
+ opt['path'][key] = osp.expanduser(val)
156
+
157
+ if is_train:
158
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
159
+ opt['path']['experiments_root'] = experiments_root
160
+ opt['path']['models'] = osp.join(experiments_root, 'models')
161
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
162
+ opt['path']['log'] = experiments_root
163
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
164
+
165
+ # change some options for debug mode
166
+ if 'debug' in opt['name']:
167
+ if 'val' in opt:
168
+ opt['val']['val_freq'] = 8
169
+ opt['logger']['print_freq'] = 1
170
+ opt['logger']['save_checkpoint_freq'] = 8
171
+ else: # test
172
+ results_root = osp.join(root_path, 'results', opt['name'])
173
+ opt['path']['results_root'] = results_root
174
+ opt['path']['log'] = results_root
175
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
176
+
177
+ return opt, args
178
+
179
+
180
+ @master_only
181
+ def copy_opt_file(opt_file, experiments_root):
182
+ # copy the yml file to the experiment root
183
+ import sys
184
+ import time
185
+ from shutil import copyfile
186
+ cmd = ' '.join(sys.argv)
187
+ filename = osp.join(experiments_root, osp.basename(opt_file))
188
+ copyfile(opt_file, filename)
189
+
190
+ with open(filename, 'r+') as f:
191
+ lines = f.readlines()
192
+ lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
193
+ f.seek(0)
194
+ f.writelines(lines)
basicsr/utils/registry.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
+
3
+
4
+ class Registry():
5
+ """
6
+ The registry that provides name -> object mapping, to support third-party
7
+ users' custom modules.
8
+
9
+ To create a registry (e.g. a backbone registry):
10
+
11
+ .. code-block:: python
12
+
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+
15
+ To register an object:
16
+
17
+ .. code-block:: python
18
+
19
+ @BACKBONE_REGISTRY.register()
20
+ class MyBackbone():
21
+ ...
22
+
23
+ Or:
24
+
25
+ .. code-block:: python
26
+
27
+ BACKBONE_REGISTRY.register(MyBackbone)
28
+ """
29
+
30
+ def __init__(self, name):
31
+ """
32
+ Args:
33
+ name (str): the name of this registry
34
+ """
35
+ self._name = name
36
+ self._obj_map = {}
37
+
38
+ def _do_register(self, name, obj):
39
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
40
+ f"in '{self._name}' registry!")
41
+ self._obj_map[name] = obj
42
+
43
+ def register(self, obj=None):
44
+ """
45
+ Register the given object under the the name `obj.__name__`.
46
+ Can be used as either a decorator or not.
47
+ See docstring of this class for usage.
48
+ """
49
+ if obj is None:
50
+ # used as a decorator
51
+ def deco(func_or_class):
52
+ name = func_or_class.__name__
53
+ self._do_register(name, func_or_class)
54
+ return func_or_class
55
+
56
+ return deco
57
+
58
+ # used as a function call
59
+ name = obj.__name__
60
+ self._do_register(name, obj)
61
+
62
+ def get(self, name):
63
+ ret = self._obj_map.get(name)
64
+ if ret is None:
65
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
66
+ return ret
67
+
68
+ def __contains__(self, name):
69
+ return name in self._obj_map
70
+
71
+ def __iter__(self):
72
+ return iter(self._obj_map.items())
73
+
74
+ def keys(self):
75
+ return self._obj_map.keys()
76
+
77
+
78
+ DATASET_REGISTRY = Registry('dataset')
79
+ ARCH_REGISTRY = Registry('arch')
80
+ MODEL_REGISTRY = Registry('model')
81
+ LOSS_REGISTRY = Registry('loss')
82
+ METRIC_REGISTRY = Registry('metric')
basicsr/version.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # GENERATED VERSION FILE
2
+ # TIME: Thu Sep 22 07:20:35 2022
3
+ __version__ = '1.3.5'
4
+ __gitsha__ = 'cbc9a18'
5
+ version_info = (1, 3, 5)
datasets/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Dwonload the [testing](https://ufile.io/6ek67nf8) datasets and place them here.
2
+
experiments/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Dwonload the pre-trained [models](https://ufile.io/4u0ms0h5) and place them in 'pretrained_models'.
2
+
experiments/pretrained_models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Place pretrained models here.
options/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ For more information about testing configuration, please refer to [Configuration](https://github.com/XPixelGroup/BasicSR/blob/master/docs/Config.md).
2
+
options/Test/test_DAT_2_x2.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x2
3
+ model_type: SRModel
4
+ scale: 2
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
15
+ filename_tmpl: '{}x2'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
25
+ filename_tmpl: '{}x2'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
35
+ filename_tmpl: '{}x2'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
45
+ filename_tmpl: '{}x2'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
55
+ filename_tmpl: '{}_LRBI_x2'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 2
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_2_x2.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 2
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 2
93
+ test_y_channel: True
options/Test/test_DAT_2_x3.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x3
3
+ model_type: SRModel
4
+ scale: 3
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
15
+ filename_tmpl: '{}x3'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
25
+ filename_tmpl: '{}x3'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
35
+ filename_tmpl: '{}x3'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
45
+ filename_tmpl: '{}x3'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
55
+ filename_tmpl: '{}_LRBI_x3'
56
+ io_backend:
57
+ type: disk
58
+
59
+ # network structures
60
+ network_g:
61
+ type: DAT
62
+ upscale: 3
63
+ in_chans: 3
64
+ img_size: 64
65
+ img_range: 1.
66
+ split_size: [8,32]
67
+ depth: [6,6,6,6,6,6]
68
+ embed_dim: 180
69
+ num_heads: [6,6,6,6,6,6]
70
+ expansion_factor: 2
71
+ resi_connection: '1conv'
72
+
73
+ # path
74
+ path:
75
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_2_x3.pth
76
+ strict_load_g: True
77
+
78
+ # validation settings
79
+ val:
80
+ save_img: False
81
+ suffix: ~ # add suffix to saved images, if None, use exp name
82
+ use_chop: False
83
+
84
+ metrics:
85
+ psnr: # metric name, can be arbitrary
86
+ type: calculate_psnr
87
+ crop_border: 3
88
+ test_y_channel: True
89
+ ssim:
90
+ type: calculate_ssim
91
+ crop_border: 3
92
+ test_y_channel: True
options/Test/test_DAT_2_x4.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x4
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
15
+ filename_tmpl: '{}x4'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
25
+ filename_tmpl: '{}x4'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35
+ filename_tmpl: '{}x4'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
45
+ filename_tmpl: '{}x4'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
55
+ filename_tmpl: '{}_LRBI_x4'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 4
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_2_x4.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 4
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 4
93
+ test_y_channel: True
options/Test/test_DAT_L_x2.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_L_x2
3
+ model_type: SRModel
4
+ scale: 2
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
15
+ filename_tmpl: '{}x2'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
25
+ filename_tmpl: '{}x2'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
35
+ filename_tmpl: '{}x2'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
45
+ filename_tmpl: '{}x2'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
55
+ filename_tmpl: '{}_LRBI_x2'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 2
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 4
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_L_x2.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 2
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 2
93
+ test_y_channel: True
options/Test/test_DAT_L_x3.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_L_x3
3
+ model_type: SRModel
4
+ scale: 3
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
15
+ filename_tmpl: '{}x3'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
25
+ filename_tmpl: '{}x3'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
35
+ filename_tmpl: '{}x3'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
45
+ filename_tmpl: '{}x3'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
55
+ filename_tmpl: '{}_LRBI_x3'
56
+ io_backend:
57
+ type: disk
58
+
59
+ # network structures
60
+ network_g:
61
+ type: DAT
62
+ upscale: 3
63
+ in_chans: 3
64
+ img_size: 64
65
+ img_range: 1.
66
+ split_size: [8,32]
67
+ depth: [6,6,6,6,6,6]
68
+ embed_dim: 180
69
+ num_heads: [6,6,6,6,6,6]
70
+ expansion_factor: 4
71
+ resi_connection: '1conv'
72
+
73
+ # path
74
+ path:
75
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_L_x3.pth
76
+ strict_load_g: True
77
+
78
+ # validation settings
79
+ val:
80
+ save_img: False
81
+ suffix: ~ # add suffix to saved images, if None, use exp name
82
+ use_chop: False
83
+
84
+ metrics:
85
+ psnr: # metric name, can be arbitrary
86
+ type: calculate_psnr
87
+ crop_border: 3
88
+ test_y_channel: True
89
+ ssim:
90
+ type: calculate_ssim
91
+ crop_border: 3
92
+ test_y_channel: True
options/Test/test_DAT_L_x4.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_L_x4
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
15
+ filename_tmpl: '{}x4'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
25
+ filename_tmpl: '{}x4'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35
+ filename_tmpl: '{}x4'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
45
+ filename_tmpl: '{}x4'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
55
+ filename_tmpl: '{}_LRBI_x4'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 4
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 4
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_L_x4.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 4
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 4
93
+ test_y_channel: True
options/Test/test_DAT_x2.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_x2
3
+ model_type: SRModel
4
+ scale: 2
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
15
+ filename_tmpl: '{}x2'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
25
+ filename_tmpl: '{}x2'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
35
+ filename_tmpl: '{}x2'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
45
+ filename_tmpl: '{}x2'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
55
+ filename_tmpl: '{}_LRBI_x2'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 2
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,16]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_x2.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 2
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 2
93
+ test_y_channel: True
options/Test/test_DAT_x3.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_x3
3
+ model_type: SRModel
4
+ scale: 3
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
15
+ filename_tmpl: '{}x3'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
25
+ filename_tmpl: '{}x3'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
35
+ filename_tmpl: '{}x3'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
45
+ filename_tmpl: '{}x3'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
55
+ filename_tmpl: '{}_LRBI_x3'
56
+ io_backend:
57
+ type: disk
58
+
59
+ # network structures
60
+ network_g:
61
+ type: DAT
62
+ upscale: 3
63
+ in_chans: 3
64
+ img_size: 64
65
+ img_range: 1.
66
+ split_size: [8,16]
67
+ depth: [6,6,6,6,6,6]
68
+ embed_dim: 180
69
+ num_heads: [6,6,6,6,6,6]
70
+ expansion_factor: 2
71
+ resi_connection: '1conv'
72
+
73
+ # path
74
+ path:
75
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_x3.pth
76
+ strict_load_g: True
77
+
78
+ # validation settings
79
+ val:
80
+ save_img: False
81
+ suffix: ~ # add suffix to saved images, if None, use exp name
82
+ use_chop: False
83
+
84
+ metrics:
85
+ psnr: # metric name, can be arbitrary
86
+ type: calculate_psnr
87
+ crop_border: 3
88
+ test_y_channel: True
89
+ ssim:
90
+ type: calculate_ssim
91
+ crop_border: 3
92
+ test_y_channel: True
options/Test/test_DAT_x4.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_x4
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
15
+ filename_tmpl: '{}x4'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
25
+ filename_tmpl: '{}x4'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35
+ filename_tmpl: '{}x4'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
45
+ filename_tmpl: '{}x4'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
55
+ filename_tmpl: '{}_LRBI_x4'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 4
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,16]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_x4.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: False
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 4
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 4
93
+ test_y_channel: True
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ lmdb
4
+ numpy>=1.17
5
+ opencv-python
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scipy
11
+ tb-nightly
12
+ torch>=1.7
13
+ torchvision
14
+ tqdm
15
+ yapf
16
+ timm
17
+ einops
18
+ h5py
results/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The testing results.
setup.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ import os
6
+ import subprocess
7
+ import time
8
+ import torch
9
+ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
10
+
11
+ version_file = 'basicsr/version.py'
12
+
13
+
14
+ def readme():
15
+ with open('README.md', encoding='utf-8') as f:
16
+ content = f.read()
17
+ return content
18
+
19
+
20
+ def get_git_hash():
21
+
22
+ def _minimal_ext_cmd(cmd):
23
+ # construct minimal environment
24
+ env = {}
25
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
26
+ v = os.environ.get(k)
27
+ if v is not None:
28
+ env[k] = v
29
+ # LANGUAGE is used on win32
30
+ env['LANGUAGE'] = 'C'
31
+ env['LANG'] = 'C'
32
+ env['LC_ALL'] = 'C'
33
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
34
+ return out
35
+
36
+ try:
37
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
38
+ sha = out.strip().decode('ascii')
39
+ except OSError:
40
+ sha = 'unknown'
41
+
42
+ return sha
43
+
44
+
45
+ def get_hash():
46
+ if os.path.exists('.git'):
47
+ sha = get_git_hash()[:7]
48
+ # currently ignore this
49
+ # elif os.path.exists(version_file):
50
+ # try:
51
+ # from basicsr.version import __version__
52
+ # sha = __version__.split('+')[-1]
53
+ # except ImportError:
54
+ # raise ImportError('Unable to get git version')
55
+ else:
56
+ sha = 'unknown'
57
+
58
+ return sha
59
+
60
+
61
+ def write_version_py():
62
+ content = """# GENERATED VERSION FILE
63
+ # TIME: {}
64
+ __version__ = '{}'
65
+ __gitsha__ = '{}'
66
+ version_info = ({})
67
+ """
68
+ sha = get_hash()
69
+ with open('VERSION', 'r') as f:
70
+ SHORT_VERSION = f.read().strip()
71
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
72
+
73
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
74
+ with open(version_file, 'w') as f:
75
+ f.write(version_file_str)
76
+
77
+
78
+ def get_version():
79
+ with open(version_file, 'r') as f:
80
+ exec(compile(f.read(), version_file, 'exec'))
81
+ return locals()['__version__']
82
+
83
+
84
+ def make_cuda_ext(name, module, sources, sources_cuda=None):
85
+ if sources_cuda is None:
86
+ sources_cuda = []
87
+ define_macros = []
88
+ extra_compile_args = {'cxx': []}
89
+
90
+ if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91
+ define_macros += [('WITH_CUDA', None)]
92
+ extension = CUDAExtension
93
+ extra_compile_args['nvcc'] = [
94
+ '-D__CUDA_NO_HALF_OPERATORS__',
95
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
96
+ '-D__CUDA_NO_HALF2_OPERATORS__',
97
+ ]
98
+ sources += sources_cuda
99
+ else:
100
+ print(f'Compiling {name} without CUDA')
101
+ extension = CppExtension
102
+
103
+ return extension(
104
+ name=f'{module}.{name}',
105
+ sources=[os.path.join(*module.split('.'), p) for p in sources],
106
+ define_macros=define_macros,
107
+ extra_compile_args=extra_compile_args)
108
+
109
+
110
+ def get_requirements(filename='requirements.txt'):
111
+ here = os.path.dirname(os.path.realpath(__file__))
112
+ with open(os.path.join(here, filename), 'r') as f:
113
+ requires = [line.replace('\n', '') for line in f.readlines()]
114
+ return requires
115
+
116
+
117
+ if __name__ == '__main__':
118
+ cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext
119
+ if cuda_ext == 'True':
120
+ ext_modules = [
121
+ make_cuda_ext(
122
+ name='deform_conv_ext',
123
+ module='basicsr.ops.dcn',
124
+ sources=['src/deform_conv_ext.cpp'],
125
+ sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
126
+ make_cuda_ext(
127
+ name='fused_act_ext',
128
+ module='basicsr.ops.fused_act',
129
+ sources=['src/fused_bias_act.cpp'],
130
+ sources_cuda=['src/fused_bias_act_kernel.cu']),
131
+ make_cuda_ext(
132
+ name='upfirdn2d_ext',
133
+ module='basicsr.ops.upfirdn2d',
134
+ sources=['src/upfirdn2d.cpp'],
135
+ sources_cuda=['src/upfirdn2d_kernel.cu']),
136
+ ]
137
+ else:
138
+ ext_modules = []
139
+
140
+ write_version_py()
141
+ setup(
142
+ name='basicsr',
143
+ version=get_version(),
144
+ description='Open Source Image and Video Super-Resolution Toolbox',
145
+ long_description=readme(),
146
+ long_description_content_type='text/markdown',
147
+ author='Xintao Wang',
148
+ author_email='[email protected]',
149
+ keywords='computer vision, restoration, super resolution',
150
+ url='https://github.com/xinntao/BasicSR',
151
+ include_package_data=True,
152
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
153
+ classifiers=[
154
+ 'Development Status :: 4 - Beta',
155
+ 'License :: OSI Approved :: Apache Software License',
156
+ 'Operating System :: OS Independent',
157
+ 'Programming Language :: Python :: 3',
158
+ 'Programming Language :: Python :: 3.7',
159
+ 'Programming Language :: Python :: 3.8',
160
+ ],
161
+ license='Apache License 2.0',
162
+ setup_requires=['cython', 'numpy'],
163
+ install_requires=get_requirements(),
164
+ ext_modules=ext_modules,
165
+ cmdclass={'build_ext': BuildExtension},
166
+ zip_safe=False)