463465810cz commited on
Commit
8cb8316
·
1 Parent(s): 3cda643

Former-commit-id: 1d9f9df3885a24e94f638ac328082683d0ceb8b8

README.md CHANGED
@@ -58,10 +58,11 @@ Download training and testing datasets and put them into the corresponding folde
58
 
59
  | Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
60
  | :----- | :--------: | :-------: | :------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: |
61
- | DAT-S | 11.21 | 203.3 | Urban100 | 27.68 | 0.8300 | [Google Drive](https://drive.google.com/drive/folders/1hb77nOTpCo9iU_jmg_izHOPRvPJujRiL?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1W-CeN2Z0e1r0rOdc3t-GcGrRV-qTGdub/view?usp=drive_link) |
62
- | DAT | 14.80 | 275.8 | Urban100 | 27.87 | 0.8343 | [Google Drive](https://drive.google.com/drive/folders/1eZqgQEBQ69Vzf8afrPkvL27JHubW6o0t?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1B4zJsZaiVsu009ilTh81BV7-8Hr98BI2/view?usp=drive_link) |
 
63
 
64
- The performance is reported on Urban100 (x4, SR). The test input size of FLOPs is 128 x 128.
65
 
66
  ## Training
67
 
@@ -79,6 +80,11 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
79
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x2.yml --launcher pytorch
80
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x3.yml --launcher pytorch
81
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x4.yml --launcher pytorch
 
 
 
 
 
82
  ```
83
 
84
  - The training experiment is in `experiments/`.
@@ -87,9 +93,9 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
87
 
88
  - Download the pre-trained [models](https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM?usp=drive_link) and place them in `experiments/pretrained_models/`.
89
 
90
- We provide pre-trained models for image SR: DAT-S and DAT (x2, x3, x4).
91
 
92
- - Download [testing](https://ufile.io/6ek67nf8) (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in `datasets/`.
93
 
94
  - Run the following scripts. The testing configuration is in `options/test/`.
95
 
@@ -104,6 +110,11 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
104
  python basicsr/test.py -opt options/Test/test_DAT_x2.yml
105
  python basicsr/test.py -opt options/Test/test_DAT_x3.yml
106
  python basicsr/test.py -opt options/Test/test_DAT_x4.yml
 
 
 
 
 
107
  ```
108
 
109
  - The output is in `results/`.
@@ -120,13 +131,26 @@ We achieved state-of-the-art performance. Detailed results can be found in the p
120
  <p align="center">
121
  <img width="900" src="figs/Table-1.png">
122
  </p>
 
123
 
 
 
 
124
 
125
  - visual comparison (x4) in the main paper
126
 
127
  <p align="center">
128
  <img width="900" src="figs/Figure-1.png">
129
  </p>
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  - </details>
 
58
 
59
  | Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
60
  | :----- | :--------: | :-------: | :------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: |
61
+ | DAT-S | 11.21 | 203.3 | Urban100 | 27.68 | 0.8300 | [Google Drive](https://drive.google.com/drive/folders/1hM0v3fUg5u6GjkI7dduxShyGgGfEwQXO?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1x1ixMswxw5w-zeZ_Rap5Nk4Tr46MIjAw/view?usp=drive_link) |
62
+ | DAT | 14.80 | 275.8 | Urban100 | 27.87 | 0.8343 | [Google Drive](https://drive.google.com/drive/folders/14VG5mw5ie8RrR4jjypeHynXDZYWL8w-r?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1K43CTsXpoX5St5fed4kEW9gu2KMR6hLu/view?usp=drive_link) |
63
+ | DAT-2 | 11.21 | 216.93 | Urban100 | 27.86 | 0.8341 | [Google Drive](https://drive.google.com/drive/folders/1yV9LMhr2tYM_eHEIVY4Jw9X3bWGgorbD?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1TQRZIg8at5HX87OCu3GYytZhYGperkuN/view?usp=drive_link) |
64
 
65
+ The performance is reported on Urban100 (x4). The test input size of FLOPs is 128 x 128.
66
 
67
  ## Training
68
 
 
80
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x2.yml --launcher pytorch
81
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x3.yml --launcher pytorch
82
  python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x4.yml --launcher pytorch
83
+
84
+ # DAT-2, input=64x64, 4 GPUs
85
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x2.yml --launcher pytorch
86
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x3.yml --launcher pytorch
87
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x4.yml --launcher pytorch
88
  ```
89
 
90
  - The training experiment is in `experiments/`.
 
93
 
94
  - Download the pre-trained [models](https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM?usp=drive_link) and place them in `experiments/pretrained_models/`.
95
 
96
+ We provide pre-trained models for image SR: DAT-S, DAT, and DAT-2 (x2, x3, x4).
97
 
98
+ - Download [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in `datasets/`.
99
 
100
  - Run the following scripts. The testing configuration is in `options/test/`.
101
 
 
110
  python basicsr/test.py -opt options/Test/test_DAT_x2.yml
111
  python basicsr/test.py -opt options/Test/test_DAT_x3.yml
112
  python basicsr/test.py -opt options/Test/test_DAT_x4.yml
113
+
114
+ # DAT-2, reproduces results in Table 1 of the supplementary material
115
+ python basicsr/test.py -opt options/Test/test_DAT_2_x2.yml
116
+ python basicsr/test.py -opt options/Test/test_DAT_2_x3.yml
117
+ python basicsr/test.py -opt options/Test/test_DAT_2_x4.yml
118
  ```
119
 
120
  - The output is in `results/`.
 
131
  <p align="center">
132
  <img width="900" src="figs/Table-1.png">
133
  </p>
134
+ - results in Table 1 of the supplementary material
135
 
136
+ <p align="center">
137
+ <img width="900" src="figs/Table-2.png">
138
+ </p>
139
 
140
  - visual comparison (x4) in the main paper
141
 
142
  <p align="center">
143
  <img width="900" src="figs/Figure-1.png">
144
  </p>
145
+ - visual comparison (x4) in the supplementary material
146
+
147
+ <p align="center">
148
+ <img width="900" src="figs/Figure-2.png">
149
+ <img width="900" src="figs/Figure-3.png">
150
+ <img width="900" src="figs/Figure-4.png">
151
+ <img width="900" src="figs/Figure-5.png">
152
+ </p>
153
+
154
 
155
 
156
  - </details>
basicsr/archs/dat_arch.py CHANGED
@@ -297,7 +297,6 @@ class Axial_Spatial_Attention(nn.Module):
297
  self.register_buffer("attn_mask_0", None)
298
  self.register_buffer("attn_mask_1", None)
299
 
300
- # Adaptive Interaction Module
301
  self.dwconv = nn.Sequential(
302
  nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
303
  nn.BatchNorm2d(dim),
@@ -419,6 +418,7 @@ class Axial_Spatial_Attention(nn.Module):
419
  # convolution output
420
  conv_x = self.dwconv(v)
421
 
 
422
  # C-Map (before sigmoid)
423
  channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C)
424
  # S-Map (before sigmoid)
@@ -460,7 +460,6 @@ class Axial_Channel_Attention(nn.Module):
460
  self.proj = nn.Linear(dim, dim)
461
  self.proj_drop = nn.Dropout(proj_drop)
462
 
463
- # Adaptive Interaction Module
464
  self.dwconv = nn.Sequential(
465
  nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
466
  nn.BatchNorm2d(dim),
@@ -509,6 +508,7 @@ class Axial_Channel_Attention(nn.Module):
509
  # convolution output
510
  conv_x = self.dwconv(v_)
511
 
 
512
  # C-Map (before sigmoid)
513
  attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
514
  channel_map = self.channel_interaction(attention_reshape)
 
297
  self.register_buffer("attn_mask_0", None)
298
  self.register_buffer("attn_mask_1", None)
299
 
 
300
  self.dwconv = nn.Sequential(
301
  nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
302
  nn.BatchNorm2d(dim),
 
418
  # convolution output
419
  conv_x = self.dwconv(v)
420
 
421
+ # Adaptive Interaction Module (AIM)
422
  # C-Map (before sigmoid)
423
  channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C)
424
  # S-Map (before sigmoid)
 
460
  self.proj = nn.Linear(dim, dim)
461
  self.proj_drop = nn.Dropout(proj_drop)
462
 
 
463
  self.dwconv = nn.Sequential(
464
  nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
465
  nn.BatchNorm2d(dim),
 
508
  # convolution output
509
  conv_x = self.dwconv(v_)
510
 
511
+ # Adaptive Interaction Module (AIM)
512
  # C-Map (before sigmoid)
513
  attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
514
  channel_map = self.channel_interaction(attention_reshape)
basicsr/train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import math
4
+ import time
5
+ import torch
6
+ from os import path as osp
7
+
8
+ from basicsr.data import build_dataloader, build_dataset
9
+ from basicsr.data.data_sampler import EnlargedSampler
10
+ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
11
+ from basicsr.models import build_model
12
+ from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
13
+ init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
14
+ from basicsr.utils.options import copy_opt_file, dict2str, parse_options
15
+
16
+
17
+ def init_tb_loggers(opt):
18
+ # initialize wandb logger before tensorboard logger to allow proper sync
19
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
20
+ is not None) and ('debug' not in opt['name']):
21
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
22
+ init_wandb_logger(opt)
23
+ tb_logger = None
24
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
25
+ tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
26
+ return tb_logger
27
+
28
+
29
+ def create_train_val_dataloader(opt, logger):
30
+ # create train and val dataloaders
31
+ train_loader, val_loaders = None, []
32
+ for phase, dataset_opt in opt['datasets'].items():
33
+ if phase == 'train':
34
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
35
+ train_set = build_dataset(dataset_opt)
36
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
37
+ train_loader = build_dataloader(
38
+ train_set,
39
+ dataset_opt,
40
+ num_gpu=opt['num_gpu'],
41
+ dist=opt['dist'],
42
+ sampler=train_sampler,
43
+ seed=opt['manual_seed'])
44
+
45
+ num_iter_per_epoch = math.ceil(
46
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
47
+ total_iters = int(opt['train']['total_iter'])
48
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
49
+ logger.info('Training statistics:'
50
+ f'\n\tNumber of train images: {len(train_set)}'
51
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
52
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
53
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
54
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
55
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
56
+ elif phase.split('_')[0] == 'val':
57
+ val_set = build_dataset(dataset_opt)
58
+ val_loader = build_dataloader(
59
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
60
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
61
+ val_loaders.append(val_loader)
62
+ else:
63
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
64
+
65
+ return train_loader, train_sampler, val_loaders, total_epochs, total_iters
66
+
67
+
68
+ def load_resume_state(opt):
69
+ resume_state_path = None
70
+ if opt['auto_resume']:
71
+ state_path = osp.join('experiments', opt['name'], 'training_states')
72
+ if osp.isdir(state_path):
73
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
74
+ if len(states) != 0:
75
+ states = [float(v.split('.state')[0]) for v in states]
76
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
77
+ opt['path']['resume_state'] = resume_state_path
78
+ else:
79
+ if opt['path'].get('resume_state'):
80
+ resume_state_path = opt['path']['resume_state']
81
+
82
+ if resume_state_path is None:
83
+ resume_state = None
84
+ else:
85
+ device_id = torch.cuda.current_device()
86
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
87
+ check_resume(opt, resume_state['iter'])
88
+ return resume_state
89
+
90
+
91
+ def train_pipeline(root_path):
92
+ # parse options, set distributed setting, set ramdom seed
93
+ opt, args = parse_options(root_path, is_train=True)
94
+ opt['root_path'] = root_path
95
+
96
+ torch.backends.cudnn.benchmark = True
97
+ # torch.backends.cudnn.deterministic = True
98
+
99
+ # load resume states if necessary
100
+ resume_state = load_resume_state(opt)
101
+ # mkdir for experiments and logger
102
+ if resume_state is None:
103
+ make_exp_dirs(opt)
104
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
105
+ mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
106
+
107
+ # copy the yml file to the experiment root
108
+ copy_opt_file(args.opt, opt['path']['experiments_root'])
109
+
110
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
111
+ # Otherwise the logger will not be properly initialized
112
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
113
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
114
+ logger.info(get_env_info())
115
+ logger.info(dict2str(opt))
116
+ # initialize wandb and tb loggers
117
+ tb_logger = init_tb_loggers(opt)
118
+
119
+ # create train and validation dataloaders
120
+ result = create_train_val_dataloader(opt, logger)
121
+ train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
122
+
123
+ # create model
124
+ model = build_model(opt)
125
+ if resume_state: # resume training
126
+ model.resume_training(resume_state) # handle optimizers and schedulers
127
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
128
+ start_epoch = resume_state['epoch']
129
+ current_iter = resume_state['iter']
130
+ else:
131
+ start_epoch = 0
132
+ current_iter = 0
133
+
134
+ # create message logger (formatted outputs)
135
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
136
+
137
+ # dataloader prefetcher
138
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
139
+ if prefetch_mode is None or prefetch_mode == 'cpu':
140
+ prefetcher = CPUPrefetcher(train_loader)
141
+ elif prefetch_mode == 'cuda':
142
+ prefetcher = CUDAPrefetcher(train_loader, opt)
143
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
144
+ if opt['datasets']['train'].get('pin_memory') is not True:
145
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
146
+ else:
147
+ raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
148
+
149
+ # training
150
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
151
+ data_timer, iter_timer = AvgTimer(), AvgTimer()
152
+ start_time = time.time()
153
+
154
+ for epoch in range(start_epoch, total_epochs + 1):
155
+ train_sampler.set_epoch(epoch)
156
+ prefetcher.reset()
157
+ train_data = prefetcher.next()
158
+
159
+ while train_data is not None:
160
+ data_timer.record()
161
+
162
+ current_iter += 1
163
+ if current_iter > total_iters:
164
+ break
165
+ # update learning rate
166
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
167
+ # training
168
+ model.feed_data(train_data)
169
+ model.optimize_parameters(current_iter)
170
+ iter_timer.record()
171
+ if current_iter == 1:
172
+ # reset start time in msg_logger for more accurate eta_time
173
+ # not work in resume mode
174
+ msg_logger.reset_start_time()
175
+ # log
176
+ if current_iter % opt['logger']['print_freq'] == 0:
177
+ log_vars = {'epoch': epoch, 'iter': current_iter}
178
+ log_vars.update({'lrs': model.get_current_learning_rate()})
179
+ log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
180
+ log_vars.update(model.get_current_log())
181
+ msg_logger(log_vars)
182
+
183
+ # save models and training states
184
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
185
+ logger.info('Saving models and training states.')
186
+ model.save(epoch, current_iter)
187
+
188
+ # validation
189
+ if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
190
+ if len(val_loaders) > 1:
191
+ logger.warning('Multiple validation datasets are *only* supported by SRModel.')
192
+ for val_loader in val_loaders:
193
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
194
+
195
+ data_timer.start()
196
+ iter_timer.start()
197
+ train_data = prefetcher.next()
198
+ # end of iter
199
+
200
+ # end of epoch
201
+
202
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
203
+ logger.info(f'End of training. Time consumed: {consumed_time}')
204
+ logger.info('Save the latest model.')
205
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
206
+ if opt.get('val') is not None:
207
+ for val_loader in val_loaders:
208
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
209
+ if tb_logger:
210
+ tb_logger.close()
211
+
212
+
213
+ if __name__ == '__main__':
214
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
215
+ train_pipeline(root_path)
basicsr/version.py CHANGED
@@ -1,5 +1,5 @@
1
  # GENERATED VERSION FILE
2
- # TIME: Thu Sep 22 07:20:35 2022
3
  __version__ = '1.3.5'
4
- __gitsha__ = 'cbc9a18'
5
  version_info = (1, 3, 5)
 
1
  # GENERATED VERSION FILE
2
+ # TIME: Mon Jul 17 01:59:53 2023
3
  __version__ = '1.3.5'
4
+ __gitsha__ = '29e57e3'
5
  version_info = (1, 3, 5)
datasets/README.md CHANGED
@@ -1,2 +1,46 @@
1
- Dwonload the [testing](https://ufile.io/6ek67nf8) datasets and place them here.
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ For training and testing, the directory structure is as follows:
2
 
3
+ ```shell
4
+ |-- datasets
5
+ # train
6
+ |-- DF2K
7
+ |-- HR
8
+ |-- LR_bicubic
9
+ |-- X2
10
+ |-- X3
11
+ |-- X4
12
+ # test
13
+ |-- benchmark
14
+ |-- Set5
15
+ |-- HR
16
+ |-- LR_bicubic
17
+ |-- X2
18
+ |-- X3
19
+ |-- X4
20
+ |-- Set14
21
+ |-- HR
22
+ |-- LR_bicubic
23
+ |-- X2
24
+ |-- X3
25
+ |-- X4
26
+ |-- B100
27
+ |-- HR
28
+ |-- LR_bicubic
29
+ |-- X2
30
+ |-- X3
31
+ |-- X4
32
+ |-- Urban100
33
+ |-- HR
34
+ |-- LR_bicubic
35
+ |-- X2
36
+ |-- X3
37
+ |-- X4
38
+ |-- Manga109
39
+ |-- HR
40
+ |-- LR_bicubic
41
+ |-- X2
42
+ |-- X3
43
+ |-- X4
44
+ ```
45
+
46
+ You can download the complete datasets we have collected.
experiments/README.md CHANGED
@@ -1,2 +1 @@
1
- Dwonload the pre-trained [models](https://ufile.io/rf58x0s9) and place them in `pretrained_models`.
2
-
 
1
+ Place pretrained models in `pretrained_models`.
 
figs/Figure-2.png ADDED

Git LFS Details

  • SHA256: bca8431490641478e71d106c83458e00ee1c1cf6315ca03318d24c6ebbd48246
  • Pointer size: 132 Bytes
  • Size of remote file: 2.65 MB
figs/Figure-3.png ADDED

Git LFS Details

  • SHA256: 3e08d373a4a965a1da58b0b8f67273afa2e90c231f56ad2e875febb5e1c1b9d0
  • Pointer size: 132 Bytes
  • Size of remote file: 3.27 MB
figs/Figure-4.png ADDED

Git LFS Details

  • SHA256: 1c86421092316c09c7ff5e21a4f2d59775c501cf61b7cdd70c66546b39959971
  • Pointer size: 132 Bytes
  • Size of remote file: 3.05 MB
figs/Figure-5.png ADDED

Git LFS Details

  • SHA256: 18e604aca7929f426aca7ab53df0ff38b488b04f19c54c90a1c6d730c0c605ef
  • Pointer size: 132 Bytes
  • Size of remote file: 3.12 MB
figs/Table-2.png ADDED

Git LFS Details

  • SHA256: be734be1fe8df4ae76804df09621b1859f1f93134fb9d458b279ac2ceda8d811
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
options/README.md DELETED
@@ -1,2 +0,0 @@
1
- For more information about testing configuration, please refer to [Configuration](https://github.com/XPixelGroup/BasicSR/blob/master/docs/Config.md).
2
-
 
 
 
options/Test/test_DAT_2_x2.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x2
3
+ model_type: SRModel
4
+ scale: 2
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
15
+ filename_tmpl: '{}x2'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
25
+ filename_tmpl: '{}x2'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
35
+ filename_tmpl: '{}x2'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
45
+ filename_tmpl: '{}x2'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
55
+ filename_tmpl: '{}_LRBI_x2'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 2
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: True
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 2
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 2
93
+ test_y_channel: True
options/Test/test_DAT_2_x3.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x3
3
+ model_type: SRModel
4
+ scale: 3
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
15
+ filename_tmpl: '{}x3'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
25
+ filename_tmpl: '{}x3'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
35
+ filename_tmpl: '{}x3'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
45
+ filename_tmpl: '{}x3'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
55
+ filename_tmpl: '{}_LRBI_x3'
56
+ io_backend:
57
+ type: disk
58
+
59
+ # network structures
60
+ network_g:
61
+ type: DAT
62
+ upscale: 3
63
+ in_chans: 3
64
+ img_size: 64
65
+ img_range: 1.
66
+ split_size: [8,32]
67
+ depth: [6,6,6,6,6,6]
68
+ embed_dim: 180
69
+ num_heads: [6,6,6,6,6,6]
70
+ expansion_factor: 2
71
+ resi_connection: '1conv'
72
+
73
+ # path
74
+ path:
75
+ pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x3.pth
76
+ strict_load_g: True
77
+
78
+ # validation settings
79
+ val:
80
+ save_img: True
81
+ suffix: ~ # add suffix to saved images, if None, use exp name
82
+ use_chop: False
83
+
84
+ metrics:
85
+ psnr: # metric name, can be arbitrary
86
+ type: calculate_psnr
87
+ crop_border: 3
88
+ test_y_channel: True
89
+ ssim:
90
+ type: calculate_ssim
91
+ crop_border: 3
92
+ test_y_channel: True
options/Test/test_DAT_2_x4.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x4
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: 1
6
+ manual_seed: 10
7
+
8
+ datasets:
9
+ test_1: # the 1st test dataset
10
+ task: SR
11
+ name: Set5
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/benchmark/Set5/HR
14
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
15
+ filename_tmpl: '{}x4'
16
+ io_backend:
17
+ type: disk
18
+
19
+ test_2: # the 2st test dataset
20
+ task: SR
21
+ name: Set14
22
+ type: PairedImageDataset
23
+ dataroot_gt: datasets/benchmark/Set14/HR
24
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
25
+ filename_tmpl: '{}x4'
26
+ io_backend:
27
+ type: disk
28
+
29
+ test_3: # the 3st test dataset
30
+ task: SR
31
+ name: B100
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/benchmark/B100/HR
34
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35
+ filename_tmpl: '{}x4'
36
+ io_backend:
37
+ type: disk
38
+
39
+ test_4: # the 4st test dataset
40
+ task: SR
41
+ name: Urban100
42
+ type: PairedImageDataset
43
+ dataroot_gt: datasets/benchmark/Urban100/HR
44
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
45
+ filename_tmpl: '{}x4'
46
+ io_backend:
47
+ type: disk
48
+
49
+ test_5: # the 5st test dataset
50
+ task: SR
51
+ name: Manga109
52
+ type: PairedImageDataset
53
+ dataroot_gt: datasets/benchmark/Manga109/HR
54
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
55
+ filename_tmpl: '{}_LRBI_x4'
56
+ io_backend:
57
+ type: disk
58
+
59
+
60
+ # network structures
61
+ network_g:
62
+ type: DAT
63
+ upscale: 4
64
+ in_chans: 3
65
+ img_size: 64
66
+ img_range: 1.
67
+ split_size: [8,32]
68
+ depth: [6,6,6,6,6,6]
69
+ embed_dim: 180
70
+ num_heads: [6,6,6,6,6,6]
71
+ expansion_factor: 2
72
+ resi_connection: '1conv'
73
+
74
+ # path
75
+ path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x4.pth
77
+ strict_load_g: True
78
+
79
+ # validation settings
80
+ val:
81
+ save_img: True
82
+ suffix: ~ # add suffix to saved images, if None, use exp name
83
+ use_chop: False
84
+
85
+ metrics:
86
+ psnr: # metric name, can be arbitrary
87
+ type: calculate_psnr
88
+ crop_border: 4
89
+ test_y_channel: True
90
+ ssim:
91
+ type: calculate_ssim
92
+ crop_border: 4
93
+ test_y_channel: True
options/Test/test_DAT_S_x2.yml CHANGED
@@ -73,12 +73,12 @@ network_g:
73
 
74
  # path
75
  path:
76
- pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x2.pth
77
  strict_load_g: True
78
 
79
  # validation settings
80
  val:
81
- save_img: False
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
 
73
 
74
  # path
75
  path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x2.pth
77
  strict_load_g: True
78
 
79
  # validation settings
80
  val:
81
+ save_img: True
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
options/Test/{test_DAT_S_x3.yml.yml → test_DAT_S_x3.yml} RENAMED
@@ -72,12 +72,12 @@ network_g:
72
 
73
  # path
74
  path:
75
- pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x3.pth
76
  strict_load_g: True
77
 
78
  # validation settings
79
  val:
80
- save_img: False
81
  suffix: ~ # add suffix to saved images, if None, use exp name
82
  use_chop: False
83
 
 
72
 
73
  # path
74
  path:
75
+ pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x3.pth
76
  strict_load_g: True
77
 
78
  # validation settings
79
  val:
80
+ save_img: True
81
  suffix: ~ # add suffix to saved images, if None, use exp name
82
  use_chop: False
83
 
options/Test/test_DAT_S_x4.yml CHANGED
@@ -73,12 +73,12 @@ network_g:
73
 
74
  # path
75
  path:
76
- pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x4.pth
77
  strict_load_g: True
78
 
79
  # validation settings
80
  val:
81
- save_img: False
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
 
73
 
74
  # path
75
  path:
76
+ pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x4.pth
77
  strict_load_g: True
78
 
79
  # validation settings
80
  val:
81
+ save_img: True
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
options/Test/test_DAT_x2.yml CHANGED
@@ -78,7 +78,7 @@ path:
78
 
79
  # validation settings
80
  val:
81
- save_img: False
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
 
78
 
79
  # validation settings
80
  val:
81
+ save_img: True
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
options/Test/test_DAT_x3.yml CHANGED
@@ -77,7 +77,7 @@ path:
77
 
78
  # validation settings
79
  val:
80
- save_img: False
81
  suffix: ~ # add suffix to saved images, if None, use exp name
82
  use_chop: False
83
 
 
77
 
78
  # validation settings
79
  val:
80
+ save_img: True
81
  suffix: ~ # add suffix to saved images, if None, use exp name
82
  use_chop: False
83
 
options/Test/test_DAT_x4.yml CHANGED
@@ -78,7 +78,7 @@ path:
78
 
79
  # validation settings
80
  val:
81
- save_img: False
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
 
78
 
79
  # validation settings
80
  val:
81
+ save_img: True
82
  suffix: ~ # add suffix to saved images, if None, use exp name
83
  use_chop: False
84
 
options/Train/train_DAT_2_x2.yml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_DAT_2_x2
3
+ model_type: SRModel
4
+ scale: 2
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X2
16
+ filename_tmpl: '{}x2'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 128
21
+ use_hflip: True
22
+ use_rot: True
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
37
+ filename_tmpl: '{}x2'
38
+ io_backend:
39
+ type: disk
40
+
41
+ # network structures
42
+ network_g:
43
+ type: DAT
44
+ upscale: 2
45
+ in_chans: 3
46
+ img_size: 64
47
+ img_range: 1.
48
+ split_size: [8,32]
49
+ depth: [6,6,6,6,6,6]
50
+ embed_dim: 180
51
+ num_heads: [6,6,6,6,6,6]
52
+ expansion_factor: 2
53
+ resi_connection: '1conv'
54
+
55
+ # path
56
+ path:
57
+ pretrain_network_g: ~
58
+ strict_load_g: True
59
+ resume_state: ~
60
+
61
+ # training settings
62
+ train:
63
+ optim_g:
64
+ type: Adam
65
+ lr: !!float 2e-4
66
+ weight_decay: 0
67
+ betas: [0.9, 0.99]
68
+
69
+ scheduler:
70
+ type: MultiStepLR
71
+ milestones: [250000, 400000, 450000, 475000]
72
+ gamma: 0.5
73
+
74
+ total_iter: 500000
75
+ warmup_iter: -1 # no warm up
76
+
77
+ # losses
78
+ pixel_opt:
79
+ type: L1Loss
80
+ loss_weight: 1.0
81
+ reduction: mean
82
+
83
+ # validation settings
84
+ val:
85
+ val_freq: !!float 5e3
86
+ save_img: False
87
+
88
+ metrics:
89
+ psnr: # metric name, can be arbitrary
90
+ type: calculate_psnr
91
+ crop_border: 2
92
+ test_y_channel: True
93
+
94
+ # logging settings
95
+ logger:
96
+ print_freq: 200
97
+ save_checkpoint_freq: !!float 5e3
98
+ use_tb_logger: True
99
+ wandb:
100
+ project: ~
101
+ resume_id: ~
102
+
103
+ # dist training settings
104
+ dist_params:
105
+ backend: nccl
106
+ port: 29500
options/Train/train_DAT_2_x3.yml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_DAT_2_x3
3
+ model_type: SRModel
4
+ scale: 3
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X3
16
+ filename_tmpl: '{}x3'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 192
21
+ use_hflip: True
22
+ use_rot: True
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
37
+ filename_tmpl: '{}x3'
38
+ io_backend:
39
+ type: disk
40
+
41
+ # network structures
42
+ network_g:
43
+ type: DAT
44
+ upscale: 2
45
+ in_chans: 3
46
+ img_size: 64
47
+ img_range: 1.
48
+ split_size: [8,32]
49
+ depth: [6,6,6,6,6,6]
50
+ embed_dim: 180
51
+ num_heads: [6,6,6,6,6,6]
52
+ expansion_factor: 2
53
+ resi_connection: '1conv'
54
+
55
+ # path
56
+ path:
57
+ pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
58
+ strict_load_g: False
59
+ resume_state: ~
60
+
61
+ # training settings
62
+ train:
63
+ optim_g:
64
+ type: Adam
65
+ # lr: !!float 2e-4
66
+ lr: !!float 1e-4
67
+ weight_decay: 0
68
+ betas: [0.9, 0.99]
69
+
70
+ scheduler:
71
+ type: MultiStepLR
72
+ # milestones: [ 250000, 400000, 450000, 475000 ]
73
+ milestones: [ 125000, 200000, 225000, 237500 ]
74
+ gamma: 0.5
75
+
76
+ # total_iter: 500000
77
+ total_iter: 250000
78
+ warmup_iter: -1 # no warm up
79
+
80
+ # losses
81
+ pixel_opt:
82
+ type: L1Loss
83
+ loss_weight: 1.0
84
+ reduction: mean
85
+
86
+ # validation settings
87
+ val:
88
+ val_freq: !!float 5e3
89
+ save_img: False
90
+
91
+ metrics:
92
+ psnr: # metric name, can be arbitrary
93
+ type: calculate_psnr
94
+ crop_border: 4
95
+ test_y_channel: True
96
+
97
+ # logging settings
98
+ logger:
99
+ print_freq: 200
100
+ save_checkpoint_freq: !!float 5e3
101
+ use_tb_logger: True
102
+ wandb:
103
+ project: ~
104
+ resume_id: ~
105
+
106
+ # dist training settings
107
+ dist_params:
108
+ backend: nccl
109
+ port: 29500
options/Train/train_DAT_2_x4.yml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: test_DAT_2_x4
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 10
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ task: SR
12
+ name: DF2K
13
+ type: PairedImageDataset
14
+ dataroot_gt: datasets/DF2K/HR
15
+ dataroot_lq: datasets/DF2K/LR_bicubic/X4
16
+ filename_tmpl: '{}x4'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 256
21
+ use_hflip: true
22
+ use_rot: true
23
+
24
+ # data loader
25
+ use_shuffle: True
26
+ num_worker_per_gpu: 12
27
+ batch_size_per_gpu: 8
28
+ dataset_enlarge_ratio: 100
29
+ prefetch_mode: ~
30
+
31
+ val:
32
+ task: SR
33
+ name: Set5
34
+ type: PairedImageDataset
35
+ dataroot_gt: datasets/benchmark/Set5/HR
36
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
37
+ filename_tmpl: '{}x4'
38
+ io_backend:
39
+ type: disk
40
+
41
+
42
+ # network structures
43
+ network_g:
44
+ type: DAT
45
+ upscale: 4
46
+ in_chans: 3
47
+ img_size: 64
48
+ img_range: 1.
49
+ split_size: [8,32]
50
+ depth: [6,6,6,6,6,6]
51
+ embed_dim: 180
52
+ num_heads: [6,6,6,6,6,6]
53
+ expansion_factor: 2
54
+ resi_connection: '1conv'
55
+
56
+ # path
57
+ path:
58
+ pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
59
+ strict_load_g: False
60
+ resume_state: ~
61
+
62
+ # training settings
63
+ train:
64
+ optim_g:
65
+ type: Adam
66
+ # lr: !!float 2e-4
67
+ lr: !!float 1e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.99]
70
+
71
+ scheduler:
72
+ type: MultiStepLR
73
+ # milestones: [ 250000, 400000, 450000, 475000 ]
74
+ milestones: [ 125000, 200000, 225000, 237500 ]
75
+ gamma: 0.5
76
+
77
+ # total_iter: 500000
78
+ total_iter: 250000
79
+ warmup_iter: -1 # no warm up
80
+
81
+ # losses
82
+ pixel_opt:
83
+ type: L1Loss
84
+ loss_weight: 1.0
85
+ reduction: mean
86
+
87
+ # validation settings
88
+ val:
89
+ val_freq: !!float 5e3
90
+ save_img: False
91
+
92
+ metrics:
93
+ psnr: # metric name, can be arbitrary
94
+ type: calculate_psnr
95
+ crop_border: 4
96
+ test_y_channel: True
97
+
98
+ # logging settings
99
+ logger:
100
+ print_freq: 200
101
+ save_checkpoint_freq: !!float 5e3
102
+ use_tb_logger: True
103
+ wandb:
104
+ project: ~
105
+ resume_id: ~
106
+
107
+ # dist training settings
108
+ dist_params:
109
+ backend: nccl
110
+ port: 29500
options/Train/{train_DAT_S_x3.yml.yml → train_DAT_S_x3.yml} RENAMED
File without changes
options/Train/train_DAT_x4.yml CHANGED
@@ -1,5 +1,5 @@
1
  # general settings
2
- name: test_DAT_S_x4
3
  model_type: SRModel
4
  scale: 4
5
  num_gpu: auto
@@ -55,7 +55,7 @@ network_g:
55
 
56
  # path
57
  path:
58
- pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
59
  strict_load_g: False
60
  resume_state: ~
61
 
 
1
  # general settings
2
+ name: test_DAT_x4
3
  model_type: SRModel
4
  scale: 4
5
  num_gpu: auto
 
55
 
56
  # path
57
  path:
58
+ pretrain_network_g: experiments/pretrained_models/DAT/DAT_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
59
  strict_load_g: False
60
  resume_state: ~
61