tahamajs commited on
Commit
cab61c7
·
verified ·
1 Parent(s): 20f747b

Initial model upload: COCO inpainting U-Net

Browse files
Files changed (3) hide show
  1. README.md +82 -0
  2. inpainting_model_coco.pth +3 -0
  3. model.py +59 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ license: apache-2.0
4
+ language: en
5
+ library_name: pytorch
6
+ tags:
7
+ - image-inpainting
8
+ - computer-vision
9
+ - pytorch
10
+ - unet
11
+ - coco
12
+ datasets:
13
+ - coco
14
+ ---
15
+
16
+ # U-Net for Image Inpainting on COCO 2017
17
+
18
+ This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained on the COCO 2017 dataset for image inpainting.
19
+
20
+ ## Model Description
21
+
22
+ The model is a `ComplexUNet` architecture, a variant of the standard U-Net adapted for 256x256 images. It features a deep structure with 5 downsampling/upsampling stages and uses residual blocks for more stable training.
23
+
24
+ ## How to Use
25
+
26
+ To use this model, you need to have `torch` and `torchvision` installed.
27
+
28
+ 1. Place the `model.py` file in your project directory.
29
+ 2. Download the `inpainting_model_coco.pth` file from the 'Files and versions' tab.
30
+ 3. Load the model as shown below.
31
+
32
+ ```python
33
+ import torch
34
+ from model import ComplexUNet # Import the class from model.py
35
+ from PIL import Image
36
+ import torchvision.transforms as T
37
+
38
+ # --- Setup ---
39
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ MODEL_PATH = "inpainting_model_coco.pth" # <-- Make sure you've downloaded this file
41
+
42
+ # --- Load Model ---
43
+ # Note: Use base_channels=64 as it was during training
44
+ model = ComplexUNet(base_channels=64)
45
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
46
+ model.to(DEVICE)
47
+ model.eval()
48
+
49
+ print("Model loaded successfully!")
50
+
51
+ # --- Example: Inpaint an image ---
52
+ # 1. Load your masked image
53
+ # masked_image = Image.open("path/to/your/masked_image.png").convert("RGB")
54
+ #
55
+ # 2. Create a tensor from your image
56
+ # transform = T.Compose([
57
+ # T.Resize(256),
58
+ # T.CenterCrop(256),
59
+ # T.ToTensor()
60
+ # ])
61
+ # masked_tensor = transform(masked_image).unsqueeze(0).to(DEVICE)
62
+ #
63
+ # 3. Get the reconstructed image
64
+ # with torch.no_grad():
65
+ # reconstructed_tensor = model(masked_tensor)
66
+ #
67
+ # 4. Convert tensor back to PIL Image
68
+ # reconstructed_image = T.ToPILImage()(reconstructed_tensor.squeeze(0).cpu())
69
+ # reconstructed_image.save("reconstructed_result.png")
70
+ # print("Inpainting complete. Saved to reconstructed_result.png")
71
+ ```
72
+
73
+ ## Training Details
74
+
75
+ - **Framework**: PyTorch
76
+ - **Dataset**: COCO 2017
77
+ - **Epochs**: 10
78
+ - **Batch Size**: 16
79
+ - **Learning Rate**: 0.001
80
+ - **Optimizer**: Adam
81
+ - **Loss Function**: Mean Squared Error (MSE)
82
+ - **Image Resolution**: 256x256
inpainting_model_coco.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:add6cd16e13bbc0003bee51af922041e992ef2a6e58af6eea5f64c8fb1b385ec
3
+ size 520110786
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # Model architecture copied directly from the training script.
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_channels, out_channels):
8
+ super(ResidualBlock, self).__init__()
9
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
10
+ self.bn1 = nn.BatchNorm2d(out_channels)
11
+ self.relu = nn.ReLU(inplace=True)
12
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
13
+ self.bn2 = nn.BatchNorm2d(out_channels)
14
+ if in_channels != out_channels:
15
+ self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
16
+ else:
17
+ self.shortcut = nn.Identity()
18
+ def forward(self, x):
19
+ residual = self.shortcut(x)
20
+ out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
21
+ out = self.conv2(out); out = self.bn2(out)
22
+ out += residual
23
+ out = self.relu(out)
24
+ return out
25
+
26
+ class ComplexUNet(nn.Module):
27
+ def __init__(self, base_channels=64):
28
+ super(ComplexUNet, self).__init__()
29
+ c = base_channels
30
+ self.pool = nn.MaxPool2d(2, 2)
31
+ self.enc1 = ResidualBlock(3, c)
32
+ self.enc2 = ResidualBlock(c, c*2)
33
+ self.enc3 = ResidualBlock(c*2, c*4)
34
+ self.enc4 = ResidualBlock(c*4, c*8)
35
+ self.enc5 = ResidualBlock(c*8, c*16)
36
+ self.bottleneck = ResidualBlock(c*16, c*32)
37
+ self.upconv1 = nn.ConvTranspose2d(c*32, c*16, kernel_size=2, stride=2)
38
+ self.upconv2 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
39
+ self.upconv3 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
40
+ self.upconv4 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
41
+ self.upconv5 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
42
+ self.dec_conv1 = ResidualBlock(c*32, c*16)
43
+ self.dec_conv2 = ResidualBlock(c*16, c*8)
44
+ self.dec_conv3 = ResidualBlock(c*8, c*4)
45
+ self.dec_conv4 = ResidualBlock(c*4, c*2)
46
+ self.dec_conv5 = ResidualBlock(c*2, c)
47
+ self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
48
+ def forward(self, x):
49
+ e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
50
+ e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
51
+ e5 = self.enc5(p4); p5 = self.pool(e5)
52
+ b = self.bottleneck(p5)
53
+ d1 = self.upconv1(b); d1 = torch.cat([d1, e5], dim=1); d1 = self.dec_conv1(d1)
54
+ d2 = self.upconv2(d1); d2 = torch.cat([d2, e4], dim=1); d2 = self.dec_conv2(d2)
55
+ d3 = self.upconv3(d2); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec_conv3(d3)
56
+ d4 = self.upconv4(d3); d4 = torch.cat([d4, e2], dim=1); d4 = self.dec_conv4(d4)
57
+ d5 = self.upconv5(d4); d5 = torch.cat([d5, e1], dim=1); d5 = self.dec_conv5(d5)
58
+ out = self.final_conv(d5)
59
+ return torch.sigmoid(out)