463465810cz commited on
Commit
5841fb1
·
1 Parent(s): 86de95a

Update dat_arch.py

Browse files

Former-commit-id: 6846c8798e4a0579451982bd8b6e441f4b1e78b0

Files changed (1) hide show
  1. basicsr/archs/dat_arch.py +2 -2
basicsr/archs/dat_arch.py CHANGED
@@ -646,7 +646,7 @@ class ResidualGroup(nn.Module):
646
  x = checkpoint.checkpoint(blk, x, x_size)
647
  else:
648
  x = blk(x, x_size)
649
- x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
650
  x = self.conv(x)
651
  x = rearrange(x, "b c h w -> b (h w) c")
652
  x = res + x
@@ -835,7 +835,7 @@ class DAT(nn.Module):
835
  for layer in self.layers:
836
  x = layer(x, x_size)
837
  x = self.norm(x)
838
- x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
839
 
840
  return x
841
 
 
646
  x = checkpoint.checkpoint(blk, x, x_size)
647
  else:
648
  x = blk(x, x_size)
649
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous()
650
  x = self.conv(x)
651
  x = rearrange(x, "b c h w -> b (h w) c")
652
  x = res + x
 
835
  for layer in self.layers:
836
  x = layer(x, x_size)
837
  x = self.norm(x)
838
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous()
839
 
840
  return x
841