463465810cz
commited on
Commit
·
8cb8316
1
Parent(s):
3cda643
ICCV 2023
Browse filesFormer-commit-id: 1d9f9df3885a24e94f638ac328082683d0ceb8b8
- README.md +29 -5
- basicsr/archs/dat_arch.py +2 -2
- basicsr/train.py +215 -0
- basicsr/version.py +2 -2
- datasets/README.md +45 -1
- experiments/README.md +1 -2
- figs/Figure-2.png +3 -0
- figs/Figure-3.png +3 -0
- figs/Figure-4.png +3 -0
- figs/Figure-5.png +3 -0
- figs/Table-2.png +3 -0
- options/README.md +0 -2
- options/Test/test_DAT_2_x2.yml +93 -0
- options/Test/test_DAT_2_x3.yml +92 -0
- options/Test/test_DAT_2_x4.yml +93 -0
- options/Test/test_DAT_S_x2.yml +2 -2
- options/Test/{test_DAT_S_x3.yml.yml → test_DAT_S_x3.yml} +2 -2
- options/Test/test_DAT_S_x4.yml +2 -2
- options/Test/test_DAT_x2.yml +1 -1
- options/Test/test_DAT_x3.yml +1 -1
- options/Test/test_DAT_x4.yml +1 -1
- options/Train/train_DAT_2_x2.yml +106 -0
- options/Train/train_DAT_2_x3.yml +109 -0
- options/Train/train_DAT_2_x4.yml +110 -0
- options/Train/{train_DAT_S_x3.yml.yml → train_DAT_S_x3.yml} +0 -0
- options/Train/train_DAT_x4.yml +2 -2
README.md
CHANGED
|
@@ -58,10 +58,11 @@ Download training and testing datasets and put them into the corresponding folde
|
|
| 58 |
|
| 59 |
| Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
|
| 60 |
| :----- | :--------: | :-------: | :------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
| 61 |
-
| DAT-S | 11.21 | 203.3 | Urban100 | 27.68 | 0.8300 | [Google Drive](https://drive.google.com/drive/folders/
|
| 62 |
-
| DAT | 14.80 | 275.8 | Urban100 | 27.87 | 0.8343 | [Google Drive](https://drive.google.com/drive/folders/
|
|
|
|
| 63 |
|
| 64 |
-
The performance is reported on Urban100 (x4
|
| 65 |
|
| 66 |
## Training
|
| 67 |
|
|
@@ -79,6 +80,11 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
|
|
| 79 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x2.yml --launcher pytorch
|
| 80 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x3.yml --launcher pytorch
|
| 81 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x4.yml --launcher pytorch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
```
|
| 83 |
|
| 84 |
- The training experiment is in `experiments/`.
|
|
@@ -87,9 +93,9 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
|
|
| 87 |
|
| 88 |
- Download the pre-trained [models](https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM?usp=drive_link) and place them in `experiments/pretrained_models/`.
|
| 89 |
|
| 90 |
-
We provide pre-trained models for image SR: DAT-S and DAT (x2, x3, x4).
|
| 91 |
|
| 92 |
-
- Download [testing](https://
|
| 93 |
|
| 94 |
- Run the following scripts. The testing configuration is in `options/test/`.
|
| 95 |
|
|
@@ -104,6 +110,11 @@ The performance is reported on Urban100 (x4, SR). The test input size of FLOPs i
|
|
| 104 |
python basicsr/test.py -opt options/Test/test_DAT_x2.yml
|
| 105 |
python basicsr/test.py -opt options/Test/test_DAT_x3.yml
|
| 106 |
python basicsr/test.py -opt options/Test/test_DAT_x4.yml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
```
|
| 108 |
|
| 109 |
- The output is in `results/`.
|
|
@@ -120,13 +131,26 @@ We achieved state-of-the-art performance. Detailed results can be found in the p
|
|
| 120 |
<p align="center">
|
| 121 |
<img width="900" src="figs/Table-1.png">
|
| 122 |
</p>
|
|
|
|
| 123 |
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
- visual comparison (x4) in the main paper
|
| 126 |
|
| 127 |
<p align="center">
|
| 128 |
<img width="900" src="figs/Figure-1.png">
|
| 129 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
- </details>
|
|
|
|
| 58 |
|
| 59 |
| Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
|
| 60 |
| :----- | :--------: | :-------: | :------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
| 61 |
+
| DAT-S | 11.21 | 203.3 | Urban100 | 27.68 | 0.8300 | [Google Drive](https://drive.google.com/drive/folders/1hM0v3fUg5u6GjkI7dduxShyGgGfEwQXO?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1x1ixMswxw5w-zeZ_Rap5Nk4Tr46MIjAw/view?usp=drive_link) |
|
| 62 |
+
| DAT | 14.80 | 275.8 | Urban100 | 27.87 | 0.8343 | [Google Drive](https://drive.google.com/drive/folders/14VG5mw5ie8RrR4jjypeHynXDZYWL8w-r?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1K43CTsXpoX5St5fed4kEW9gu2KMR6hLu/view?usp=drive_link) |
|
| 63 |
+
| DAT-2 | 11.21 | 216.93 | Urban100 | 27.86 | 0.8341 | [Google Drive](https://drive.google.com/drive/folders/1yV9LMhr2tYM_eHEIVY4Jw9X3bWGgorbD?usp=drive_link) | [Google Drive](https://drive.google.com/file/d/1TQRZIg8at5HX87OCu3GYytZhYGperkuN/view?usp=drive_link) |
|
| 64 |
|
| 65 |
+
The performance is reported on Urban100 (x4). The test input size of FLOPs is 128 x 128.
|
| 66 |
|
| 67 |
## Training
|
| 68 |
|
|
|
|
| 80 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x2.yml --launcher pytorch
|
| 81 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x3.yml --launcher pytorch
|
| 82 |
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_x4.yml --launcher pytorch
|
| 83 |
+
|
| 84 |
+
# DAT-2, input=64x64, 4 GPUs
|
| 85 |
+
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x2.yml --launcher pytorch
|
| 86 |
+
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x3.yml --launcher pytorch
|
| 87 |
+
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/Train/train_DAT_2_x4.yml --launcher pytorch
|
| 88 |
```
|
| 89 |
|
| 90 |
- The training experiment is in `experiments/`.
|
|
|
|
| 93 |
|
| 94 |
- Download the pre-trained [models](https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM?usp=drive_link) and place them in `experiments/pretrained_models/`.
|
| 95 |
|
| 96 |
+
We provide pre-trained models for image SR: DAT-S, DAT, and DAT-2 (x2, x3, x4).
|
| 97 |
|
| 98 |
+
- Download [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in `datasets/`.
|
| 99 |
|
| 100 |
- Run the following scripts. The testing configuration is in `options/test/`.
|
| 101 |
|
|
|
|
| 110 |
python basicsr/test.py -opt options/Test/test_DAT_x2.yml
|
| 111 |
python basicsr/test.py -opt options/Test/test_DAT_x3.yml
|
| 112 |
python basicsr/test.py -opt options/Test/test_DAT_x4.yml
|
| 113 |
+
|
| 114 |
+
# DAT-2, reproduces results in Table 1 of the supplementary material
|
| 115 |
+
python basicsr/test.py -opt options/Test/test_DAT_2_x2.yml
|
| 116 |
+
python basicsr/test.py -opt options/Test/test_DAT_2_x3.yml
|
| 117 |
+
python basicsr/test.py -opt options/Test/test_DAT_2_x4.yml
|
| 118 |
```
|
| 119 |
|
| 120 |
- The output is in `results/`.
|
|
|
|
| 131 |
<p align="center">
|
| 132 |
<img width="900" src="figs/Table-1.png">
|
| 133 |
</p>
|
| 134 |
+
- results in Table 1 of the supplementary material
|
| 135 |
|
| 136 |
+
<p align="center">
|
| 137 |
+
<img width="900" src="figs/Table-2.png">
|
| 138 |
+
</p>
|
| 139 |
|
| 140 |
- visual comparison (x4) in the main paper
|
| 141 |
|
| 142 |
<p align="center">
|
| 143 |
<img width="900" src="figs/Figure-1.png">
|
| 144 |
</p>
|
| 145 |
+
- visual comparison (x4) in the supplementary material
|
| 146 |
+
|
| 147 |
+
<p align="center">
|
| 148 |
+
<img width="900" src="figs/Figure-2.png">
|
| 149 |
+
<img width="900" src="figs/Figure-3.png">
|
| 150 |
+
<img width="900" src="figs/Figure-4.png">
|
| 151 |
+
<img width="900" src="figs/Figure-5.png">
|
| 152 |
+
</p>
|
| 153 |
+
|
| 154 |
|
| 155 |
|
| 156 |
- </details>
|
basicsr/archs/dat_arch.py
CHANGED
|
@@ -297,7 +297,6 @@ class Axial_Spatial_Attention(nn.Module):
|
|
| 297 |
self.register_buffer("attn_mask_0", None)
|
| 298 |
self.register_buffer("attn_mask_1", None)
|
| 299 |
|
| 300 |
-
# Adaptive Interaction Module
|
| 301 |
self.dwconv = nn.Sequential(
|
| 302 |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
|
| 303 |
nn.BatchNorm2d(dim),
|
|
@@ -419,6 +418,7 @@ class Axial_Spatial_Attention(nn.Module):
|
|
| 419 |
# convolution output
|
| 420 |
conv_x = self.dwconv(v)
|
| 421 |
|
|
|
|
| 422 |
# C-Map (before sigmoid)
|
| 423 |
channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C)
|
| 424 |
# S-Map (before sigmoid)
|
|
@@ -460,7 +460,6 @@ class Axial_Channel_Attention(nn.Module):
|
|
| 460 |
self.proj = nn.Linear(dim, dim)
|
| 461 |
self.proj_drop = nn.Dropout(proj_drop)
|
| 462 |
|
| 463 |
-
# Adaptive Interaction Module
|
| 464 |
self.dwconv = nn.Sequential(
|
| 465 |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
|
| 466 |
nn.BatchNorm2d(dim),
|
|
@@ -509,6 +508,7 @@ class Axial_Channel_Attention(nn.Module):
|
|
| 509 |
# convolution output
|
| 510 |
conv_x = self.dwconv(v_)
|
| 511 |
|
|
|
|
| 512 |
# C-Map (before sigmoid)
|
| 513 |
attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
|
| 514 |
channel_map = self.channel_interaction(attention_reshape)
|
|
|
|
| 297 |
self.register_buffer("attn_mask_0", None)
|
| 298 |
self.register_buffer("attn_mask_1", None)
|
| 299 |
|
|
|
|
| 300 |
self.dwconv = nn.Sequential(
|
| 301 |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
|
| 302 |
nn.BatchNorm2d(dim),
|
|
|
|
| 418 |
# convolution output
|
| 419 |
conv_x = self.dwconv(v)
|
| 420 |
|
| 421 |
+
# Adaptive Interaction Module (AIM)
|
| 422 |
# C-Map (before sigmoid)
|
| 423 |
channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C)
|
| 424 |
# S-Map (before sigmoid)
|
|
|
|
| 460 |
self.proj = nn.Linear(dim, dim)
|
| 461 |
self.proj_drop = nn.Dropout(proj_drop)
|
| 462 |
|
|
|
|
| 463 |
self.dwconv = nn.Sequential(
|
| 464 |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),
|
| 465 |
nn.BatchNorm2d(dim),
|
|
|
|
| 508 |
# convolution output
|
| 509 |
conv_x = self.dwconv(v_)
|
| 510 |
|
| 511 |
+
# Adaptive Interaction Module (AIM)
|
| 512 |
# C-Map (before sigmoid)
|
| 513 |
attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)
|
| 514 |
channel_map = self.channel_interaction(attention_reshape)
|
basicsr/train.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from os import path as osp
|
| 7 |
+
|
| 8 |
+
from basicsr.data import build_dataloader, build_dataset
|
| 9 |
+
from basicsr.data.data_sampler import EnlargedSampler
|
| 10 |
+
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
|
| 11 |
+
from basicsr.models import build_model
|
| 12 |
+
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
|
| 13 |
+
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
|
| 14 |
+
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_tb_loggers(opt):
|
| 18 |
+
# initialize wandb logger before tensorboard logger to allow proper sync
|
| 19 |
+
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
|
| 20 |
+
is not None) and ('debug' not in opt['name']):
|
| 21 |
+
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
|
| 22 |
+
init_wandb_logger(opt)
|
| 23 |
+
tb_logger = None
|
| 24 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
|
| 25 |
+
tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
| 26 |
+
return tb_logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_train_val_dataloader(opt, logger):
|
| 30 |
+
# create train and val dataloaders
|
| 31 |
+
train_loader, val_loaders = None, []
|
| 32 |
+
for phase, dataset_opt in opt['datasets'].items():
|
| 33 |
+
if phase == 'train':
|
| 34 |
+
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
|
| 35 |
+
train_set = build_dataset(dataset_opt)
|
| 36 |
+
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
|
| 37 |
+
train_loader = build_dataloader(
|
| 38 |
+
train_set,
|
| 39 |
+
dataset_opt,
|
| 40 |
+
num_gpu=opt['num_gpu'],
|
| 41 |
+
dist=opt['dist'],
|
| 42 |
+
sampler=train_sampler,
|
| 43 |
+
seed=opt['manual_seed'])
|
| 44 |
+
|
| 45 |
+
num_iter_per_epoch = math.ceil(
|
| 46 |
+
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
|
| 47 |
+
total_iters = int(opt['train']['total_iter'])
|
| 48 |
+
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
|
| 49 |
+
logger.info('Training statistics:'
|
| 50 |
+
f'\n\tNumber of train images: {len(train_set)}'
|
| 51 |
+
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
|
| 52 |
+
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
|
| 53 |
+
f'\n\tWorld size (gpu number): {opt["world_size"]}'
|
| 54 |
+
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
|
| 55 |
+
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
|
| 56 |
+
elif phase.split('_')[0] == 'val':
|
| 57 |
+
val_set = build_dataset(dataset_opt)
|
| 58 |
+
val_loader = build_dataloader(
|
| 59 |
+
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
|
| 60 |
+
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
|
| 61 |
+
val_loaders.append(val_loader)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'Dataset phase {phase} is not recognized.')
|
| 64 |
+
|
| 65 |
+
return train_loader, train_sampler, val_loaders, total_epochs, total_iters
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_resume_state(opt):
|
| 69 |
+
resume_state_path = None
|
| 70 |
+
if opt['auto_resume']:
|
| 71 |
+
state_path = osp.join('experiments', opt['name'], 'training_states')
|
| 72 |
+
if osp.isdir(state_path):
|
| 73 |
+
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
|
| 74 |
+
if len(states) != 0:
|
| 75 |
+
states = [float(v.split('.state')[0]) for v in states]
|
| 76 |
+
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
|
| 77 |
+
opt['path']['resume_state'] = resume_state_path
|
| 78 |
+
else:
|
| 79 |
+
if opt['path'].get('resume_state'):
|
| 80 |
+
resume_state_path = opt['path']['resume_state']
|
| 81 |
+
|
| 82 |
+
if resume_state_path is None:
|
| 83 |
+
resume_state = None
|
| 84 |
+
else:
|
| 85 |
+
device_id = torch.cuda.current_device()
|
| 86 |
+
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
|
| 87 |
+
check_resume(opt, resume_state['iter'])
|
| 88 |
+
return resume_state
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def train_pipeline(root_path):
|
| 92 |
+
# parse options, set distributed setting, set ramdom seed
|
| 93 |
+
opt, args = parse_options(root_path, is_train=True)
|
| 94 |
+
opt['root_path'] = root_path
|
| 95 |
+
|
| 96 |
+
torch.backends.cudnn.benchmark = True
|
| 97 |
+
# torch.backends.cudnn.deterministic = True
|
| 98 |
+
|
| 99 |
+
# load resume states if necessary
|
| 100 |
+
resume_state = load_resume_state(opt)
|
| 101 |
+
# mkdir for experiments and logger
|
| 102 |
+
if resume_state is None:
|
| 103 |
+
make_exp_dirs(opt)
|
| 104 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
|
| 105 |
+
mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
| 106 |
+
|
| 107 |
+
# copy the yml file to the experiment root
|
| 108 |
+
copy_opt_file(args.opt, opt['path']['experiments_root'])
|
| 109 |
+
|
| 110 |
+
# WARNING: should not use get_root_logger in the above codes, including the called functions
|
| 111 |
+
# Otherwise the logger will not be properly initialized
|
| 112 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
|
| 113 |
+
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
| 114 |
+
logger.info(get_env_info())
|
| 115 |
+
logger.info(dict2str(opt))
|
| 116 |
+
# initialize wandb and tb loggers
|
| 117 |
+
tb_logger = init_tb_loggers(opt)
|
| 118 |
+
|
| 119 |
+
# create train and validation dataloaders
|
| 120 |
+
result = create_train_val_dataloader(opt, logger)
|
| 121 |
+
train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
|
| 122 |
+
|
| 123 |
+
# create model
|
| 124 |
+
model = build_model(opt)
|
| 125 |
+
if resume_state: # resume training
|
| 126 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
| 127 |
+
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
|
| 128 |
+
start_epoch = resume_state['epoch']
|
| 129 |
+
current_iter = resume_state['iter']
|
| 130 |
+
else:
|
| 131 |
+
start_epoch = 0
|
| 132 |
+
current_iter = 0
|
| 133 |
+
|
| 134 |
+
# create message logger (formatted outputs)
|
| 135 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
| 136 |
+
|
| 137 |
+
# dataloader prefetcher
|
| 138 |
+
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
|
| 139 |
+
if prefetch_mode is None or prefetch_mode == 'cpu':
|
| 140 |
+
prefetcher = CPUPrefetcher(train_loader)
|
| 141 |
+
elif prefetch_mode == 'cuda':
|
| 142 |
+
prefetcher = CUDAPrefetcher(train_loader, opt)
|
| 143 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader')
|
| 144 |
+
if opt['datasets']['train'].get('pin_memory') is not True:
|
| 145 |
+
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
|
| 148 |
+
|
| 149 |
+
# training
|
| 150 |
+
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
|
| 151 |
+
data_timer, iter_timer = AvgTimer(), AvgTimer()
|
| 152 |
+
start_time = time.time()
|
| 153 |
+
|
| 154 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
| 155 |
+
train_sampler.set_epoch(epoch)
|
| 156 |
+
prefetcher.reset()
|
| 157 |
+
train_data = prefetcher.next()
|
| 158 |
+
|
| 159 |
+
while train_data is not None:
|
| 160 |
+
data_timer.record()
|
| 161 |
+
|
| 162 |
+
current_iter += 1
|
| 163 |
+
if current_iter > total_iters:
|
| 164 |
+
break
|
| 165 |
+
# update learning rate
|
| 166 |
+
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
|
| 167 |
+
# training
|
| 168 |
+
model.feed_data(train_data)
|
| 169 |
+
model.optimize_parameters(current_iter)
|
| 170 |
+
iter_timer.record()
|
| 171 |
+
if current_iter == 1:
|
| 172 |
+
# reset start time in msg_logger for more accurate eta_time
|
| 173 |
+
# not work in resume mode
|
| 174 |
+
msg_logger.reset_start_time()
|
| 175 |
+
# log
|
| 176 |
+
if current_iter % opt['logger']['print_freq'] == 0:
|
| 177 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
| 178 |
+
log_vars.update({'lrs': model.get_current_learning_rate()})
|
| 179 |
+
log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
|
| 180 |
+
log_vars.update(model.get_current_log())
|
| 181 |
+
msg_logger(log_vars)
|
| 182 |
+
|
| 183 |
+
# save models and training states
|
| 184 |
+
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
|
| 185 |
+
logger.info('Saving models and training states.')
|
| 186 |
+
model.save(epoch, current_iter)
|
| 187 |
+
|
| 188 |
+
# validation
|
| 189 |
+
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
|
| 190 |
+
if len(val_loaders) > 1:
|
| 191 |
+
logger.warning('Multiple validation datasets are *only* supported by SRModel.')
|
| 192 |
+
for val_loader in val_loaders:
|
| 193 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 194 |
+
|
| 195 |
+
data_timer.start()
|
| 196 |
+
iter_timer.start()
|
| 197 |
+
train_data = prefetcher.next()
|
| 198 |
+
# end of iter
|
| 199 |
+
|
| 200 |
+
# end of epoch
|
| 201 |
+
|
| 202 |
+
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
|
| 203 |
+
logger.info(f'End of training. Time consumed: {consumed_time}')
|
| 204 |
+
logger.info('Save the latest model.')
|
| 205 |
+
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
|
| 206 |
+
if opt.get('val') is not None:
|
| 207 |
+
for val_loader in val_loaders:
|
| 208 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 209 |
+
if tb_logger:
|
| 210 |
+
tb_logger.close()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == '__main__':
|
| 214 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 215 |
+
train_pipeline(root_path)
|
basicsr/version.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# GENERATED VERSION FILE
|
| 2 |
-
# TIME:
|
| 3 |
__version__ = '1.3.5'
|
| 4 |
-
__gitsha__ = '
|
| 5 |
version_info = (1, 3, 5)
|
|
|
|
| 1 |
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Mon Jul 17 01:59:53 2023
|
| 3 |
__version__ = '1.3.5'
|
| 4 |
+
__gitsha__ = '29e57e3'
|
| 5 |
version_info = (1, 3, 5)
|
datasets/README.md
CHANGED
|
@@ -1,2 +1,46 @@
|
|
| 1 |
-
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
For training and testing, the directory structure is as follows:
|
| 2 |
|
| 3 |
+
```shell
|
| 4 |
+
|-- datasets
|
| 5 |
+
# train
|
| 6 |
+
|-- DF2K
|
| 7 |
+
|-- HR
|
| 8 |
+
|-- LR_bicubic
|
| 9 |
+
|-- X2
|
| 10 |
+
|-- X3
|
| 11 |
+
|-- X4
|
| 12 |
+
# test
|
| 13 |
+
|-- benchmark
|
| 14 |
+
|-- Set5
|
| 15 |
+
|-- HR
|
| 16 |
+
|-- LR_bicubic
|
| 17 |
+
|-- X2
|
| 18 |
+
|-- X3
|
| 19 |
+
|-- X4
|
| 20 |
+
|-- Set14
|
| 21 |
+
|-- HR
|
| 22 |
+
|-- LR_bicubic
|
| 23 |
+
|-- X2
|
| 24 |
+
|-- X3
|
| 25 |
+
|-- X4
|
| 26 |
+
|-- B100
|
| 27 |
+
|-- HR
|
| 28 |
+
|-- LR_bicubic
|
| 29 |
+
|-- X2
|
| 30 |
+
|-- X3
|
| 31 |
+
|-- X4
|
| 32 |
+
|-- Urban100
|
| 33 |
+
|-- HR
|
| 34 |
+
|-- LR_bicubic
|
| 35 |
+
|-- X2
|
| 36 |
+
|-- X3
|
| 37 |
+
|-- X4
|
| 38 |
+
|-- Manga109
|
| 39 |
+
|-- HR
|
| 40 |
+
|-- LR_bicubic
|
| 41 |
+
|-- X2
|
| 42 |
+
|-- X3
|
| 43 |
+
|-- X4
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
You can download the complete datasets we have collected.
|
experiments/README.md
CHANGED
|
@@ -1,2 +1 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
| 1 |
+
Place pretrained models in `pretrained_models`.
|
|
|
figs/Figure-2.png
ADDED
|
Git LFS Details
|
figs/Figure-3.png
ADDED
|
Git LFS Details
|
figs/Figure-4.png
ADDED
|
Git LFS Details
|
figs/Figure-5.png
ADDED
|
Git LFS Details
|
figs/Table-2.png
ADDED
|
Git LFS Details
|
options/README.md
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
For more information about testing configuration, please refer to [Configuration](https://github.com/XPixelGroup/BasicSR/blob/master/docs/Config.md).
|
| 2 |
-
|
|
|
|
|
|
|
|
|
options/Test/test_DAT_2_x2.yml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: test_DAT_2_x2
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 2
|
| 5 |
+
num_gpu: 1
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
datasets:
|
| 9 |
+
test_1: # the 1st test dataset
|
| 10 |
+
task: SR
|
| 11 |
+
name: Set5
|
| 12 |
+
type: PairedImageDataset
|
| 13 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 14 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
|
| 15 |
+
filename_tmpl: '{}x2'
|
| 16 |
+
io_backend:
|
| 17 |
+
type: disk
|
| 18 |
+
|
| 19 |
+
test_2: # the 2st test dataset
|
| 20 |
+
task: SR
|
| 21 |
+
name: Set14
|
| 22 |
+
type: PairedImageDataset
|
| 23 |
+
dataroot_gt: datasets/benchmark/Set14/HR
|
| 24 |
+
dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2
|
| 25 |
+
filename_tmpl: '{}x2'
|
| 26 |
+
io_backend:
|
| 27 |
+
type: disk
|
| 28 |
+
|
| 29 |
+
test_3: # the 3st test dataset
|
| 30 |
+
task: SR
|
| 31 |
+
name: B100
|
| 32 |
+
type: PairedImageDataset
|
| 33 |
+
dataroot_gt: datasets/benchmark/B100/HR
|
| 34 |
+
dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
|
| 35 |
+
filename_tmpl: '{}x2'
|
| 36 |
+
io_backend:
|
| 37 |
+
type: disk
|
| 38 |
+
|
| 39 |
+
test_4: # the 4st test dataset
|
| 40 |
+
task: SR
|
| 41 |
+
name: Urban100
|
| 42 |
+
type: PairedImageDataset
|
| 43 |
+
dataroot_gt: datasets/benchmark/Urban100/HR
|
| 44 |
+
dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
|
| 45 |
+
filename_tmpl: '{}x2'
|
| 46 |
+
io_backend:
|
| 47 |
+
type: disk
|
| 48 |
+
|
| 49 |
+
test_5: # the 5st test dataset
|
| 50 |
+
task: SR
|
| 51 |
+
name: Manga109
|
| 52 |
+
type: PairedImageDataset
|
| 53 |
+
dataroot_gt: datasets/benchmark/Manga109/HR
|
| 54 |
+
dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
|
| 55 |
+
filename_tmpl: '{}_LRBI_x2'
|
| 56 |
+
io_backend:
|
| 57 |
+
type: disk
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# network structures
|
| 61 |
+
network_g:
|
| 62 |
+
type: DAT
|
| 63 |
+
upscale: 2
|
| 64 |
+
in_chans: 3
|
| 65 |
+
img_size: 64
|
| 66 |
+
img_range: 1.
|
| 67 |
+
split_size: [8,32]
|
| 68 |
+
depth: [6,6,6,6,6,6]
|
| 69 |
+
embed_dim: 180
|
| 70 |
+
num_heads: [6,6,6,6,6,6]
|
| 71 |
+
expansion_factor: 2
|
| 72 |
+
resi_connection: '1conv'
|
| 73 |
+
|
| 74 |
+
# path
|
| 75 |
+
path:
|
| 76 |
+
pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth
|
| 77 |
+
strict_load_g: True
|
| 78 |
+
|
| 79 |
+
# validation settings
|
| 80 |
+
val:
|
| 81 |
+
save_img: True
|
| 82 |
+
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
+
use_chop: False
|
| 84 |
+
|
| 85 |
+
metrics:
|
| 86 |
+
psnr: # metric name, can be arbitrary
|
| 87 |
+
type: calculate_psnr
|
| 88 |
+
crop_border: 2
|
| 89 |
+
test_y_channel: True
|
| 90 |
+
ssim:
|
| 91 |
+
type: calculate_ssim
|
| 92 |
+
crop_border: 2
|
| 93 |
+
test_y_channel: True
|
options/Test/test_DAT_2_x3.yml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: test_DAT_2_x3
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 3
|
| 5 |
+
num_gpu: 1
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
datasets:
|
| 9 |
+
test_1: # the 1st test dataset
|
| 10 |
+
task: SR
|
| 11 |
+
name: Set5
|
| 12 |
+
type: PairedImageDataset
|
| 13 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 14 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
|
| 15 |
+
filename_tmpl: '{}x3'
|
| 16 |
+
io_backend:
|
| 17 |
+
type: disk
|
| 18 |
+
|
| 19 |
+
test_2: # the 2st test dataset
|
| 20 |
+
task: SR
|
| 21 |
+
name: Set14
|
| 22 |
+
type: PairedImageDataset
|
| 23 |
+
dataroot_gt: datasets/benchmark/Set14/HR
|
| 24 |
+
dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X3
|
| 25 |
+
filename_tmpl: '{}x3'
|
| 26 |
+
io_backend:
|
| 27 |
+
type: disk
|
| 28 |
+
|
| 29 |
+
test_3: # the 3st test dataset
|
| 30 |
+
task: SR
|
| 31 |
+
name: B100
|
| 32 |
+
type: PairedImageDataset
|
| 33 |
+
dataroot_gt: datasets/benchmark/B100/HR
|
| 34 |
+
dataroot_lq: datasets/benchmark/B100/LR_bicubic/X3
|
| 35 |
+
filename_tmpl: '{}x3'
|
| 36 |
+
io_backend:
|
| 37 |
+
type: disk
|
| 38 |
+
|
| 39 |
+
test_4: # the 4st test dataset
|
| 40 |
+
task: SR
|
| 41 |
+
name: Urban100
|
| 42 |
+
type: PairedImageDataset
|
| 43 |
+
dataroot_gt: datasets/benchmark/Urban100/HR
|
| 44 |
+
dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X3
|
| 45 |
+
filename_tmpl: '{}x3'
|
| 46 |
+
io_backend:
|
| 47 |
+
type: disk
|
| 48 |
+
|
| 49 |
+
test_5: # the 5st test dataset
|
| 50 |
+
task: SR
|
| 51 |
+
name: Manga109
|
| 52 |
+
type: PairedImageDataset
|
| 53 |
+
dataroot_gt: datasets/benchmark/Manga109/HR
|
| 54 |
+
dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X3
|
| 55 |
+
filename_tmpl: '{}_LRBI_x3'
|
| 56 |
+
io_backend:
|
| 57 |
+
type: disk
|
| 58 |
+
|
| 59 |
+
# network structures
|
| 60 |
+
network_g:
|
| 61 |
+
type: DAT
|
| 62 |
+
upscale: 3
|
| 63 |
+
in_chans: 3
|
| 64 |
+
img_size: 64
|
| 65 |
+
img_range: 1.
|
| 66 |
+
split_size: [8,32]
|
| 67 |
+
depth: [6,6,6,6,6,6]
|
| 68 |
+
embed_dim: 180
|
| 69 |
+
num_heads: [6,6,6,6,6,6]
|
| 70 |
+
expansion_factor: 2
|
| 71 |
+
resi_connection: '1conv'
|
| 72 |
+
|
| 73 |
+
# path
|
| 74 |
+
path:
|
| 75 |
+
pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x3.pth
|
| 76 |
+
strict_load_g: True
|
| 77 |
+
|
| 78 |
+
# validation settings
|
| 79 |
+
val:
|
| 80 |
+
save_img: True
|
| 81 |
+
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 82 |
+
use_chop: False
|
| 83 |
+
|
| 84 |
+
metrics:
|
| 85 |
+
psnr: # metric name, can be arbitrary
|
| 86 |
+
type: calculate_psnr
|
| 87 |
+
crop_border: 3
|
| 88 |
+
test_y_channel: True
|
| 89 |
+
ssim:
|
| 90 |
+
type: calculate_ssim
|
| 91 |
+
crop_border: 3
|
| 92 |
+
test_y_channel: True
|
options/Test/test_DAT_2_x4.yml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: test_DAT_2_x4
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: 1
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
datasets:
|
| 9 |
+
test_1: # the 1st test dataset
|
| 10 |
+
task: SR
|
| 11 |
+
name: Set5
|
| 12 |
+
type: PairedImageDataset
|
| 13 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 14 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
|
| 15 |
+
filename_tmpl: '{}x4'
|
| 16 |
+
io_backend:
|
| 17 |
+
type: disk
|
| 18 |
+
|
| 19 |
+
test_2: # the 2st test dataset
|
| 20 |
+
task: SR
|
| 21 |
+
name: Set14
|
| 22 |
+
type: PairedImageDataset
|
| 23 |
+
dataroot_gt: datasets/benchmark/Set14/HR
|
| 24 |
+
dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
|
| 25 |
+
filename_tmpl: '{}x4'
|
| 26 |
+
io_backend:
|
| 27 |
+
type: disk
|
| 28 |
+
|
| 29 |
+
test_3: # the 3st test dataset
|
| 30 |
+
task: SR
|
| 31 |
+
name: B100
|
| 32 |
+
type: PairedImageDataset
|
| 33 |
+
dataroot_gt: datasets/benchmark/B100/HR
|
| 34 |
+
dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
|
| 35 |
+
filename_tmpl: '{}x4'
|
| 36 |
+
io_backend:
|
| 37 |
+
type: disk
|
| 38 |
+
|
| 39 |
+
test_4: # the 4st test dataset
|
| 40 |
+
task: SR
|
| 41 |
+
name: Urban100
|
| 42 |
+
type: PairedImageDataset
|
| 43 |
+
dataroot_gt: datasets/benchmark/Urban100/HR
|
| 44 |
+
dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
|
| 45 |
+
filename_tmpl: '{}x4'
|
| 46 |
+
io_backend:
|
| 47 |
+
type: disk
|
| 48 |
+
|
| 49 |
+
test_5: # the 5st test dataset
|
| 50 |
+
task: SR
|
| 51 |
+
name: Manga109
|
| 52 |
+
type: PairedImageDataset
|
| 53 |
+
dataroot_gt: datasets/benchmark/Manga109/HR
|
| 54 |
+
dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
|
| 55 |
+
filename_tmpl: '{}_LRBI_x4'
|
| 56 |
+
io_backend:
|
| 57 |
+
type: disk
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# network structures
|
| 61 |
+
network_g:
|
| 62 |
+
type: DAT
|
| 63 |
+
upscale: 4
|
| 64 |
+
in_chans: 3
|
| 65 |
+
img_size: 64
|
| 66 |
+
img_range: 1.
|
| 67 |
+
split_size: [8,32]
|
| 68 |
+
depth: [6,6,6,6,6,6]
|
| 69 |
+
embed_dim: 180
|
| 70 |
+
num_heads: [6,6,6,6,6,6]
|
| 71 |
+
expansion_factor: 2
|
| 72 |
+
resi_connection: '1conv'
|
| 73 |
+
|
| 74 |
+
# path
|
| 75 |
+
path:
|
| 76 |
+
pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x4.pth
|
| 77 |
+
strict_load_g: True
|
| 78 |
+
|
| 79 |
+
# validation settings
|
| 80 |
+
val:
|
| 81 |
+
save_img: True
|
| 82 |
+
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
+
use_chop: False
|
| 84 |
+
|
| 85 |
+
metrics:
|
| 86 |
+
psnr: # metric name, can be arbitrary
|
| 87 |
+
type: calculate_psnr
|
| 88 |
+
crop_border: 4
|
| 89 |
+
test_y_channel: True
|
| 90 |
+
ssim:
|
| 91 |
+
type: calculate_ssim
|
| 92 |
+
crop_border: 4
|
| 93 |
+
test_y_channel: True
|
options/Test/test_DAT_S_x2.yml
CHANGED
|
@@ -73,12 +73,12 @@ network_g:
|
|
| 73 |
|
| 74 |
# path
|
| 75 |
path:
|
| 76 |
-
pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x2.pth
|
| 77 |
strict_load_g: True
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
-
save_img:
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
|
|
|
| 73 |
|
| 74 |
# path
|
| 75 |
path:
|
| 76 |
+
pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x2.pth
|
| 77 |
strict_load_g: True
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
+
save_img: True
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
options/Test/{test_DAT_S_x3.yml.yml → test_DAT_S_x3.yml}
RENAMED
|
@@ -72,12 +72,12 @@ network_g:
|
|
| 72 |
|
| 73 |
# path
|
| 74 |
path:
|
| 75 |
-
pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x3.pth
|
| 76 |
strict_load_g: True
|
| 77 |
|
| 78 |
# validation settings
|
| 79 |
val:
|
| 80 |
-
save_img:
|
| 81 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 82 |
use_chop: False
|
| 83 |
|
|
|
|
| 72 |
|
| 73 |
# path
|
| 74 |
path:
|
| 75 |
+
pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x3.pth
|
| 76 |
strict_load_g: True
|
| 77 |
|
| 78 |
# validation settings
|
| 79 |
val:
|
| 80 |
+
save_img: True
|
| 81 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 82 |
use_chop: False
|
| 83 |
|
options/Test/test_DAT_S_x4.yml
CHANGED
|
@@ -73,12 +73,12 @@ network_g:
|
|
| 73 |
|
| 74 |
# path
|
| 75 |
path:
|
| 76 |
-
pretrain_network_g: experiments/pretrained_models/DAT/DAT_S_x4.pth
|
| 77 |
strict_load_g: True
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
-
save_img:
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
|
|
|
| 73 |
|
| 74 |
# path
|
| 75 |
path:
|
| 76 |
+
pretrain_network_g: experiments/pretrained_models/DAT-S/DAT_S_x4.pth
|
| 77 |
strict_load_g: True
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
+
save_img: True
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
options/Test/test_DAT_x2.yml
CHANGED
|
@@ -78,7 +78,7 @@ path:
|
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
-
save_img:
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
|
|
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
+
save_img: True
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
options/Test/test_DAT_x3.yml
CHANGED
|
@@ -77,7 +77,7 @@ path:
|
|
| 77 |
|
| 78 |
# validation settings
|
| 79 |
val:
|
| 80 |
-
save_img:
|
| 81 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 82 |
use_chop: False
|
| 83 |
|
|
|
|
| 77 |
|
| 78 |
# validation settings
|
| 79 |
val:
|
| 80 |
+
save_img: True
|
| 81 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 82 |
use_chop: False
|
| 83 |
|
options/Test/test_DAT_x4.yml
CHANGED
|
@@ -78,7 +78,7 @@ path:
|
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
-
save_img:
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
|
|
|
| 78 |
|
| 79 |
# validation settings
|
| 80 |
val:
|
| 81 |
+
save_img: True
|
| 82 |
suffix: ~ # add suffix to saved images, if None, use exp name
|
| 83 |
use_chop: False
|
| 84 |
|
options/Train/train_DAT_2_x2.yml
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_DAT_2_x2
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 2
|
| 5 |
+
num_gpu: auto
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
# dataset and data loader settings
|
| 9 |
+
datasets:
|
| 10 |
+
train:
|
| 11 |
+
task: SR
|
| 12 |
+
name: DF2K
|
| 13 |
+
type: PairedImageDataset
|
| 14 |
+
dataroot_gt: datasets/DF2K/HR
|
| 15 |
+
dataroot_lq: datasets/DF2K/LR_bicubic/X2
|
| 16 |
+
filename_tmpl: '{}x2'
|
| 17 |
+
io_backend:
|
| 18 |
+
type: disk
|
| 19 |
+
|
| 20 |
+
gt_size: 128
|
| 21 |
+
use_hflip: True
|
| 22 |
+
use_rot: True
|
| 23 |
+
|
| 24 |
+
# data loader
|
| 25 |
+
use_shuffle: True
|
| 26 |
+
num_worker_per_gpu: 12
|
| 27 |
+
batch_size_per_gpu: 8
|
| 28 |
+
dataset_enlarge_ratio: 100
|
| 29 |
+
prefetch_mode: ~
|
| 30 |
+
|
| 31 |
+
val:
|
| 32 |
+
task: SR
|
| 33 |
+
name: Set5
|
| 34 |
+
type: PairedImageDataset
|
| 35 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 36 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
|
| 37 |
+
filename_tmpl: '{}x2'
|
| 38 |
+
io_backend:
|
| 39 |
+
type: disk
|
| 40 |
+
|
| 41 |
+
# network structures
|
| 42 |
+
network_g:
|
| 43 |
+
type: DAT
|
| 44 |
+
upscale: 2
|
| 45 |
+
in_chans: 3
|
| 46 |
+
img_size: 64
|
| 47 |
+
img_range: 1.
|
| 48 |
+
split_size: [8,32]
|
| 49 |
+
depth: [6,6,6,6,6,6]
|
| 50 |
+
embed_dim: 180
|
| 51 |
+
num_heads: [6,6,6,6,6,6]
|
| 52 |
+
expansion_factor: 2
|
| 53 |
+
resi_connection: '1conv'
|
| 54 |
+
|
| 55 |
+
# path
|
| 56 |
+
path:
|
| 57 |
+
pretrain_network_g: ~
|
| 58 |
+
strict_load_g: True
|
| 59 |
+
resume_state: ~
|
| 60 |
+
|
| 61 |
+
# training settings
|
| 62 |
+
train:
|
| 63 |
+
optim_g:
|
| 64 |
+
type: Adam
|
| 65 |
+
lr: !!float 2e-4
|
| 66 |
+
weight_decay: 0
|
| 67 |
+
betas: [0.9, 0.99]
|
| 68 |
+
|
| 69 |
+
scheduler:
|
| 70 |
+
type: MultiStepLR
|
| 71 |
+
milestones: [250000, 400000, 450000, 475000]
|
| 72 |
+
gamma: 0.5
|
| 73 |
+
|
| 74 |
+
total_iter: 500000
|
| 75 |
+
warmup_iter: -1 # no warm up
|
| 76 |
+
|
| 77 |
+
# losses
|
| 78 |
+
pixel_opt:
|
| 79 |
+
type: L1Loss
|
| 80 |
+
loss_weight: 1.0
|
| 81 |
+
reduction: mean
|
| 82 |
+
|
| 83 |
+
# validation settings
|
| 84 |
+
val:
|
| 85 |
+
val_freq: !!float 5e3
|
| 86 |
+
save_img: False
|
| 87 |
+
|
| 88 |
+
metrics:
|
| 89 |
+
psnr: # metric name, can be arbitrary
|
| 90 |
+
type: calculate_psnr
|
| 91 |
+
crop_border: 2
|
| 92 |
+
test_y_channel: True
|
| 93 |
+
|
| 94 |
+
# logging settings
|
| 95 |
+
logger:
|
| 96 |
+
print_freq: 200
|
| 97 |
+
save_checkpoint_freq: !!float 5e3
|
| 98 |
+
use_tb_logger: True
|
| 99 |
+
wandb:
|
| 100 |
+
project: ~
|
| 101 |
+
resume_id: ~
|
| 102 |
+
|
| 103 |
+
# dist training settings
|
| 104 |
+
dist_params:
|
| 105 |
+
backend: nccl
|
| 106 |
+
port: 29500
|
options/Train/train_DAT_2_x3.yml
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_DAT_2_x3
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 3
|
| 5 |
+
num_gpu: auto
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
# dataset and data loader settings
|
| 9 |
+
datasets:
|
| 10 |
+
train:
|
| 11 |
+
task: SR
|
| 12 |
+
name: DF2K
|
| 13 |
+
type: PairedImageDataset
|
| 14 |
+
dataroot_gt: datasets/DF2K/HR
|
| 15 |
+
dataroot_lq: datasets/DF2K/LR_bicubic/X3
|
| 16 |
+
filename_tmpl: '{}x3'
|
| 17 |
+
io_backend:
|
| 18 |
+
type: disk
|
| 19 |
+
|
| 20 |
+
gt_size: 192
|
| 21 |
+
use_hflip: True
|
| 22 |
+
use_rot: True
|
| 23 |
+
|
| 24 |
+
# data loader
|
| 25 |
+
use_shuffle: True
|
| 26 |
+
num_worker_per_gpu: 12
|
| 27 |
+
batch_size_per_gpu: 8
|
| 28 |
+
dataset_enlarge_ratio: 100
|
| 29 |
+
prefetch_mode: ~
|
| 30 |
+
|
| 31 |
+
val:
|
| 32 |
+
task: SR
|
| 33 |
+
name: Set5
|
| 34 |
+
type: PairedImageDataset
|
| 35 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 36 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X3
|
| 37 |
+
filename_tmpl: '{}x3'
|
| 38 |
+
io_backend:
|
| 39 |
+
type: disk
|
| 40 |
+
|
| 41 |
+
# network structures
|
| 42 |
+
network_g:
|
| 43 |
+
type: DAT
|
| 44 |
+
upscale: 2
|
| 45 |
+
in_chans: 3
|
| 46 |
+
img_size: 64
|
| 47 |
+
img_range: 1.
|
| 48 |
+
split_size: [8,32]
|
| 49 |
+
depth: [6,6,6,6,6,6]
|
| 50 |
+
embed_dim: 180
|
| 51 |
+
num_heads: [6,6,6,6,6,6]
|
| 52 |
+
expansion_factor: 2
|
| 53 |
+
resi_connection: '1conv'
|
| 54 |
+
|
| 55 |
+
# path
|
| 56 |
+
path:
|
| 57 |
+
pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
|
| 58 |
+
strict_load_g: False
|
| 59 |
+
resume_state: ~
|
| 60 |
+
|
| 61 |
+
# training settings
|
| 62 |
+
train:
|
| 63 |
+
optim_g:
|
| 64 |
+
type: Adam
|
| 65 |
+
# lr: !!float 2e-4
|
| 66 |
+
lr: !!float 1e-4
|
| 67 |
+
weight_decay: 0
|
| 68 |
+
betas: [0.9, 0.99]
|
| 69 |
+
|
| 70 |
+
scheduler:
|
| 71 |
+
type: MultiStepLR
|
| 72 |
+
# milestones: [ 250000, 400000, 450000, 475000 ]
|
| 73 |
+
milestones: [ 125000, 200000, 225000, 237500 ]
|
| 74 |
+
gamma: 0.5
|
| 75 |
+
|
| 76 |
+
# total_iter: 500000
|
| 77 |
+
total_iter: 250000
|
| 78 |
+
warmup_iter: -1 # no warm up
|
| 79 |
+
|
| 80 |
+
# losses
|
| 81 |
+
pixel_opt:
|
| 82 |
+
type: L1Loss
|
| 83 |
+
loss_weight: 1.0
|
| 84 |
+
reduction: mean
|
| 85 |
+
|
| 86 |
+
# validation settings
|
| 87 |
+
val:
|
| 88 |
+
val_freq: !!float 5e3
|
| 89 |
+
save_img: False
|
| 90 |
+
|
| 91 |
+
metrics:
|
| 92 |
+
psnr: # metric name, can be arbitrary
|
| 93 |
+
type: calculate_psnr
|
| 94 |
+
crop_border: 4
|
| 95 |
+
test_y_channel: True
|
| 96 |
+
|
| 97 |
+
# logging settings
|
| 98 |
+
logger:
|
| 99 |
+
print_freq: 200
|
| 100 |
+
save_checkpoint_freq: !!float 5e3
|
| 101 |
+
use_tb_logger: True
|
| 102 |
+
wandb:
|
| 103 |
+
project: ~
|
| 104 |
+
resume_id: ~
|
| 105 |
+
|
| 106 |
+
# dist training settings
|
| 107 |
+
dist_params:
|
| 108 |
+
backend: nccl
|
| 109 |
+
port: 29500
|
options/Train/train_DAT_2_x4.yml
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: test_DAT_2_x4
|
| 3 |
+
model_type: SRModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: auto
|
| 6 |
+
manual_seed: 10
|
| 7 |
+
|
| 8 |
+
# dataset and data loader settings
|
| 9 |
+
datasets:
|
| 10 |
+
train:
|
| 11 |
+
task: SR
|
| 12 |
+
name: DF2K
|
| 13 |
+
type: PairedImageDataset
|
| 14 |
+
dataroot_gt: datasets/DF2K/HR
|
| 15 |
+
dataroot_lq: datasets/DF2K/LR_bicubic/X4
|
| 16 |
+
filename_tmpl: '{}x4'
|
| 17 |
+
io_backend:
|
| 18 |
+
type: disk
|
| 19 |
+
|
| 20 |
+
gt_size: 256
|
| 21 |
+
use_hflip: true
|
| 22 |
+
use_rot: true
|
| 23 |
+
|
| 24 |
+
# data loader
|
| 25 |
+
use_shuffle: True
|
| 26 |
+
num_worker_per_gpu: 12
|
| 27 |
+
batch_size_per_gpu: 8
|
| 28 |
+
dataset_enlarge_ratio: 100
|
| 29 |
+
prefetch_mode: ~
|
| 30 |
+
|
| 31 |
+
val:
|
| 32 |
+
task: SR
|
| 33 |
+
name: Set5
|
| 34 |
+
type: PairedImageDataset
|
| 35 |
+
dataroot_gt: datasets/benchmark/Set5/HR
|
| 36 |
+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
|
| 37 |
+
filename_tmpl: '{}x4'
|
| 38 |
+
io_backend:
|
| 39 |
+
type: disk
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# network structures
|
| 43 |
+
network_g:
|
| 44 |
+
type: DAT
|
| 45 |
+
upscale: 4
|
| 46 |
+
in_chans: 3
|
| 47 |
+
img_size: 64
|
| 48 |
+
img_range: 1.
|
| 49 |
+
split_size: [8,32]
|
| 50 |
+
depth: [6,6,6,6,6,6]
|
| 51 |
+
embed_dim: 180
|
| 52 |
+
num_heads: [6,6,6,6,6,6]
|
| 53 |
+
expansion_factor: 2
|
| 54 |
+
resi_connection: '1conv'
|
| 55 |
+
|
| 56 |
+
# path
|
| 57 |
+
path:
|
| 58 |
+
pretrain_network_g: experiments/pretrained_models/DAT-2/DAT_2_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
|
| 59 |
+
strict_load_g: False
|
| 60 |
+
resume_state: ~
|
| 61 |
+
|
| 62 |
+
# training settings
|
| 63 |
+
train:
|
| 64 |
+
optim_g:
|
| 65 |
+
type: Adam
|
| 66 |
+
# lr: !!float 2e-4
|
| 67 |
+
lr: !!float 1e-4
|
| 68 |
+
weight_decay: 0
|
| 69 |
+
betas: [0.9, 0.99]
|
| 70 |
+
|
| 71 |
+
scheduler:
|
| 72 |
+
type: MultiStepLR
|
| 73 |
+
# milestones: [ 250000, 400000, 450000, 475000 ]
|
| 74 |
+
milestones: [ 125000, 200000, 225000, 237500 ]
|
| 75 |
+
gamma: 0.5
|
| 76 |
+
|
| 77 |
+
# total_iter: 500000
|
| 78 |
+
total_iter: 250000
|
| 79 |
+
warmup_iter: -1 # no warm up
|
| 80 |
+
|
| 81 |
+
# losses
|
| 82 |
+
pixel_opt:
|
| 83 |
+
type: L1Loss
|
| 84 |
+
loss_weight: 1.0
|
| 85 |
+
reduction: mean
|
| 86 |
+
|
| 87 |
+
# validation settings
|
| 88 |
+
val:
|
| 89 |
+
val_freq: !!float 5e3
|
| 90 |
+
save_img: False
|
| 91 |
+
|
| 92 |
+
metrics:
|
| 93 |
+
psnr: # metric name, can be arbitrary
|
| 94 |
+
type: calculate_psnr
|
| 95 |
+
crop_border: 4
|
| 96 |
+
test_y_channel: True
|
| 97 |
+
|
| 98 |
+
# logging settings
|
| 99 |
+
logger:
|
| 100 |
+
print_freq: 200
|
| 101 |
+
save_checkpoint_freq: !!float 5e3
|
| 102 |
+
use_tb_logger: True
|
| 103 |
+
wandb:
|
| 104 |
+
project: ~
|
| 105 |
+
resume_id: ~
|
| 106 |
+
|
| 107 |
+
# dist training settings
|
| 108 |
+
dist_params:
|
| 109 |
+
backend: nccl
|
| 110 |
+
port: 29500
|
options/Train/{train_DAT_S_x3.yml.yml → train_DAT_S_x3.yml}
RENAMED
|
File without changes
|
options/Train/train_DAT_x4.yml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# general settings
|
| 2 |
-
name:
|
| 3 |
model_type: SRModel
|
| 4 |
scale: 4
|
| 5 |
num_gpu: auto
|
|
@@ -55,7 +55,7 @@ network_g:
|
|
| 55 |
|
| 56 |
# path
|
| 57 |
path:
|
| 58 |
-
pretrain_network_g: experiments/pretrained_models/DAT
|
| 59 |
strict_load_g: False
|
| 60 |
resume_state: ~
|
| 61 |
|
|
|
|
| 1 |
# general settings
|
| 2 |
+
name: test_DAT_x4
|
| 3 |
model_type: SRModel
|
| 4 |
scale: 4
|
| 5 |
num_gpu: auto
|
|
|
|
| 55 |
|
| 56 |
# path
|
| 57 |
path:
|
| 58 |
+
pretrain_network_g: experiments/pretrained_models/DAT/DAT_x2.pth # save half of training time if we finetune from x2 and halve initial lr.
|
| 59 |
strict_load_g: False
|
| 60 |
resume_state: ~
|
| 61 |
|