Chen Zheng commited on
Commit
ca7246c
·
1 Parent(s): 9c116e0

Former-commit-id: 038cc4452f7c2bf937e142adebf95abffa91de55

.gitignore CHANGED
@@ -1,2 +1,135 @@
 
1
 
2
- .DS_Store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
 
3
+ # ignored folders
4
+ datasets/*
5
+ experiments/*
6
+ results/*
7
+ tb_logger/*
8
+ wandb/*
9
+ tmp/*
10
+ visual/*
11
+
12
+ docs/api
13
+ scripts/__init__.py
14
+
15
+ *.DS_Store
16
+ .idea
17
+
18
+ # ignored files
19
+ version.py
20
+
21
+ # ignored files with suffix
22
+ *.html
23
+ *.png
24
+ *.jpeg
25
+ *.jpg
26
+ *.gif
27
+ *.pth
28
+ *.zip
29
+
30
+ # template
31
+
32
+ # Byte-compiled / optimized / DLL files
33
+ __pycache__/
34
+ *.py[cod]
35
+ *$py.class
36
+
37
+ # C extensions
38
+ *.so
39
+
40
+ # Distribution / packaging
41
+ .Python
42
+ build/
43
+ develop-eggs/
44
+ dist/
45
+ downloads/
46
+ eggs/
47
+ .eggs/
48
+ lib/
49
+ lib64/
50
+ parts/
51
+ sdist/
52
+ var/
53
+ wheels/
54
+ *.egg-info/
55
+ .installed.cfg
56
+ *.egg
57
+ MANIFEST
58
+
59
+ # PyInstaller
60
+ # Usually these files are written by a python script from a template
61
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
62
+ *.manifest
63
+ *.spec
64
+
65
+ # Installer logs
66
+ pip-log.txt
67
+ pip-delete-this-directory.txt
68
+
69
+ # Unit test / coverage reports
70
+ htmlcov/
71
+ .tox/
72
+ .coverage
73
+ .coverage.*
74
+ .cache
75
+ nosetests.xml
76
+ coverage.xml
77
+ *.cover
78
+ .hypothesis/
79
+ .pytest_cache/
80
+
81
+ # Translations
82
+ *.mo
83
+ *.pot
84
+
85
+ # Django stuff:
86
+ *.log
87
+ local_settings.py
88
+ db.sqlite3
89
+
90
+ # Flask stuff:
91
+ instance/
92
+ .webassets-cache
93
+
94
+ # Scrapy stuff:
95
+ .scrapy
96
+
97
+ # Sphinx documentation
98
+ docs/_build/
99
+
100
+ # PyBuilder
101
+ target/
102
+
103
+ # Jupyter Notebook
104
+ .ipynb_checkpoints
105
+
106
+ # pyenv
107
+ .python-version
108
+
109
+ # celery beat schedule file
110
+ celerybeat-schedule
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
basicsr/archs/dat_arch.py CHANGED
@@ -673,6 +673,30 @@ class Upsample(nn.Sequential):
673
  super(Upsample, self).__init__(*m)
674
 
675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  @ARCH_REGISTRY.register()
677
  class DAT(nn.Module):
678
  """ Dual Aggregation Transformer
@@ -715,6 +739,7 @@ class DAT(nn.Module):
715
  upscale=2,
716
  img_range=1.,
717
  resi_connection='1conv',
 
718
  **kwargs):
719
  super().__init__()
720
 
@@ -728,6 +753,7 @@ class DAT(nn.Module):
728
  else:
729
  self.mean = torch.zeros(1, 1, 1, 1)
730
  self.upscale = upscale
 
731
 
732
  # ------------------------- 1, Shallow Feature Extraction ------------------------- #
733
  self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
@@ -779,10 +805,16 @@ class DAT(nn.Module):
779
  nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
780
 
781
  # ------------------------- 3, Reconstruction ------------------------- #
782
- self.conv_before_upsample = nn.Sequential(
 
 
783
  nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
784
- self.upsample = Upsample(upscale, num_feat)
785
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
 
 
 
 
786
 
787
  self.apply(self._init_weights)
788
 
@@ -813,10 +845,17 @@ class DAT(nn.Module):
813
  self.mean = self.mean.type_as(x)
814
  x = (x - self.mean) * self.img_range
815
 
816
- x = self.conv_first(x)
817
- x = self.conv_after_body(self.forward_features(x)) + x
818
- x = self.conv_before_upsample(x)
819
- x = self.conv_last(self.upsample(x))
 
 
 
 
 
 
 
820
 
821
  x = x / self.img_range + self.mean
822
  return x
 
673
  super(Upsample, self).__init__(*m)
674
 
675
 
676
+ class UpsampleOneStep(nn.Sequential):
677
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
678
+ Used in lightweight SR to save parameters.
679
+
680
+ Args:
681
+ scale (int): Scale factor. Supported scales: 2^n and 3.
682
+ num_feat (int): Channel number of intermediate features.
683
+
684
+ """
685
+
686
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
687
+ self.num_feat = num_feat
688
+ self.input_resolution = input_resolution
689
+ m = []
690
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
691
+ m.append(nn.PixelShuffle(scale))
692
+ super(UpsampleOneStep, self).__init__(*m)
693
+
694
+ def flops(self):
695
+ h, w = self.input_resolution
696
+ flops = h * w * self.num_feat * 3 * 9
697
+ return flops
698
+
699
+
700
  @ARCH_REGISTRY.register()
701
  class DAT(nn.Module):
702
  """ Dual Aggregation Transformer
 
739
  upscale=2,
740
  img_range=1.,
741
  resi_connection='1conv',
742
+ upsampler='pixelshuffle',
743
  **kwargs):
744
  super().__init__()
745
 
 
753
  else:
754
  self.mean = torch.zeros(1, 1, 1, 1)
755
  self.upscale = upscale
756
+ self.upsampler = upsampler
757
 
758
  # ------------------------- 1, Shallow Feature Extraction ------------------------- #
759
  self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
 
805
  nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
806
 
807
  # ------------------------- 3, Reconstruction ------------------------- #
808
+ if self.upsampler == 'pixelshuffle':
809
+ # for classical SR
810
+ self.conv_before_upsample = nn.Sequential(
811
  nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
812
+ self.upsample = Upsample(upscale, num_feat)
813
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
814
+ elif self.upsampler == 'pixelshuffledirect':
815
+ # for lightweight SR (to save parameters)
816
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
817
+ (img_size, img_size))
818
 
819
  self.apply(self._init_weights)
820
 
 
845
  self.mean = self.mean.type_as(x)
846
  x = (x - self.mean) * self.img_range
847
 
848
+ if self.upsampler == 'pixelshuffle':
849
+ # for image SR
850
+ x = self.conv_first(x)
851
+ x = self.conv_after_body(self.forward_features(x)) + x
852
+ x = self.conv_before_upsample(x)
853
+ x = self.conv_last(self.upsample(x))
854
+ elif self.upsampler == 'pixelshuffledirect':
855
+ # for lightweight SR
856
+ x = self.conv_first(x)
857
+ x = self.conv_after_body(self.forward_features(x)) + x
858
+ x = self.upsample(x)
859
 
860
  x = x / self.img_range + self.mean
861
  return x
options/Test/test_DAT_light_x2.yml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_light_x2
3
+ model_type: DATModle
4
+ scale: 2
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
15
+ filename_tmpl: '{}x2'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
25
+ filename_tmpl: '{}x2'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
35
+ filename_tmpl: '{}x2'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
45
+ filename_tmpl: '{}x2'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
55
+ filename_tmpl: '{}_LRBI_x2'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 2
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ depth: [18]
68
+ embed_dim: 60
69
+ num_heads: [6]
70
+ expansion_factor: 2
71
+ resi_connection: '3conv'
72
+ split_size: [8,32]
73
+ upsampler: 'pixelshuffledirect'
74
+
75
+ # path
76
+ path:
77
+ pretrain_network_g: experiments/pretrained_models/DAT-light/DAT_light_x2.pth
78
+ strict_load_g: True
79
+
80
+ # validation settings
81
+ val:
82
+ save_img: True
83
+ suffix: ~ # add suffix to saved images, if None, use exp name
84
+ use_chop: False # True to save memory, if img too large
85
+
86
+ metrics:
87
+ psnr: # metric name, can be arbitrary
88
+ type: calculate_psnr
89
+ crop_border: 2
90
+ test_y_channel: True
91
+ ssim:
92
+ type: calculate_ssim
93
+ crop_border: 2
94
+ test_y_channel: True
options/Test/test_DAT_light_x3.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_light_x3
3
+ model_type: DATModle
4
+ scale: 3
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
15
+ filename_tmpl: '{}x3'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
25
+ filename_tmpl: '{}x3'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
35
+ filename_tmpl: '{}x3'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
45
+ filename_tmpl: '{}x3'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
55
+ filename_tmpl: '{}_LRBI_x3'
56
+ io_backend:
57
+ type: disk
58
+
59
+ # network structures
60
+ network_g:
61
+ type: DAT
62
+ upscale: 3
63
+ in_chans: 3
64
+ img_size: 64
65
+ img_range: 1.
66
+ depth: [18]
67
+ embed_dim: 60
68
+ num_heads: [6]
69
+ expansion_factor: 2
70
+ resi_connection: '3conv'
71
+ split_size: [8,32]
72
+ upsampler: 'pixelshuffledirect'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT-light/DAT_light_x3.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: True
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False # True to save memory, if img too large
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 3
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 3
93
+ test_y_channel: True
options/Test/test_DAT_light_x4.yml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_light_x4
3
+ model_type: DATModle
4
+ scale: 4
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
15
+ filename_tmpl: '{}x4'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
25
+ filename_tmpl: '{}x4'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35
+ filename_tmpl: '{}x4'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
45
+ filename_tmpl: '{}x4'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
55
+ filename_tmpl: '{}_LRBI_x4'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 4
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ depth: [18]
68
+ embed_dim: 60
69
+ num_heads: [6]
70
+ expansion_factor: 2
71
+ resi_connection: '3conv'
72
+ split_size: [8,32]
73
+ upsampler: 'pixelshuffledirect'
74
+
75
+ # path
76
+ path:
77
+ pretrain_network_g: experiments/pretrained_models/DAT-light/DAT_light_x4.pth
78
+ strict_load_g: True
79
+
80
+ # validation settings
81
+ val:
82
+ save_img: True
83
+ suffix: ~ # add suffix to saved images, if None, use exp name
84
+ use_chop: False # True to save memory, if img too large
85
+
86
+ metrics:
87
+ psnr: # metric name, can be arbitrary
88
+ type: calculate_psnr
89
+ crop_border: 4
90
+ test_y_channel: True
91
+ ssim:
92
+ type: calculate_ssim
93
+ crop_border: 4
94
+ test_y_channel: True
options/Train/train_DAT_2_x3.yml CHANGED
@@ -91,7 +91,7 @@ val:
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
- crop_border: 4
95
  test_y_channel: True
96
 
97
  # logging settings
 
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
+ crop_border: 3
95
  test_y_channel: True
96
 
97
  # logging settings
options/Train/train_DAT_2_x4.yml CHANGED
@@ -1,5 +1,5 @@
1
  # general settings
2
- name: test_DAT_2_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto
 
1
  # general settings
2
+ name: train_DAT_2_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto
options/Train/train_DAT_S_x3.yml CHANGED
@@ -91,7 +91,7 @@ val:
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
- crop_border: 4
95
  test_y_channel: True
96
 
97
  # logging settings
 
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
+ crop_border: 3
95
  test_y_channel: True
96
 
97
  # logging settings
options/Train/train_DAT_S_x4.yml CHANGED
@@ -1,5 +1,5 @@
1
  # general settings
2
- name: test_DAT_S_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto
 
1
  # general settings
2
+ name: train_DAT_S_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto
options/Train/train_DAT_light_x2.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_DAT_light_x2
3
+ model_type: DATModle
4
+ scale: 2
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X2
16
+ filename_tmpl: '{}x2'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 128
21
+ use_hflip: True
22
+ use_rot: True
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
37
+ filename_tmpl: '{}x2'
38
+ io_backend:
39
+ type: disk
40
+
41
+ # network structures
42
+ network_g:
43
+ type: DAT
44
+ upscale: 2
45
+ in_chans: 3
46
+ img_size: 64
47
+ img_range: 1.
48
+ depth: [18]
49
+ embed_dim: 60
50
+ num_heads: [6]
51
+ expansion_factor: 2
52
+ resi_connection: '3conv'
53
+ split_size: [8,32]
54
+ upsampler: 'pixelshuffledirect'
55
+
56
+ # path
57
+ path:
58
+ pretrain_network_g: ~
59
+ strict_load_g: True
60
+ resume_state: ~
61
+
62
+ # training settings
63
+ train:
64
+ optim_g:
65
+ type: Adam
66
+ lr: !!float 2e-4
67
+ weight_decay: 0
68
+ betas: [0.9, 0.99]
69
+
70
+ scheduler:
71
+ type: MultiStepLR
72
+ milestones: [250000, 400000, 450000, 475000]
73
+ gamma: 0.5
74
+
75
+ total_iter: 500000
76
+ warmup_iter: -1 # no warm up
77
+
78
+ # losses
79
+ pixel_opt:
80
+ type: L1Loss
81
+ loss_weight: 1.0
82
+ reduction: mean
83
+
84
+ # validation settings
85
+ val:
86
+ val_freq: !!float 5e3
87
+ save_img: False
88
+
89
+ metrics:
90
+ psnr: # metric name, can be arbitrary
91
+ type: calculate_psnr
92
+ crop_border: 2
93
+ test_y_channel: True
94
+
95
+ # logging settings
96
+ logger:
97
+ print_freq: 200
98
+ save_checkpoint_freq: !!float 5e3
99
+ use_tb_logger: True
100
+ wandb:
101
+ project: ~
102
+ resume_id: ~
103
+
104
+ # dist training settings
105
+ dist_params:
106
+ backend: nccl
107
+ port: 29500
options/Train/train_DAT_light_x3.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_DAT_light_x3
3
+ model_type: DATModle
4
+ scale: 3
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X3
16
+ filename_tmpl: '{}x3'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 192
21
+ use_hflip: True
22
+ use_rot: True
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
37
+ filename_tmpl: '{}x3'
38
+ io_backend:
39
+ type: disk
40
+
41
+ # network structures
42
+ network_g:
43
+ type: DAT
44
+ upscale: 3
45
+ in_chans: 3
46
+ img_size: 64
47
+ img_range: 1.
48
+ depth: [18]
49
+ embed_dim: 60
50
+ num_heads: [6]
51
+ expansion_factor: 2
52
+ resi_connection: '3conv'
53
+ split_size: [8,32]
54
+ upsampler: 'pixelshuffledirect'
55
+
56
+ # path
57
+ path:
58
+ pretrain_network_g: experiments/pretrained_models/DAT-light/DAT_light_x2.pth
59
+ strict_load_g: False
60
+ resume_state: ~
61
+
62
+ # training settings
63
+ train:
64
+ optim_g:
65
+ type: Adam
66
+ lr: !!float 2e-4
67
+ weight_decay: 0
68
+ betas: [0.9, 0.99]
69
+
70
+ scheduler:
71
+ type: MultiStepLR
72
+ milestones: [250000, 400000, 450000, 475000]
73
+ gamma: 0.5
74
+
75
+ total_iter: 500000
76
+ warmup_iter: -1 # no warm up
77
+
78
+ # losses
79
+ pixel_opt:
80
+ type: L1Loss
81
+ loss_weight: 1.0
82
+ reduction: mean
83
+
84
+ # validation settings
85
+ val:
86
+ val_freq: !!float 5e3
87
+ save_img: False
88
+
89
+ metrics:
90
+ psnr: # metric name, can be arbitrary
91
+ type: calculate_psnr
92
+ crop_border: 3
93
+ test_y_channel: True
94
+
95
+ # logging settings
96
+ logger:
97
+ print_freq: 200
98
+ save_checkpoint_freq: !!float 5e3
99
+ use_tb_logger: True
100
+ wandb:
101
+ project: ~
102
+ resume_id: ~
103
+
104
+ # dist training settings
105
+ dist_params:
106
+ backend: nccl
107
+ port: 29500
options/Train/train_DAT_light_x4.yml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_DAT_light_x4
3
+ model_type: DATModle
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X4
16
+ filename_tmpl: '{}x4'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 256
21
+ use_hflip: true
22
+ use_rot: true
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
37
+ filename_tmpl: '{}x4'
38
+ io_backend:
39
+ type: disk
40
+
41
+
42
+ # network structures
43
+ network_g:
44
+ type: DAT
45
+ upscale: 4
46
+ in_chans: 3
47
+ img_size: 64
48
+ img_range: 1.
49
+ depth: [18]
50
+ embed_dim: 60
51
+ num_heads: [6]
52
+ expansion_factor: 2
53
+ resi_connection: '3conv'
54
+ split_size: [8,32]
55
+ upsampler: 'pixelshuffledirect'
56
+
57
+ # path
58
+ path:
59
+ pretrain_network_g: experiments/pretrained_models/DAT-light/DAT_light_x2.pth
60
+ strict_load_g: False
61
+ resume_state: ~
62
+
63
+ # training settings
64
+ train:
65
+ optim_g:
66
+ type: Adam
67
+ lr: !!float 2e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.99]
70
+
71
+ scheduler:
72
+ type: MultiStepLR
73
+ milestones: [250000, 400000, 450000, 475000]
74
+ gamma: 0.5
75
+
76
+ total_iter: 500000
77
+ warmup_iter: -1 # no warm up
78
+
79
+ # losses
80
+ pixel_opt:
81
+ type: L1Loss
82
+ loss_weight: 1.0
83
+ reduction: mean
84
+
85
+ # validation settings
86
+ val:
87
+ val_freq: !!float 5e3
88
+ save_img: False
89
+
90
+ metrics:
91
+ psnr: # metric name, can be arbitrary
92
+ type: calculate_psnr
93
+ crop_border: 4
94
+ test_y_channel: True
95
+
96
+ # logging settings
97
+ logger:
98
+ print_freq: 200
99
+ save_checkpoint_freq: !!float 5e3
100
+ use_tb_logger: True
101
+ wandb:
102
+ project: ~
103
+ resume_id: ~
104
+
105
+ # dist training settings
106
+ dist_params:
107
+ backend: nccl
108
+ port: 29500
options/Train/train_DAT_x3.yml CHANGED
@@ -91,7 +91,7 @@ val:
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
- crop_border: 4
95
  test_y_channel: True
96
 
97
  # logging settings
 
91
  metrics:
92
  psnr: # metric name, can be arbitrary
93
  type: calculate_psnr
94
+ crop_border: 3
95
  test_y_channel: True
96
 
97
  # logging settings
options/Train/train_DAT_x4.yml CHANGED
@@ -1,5 +1,5 @@
1
  # general settings
2
- name: test_DAT_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto
 
1
  # general settings
2
+ name: train_DAT_x4
3
  model_type: DATModle
4
  scale: 4
5
  num_gpu: auto