463465810cz
commited on
Commit
·
5841fb1
1
Parent(s):
86de95a
Update dat_arch.py
Browse filesFormer-commit-id: 6846c8798e4a0579451982bd8b6e441f4b1e78b0
basicsr/archs/dat_arch.py
CHANGED
|
@@ -646,7 +646,7 @@ class ResidualGroup(nn.Module):
|
|
| 646 |
x = checkpoint.checkpoint(blk, x, x_size)
|
| 647 |
else:
|
| 648 |
x = blk(x, x_size)
|
| 649 |
-
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
| 650 |
x = self.conv(x)
|
| 651 |
x = rearrange(x, "b c h w -> b (h w) c")
|
| 652 |
x = res + x
|
|
@@ -835,7 +835,7 @@ class DAT(nn.Module):
|
|
| 835 |
for layer in self.layers:
|
| 836 |
x = layer(x, x_size)
|
| 837 |
x = self.norm(x)
|
| 838 |
-
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
| 839 |
|
| 840 |
return x
|
| 841 |
|
|
|
|
| 646 |
x = checkpoint.checkpoint(blk, x, x_size)
|
| 647 |
else:
|
| 648 |
x = blk(x, x_size)
|
| 649 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous()
|
| 650 |
x = self.conv(x)
|
| 651 |
x = rearrange(x, "b c h w -> b (h w) c")
|
| 652 |
x = res + x
|
|
|
|
| 835 |
for layer in self.layers:
|
| 836 |
x = layer(x, x_size)
|
| 837 |
x = self.norm(x)
|
| 838 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous()
|
| 839 |
|
| 840 |
return x
|
| 841 |
|