463465810cz commited on
Commit
87bcf27
·
1 Parent(s): 6e0dcf2
Files changed (1) hide show
  1. basicsr/archs/dat_arch.py +4 -4
basicsr/archs/dat_arch.py CHANGED
@@ -137,8 +137,8 @@ class Spatial_Attention(nn.Module):
137
  It supports rectangle window (containing square window).
138
  Args:
139
  dim (int): Number of input channels.
140
- idx (int): The indentix of different shape window.
141
- split_size (tuple(int)): Height or Width of spatial window.
142
  dim_out (int | None): The dimension of the attention output. Default: None
143
  num_heads (int): Number of attention heads. Default: 6
144
  attn_drop (float): Dropout ratio of attention weight. Default: 0.0
@@ -581,7 +581,7 @@ class ResidualGroup(nn.Module):
581
  drop_paths (float | None): Stochastic depth rate.
582
  act_layer (nn.Module): Activation layer. Default: nn.GELU
583
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
584
- depth (int): Number of Cross Aggregation Transformer blocks in residual group.
585
  use_chk (bool): Whether to use checkpointing to save memory.
586
  resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
587
  """
@@ -692,7 +692,7 @@ class DAT(nn.Module):
692
  act_layer (nn.Module): Activation layer. Default: nn.GELU
693
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
694
  use_chk (bool): Whether to use checkpointing to save memory.
695
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for compress artifact reduction
696
  img_range: Image range. 1. or 255.
697
  resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
  """
 
137
  It supports rectangle window (containing square window).
138
  Args:
139
  dim (int): Number of input channels.
140
+ idx (int): The indentix of window. (0/1)
141
+ split_size (tuple(int)): Height and Width of spatial window.
142
  dim_out (int | None): The dimension of the attention output. Default: None
143
  num_heads (int): Number of attention heads. Default: 6
144
  attn_drop (float): Dropout ratio of attention weight. Default: 0.0
 
581
  drop_paths (float | None): Stochastic depth rate.
582
  act_layer (nn.Module): Activation layer. Default: nn.GELU
583
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
584
+ depth (int): Number of dual aggregation Transformer blocks in residual group.
585
  use_chk (bool): Whether to use checkpointing to save memory.
586
  resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
587
  """
 
692
  act_layer (nn.Module): Activation layer. Default: nn.GELU
693
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
694
  use_chk (bool): Whether to use checkpointing to save memory.
695
+ upscale: Upscale factor. 2/3/4 for image SR
696
  img_range: Image range. 1. or 255.
697
  resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
  """