Initial model upload: COCO inpainting U-Net
Browse files- README.md +82 -0
- inpainting_model_coco.pth +3 -0
- model.py +59 -0
README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
---
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
language: en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- image-inpainting
|
| 8 |
+
- computer-vision
|
| 9 |
+
- pytorch
|
| 10 |
+
- unet
|
| 11 |
+
- coco
|
| 12 |
+
datasets:
|
| 13 |
+
- coco
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# U-Net for Image Inpainting on COCO 2017
|
| 17 |
+
|
| 18 |
+
This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained on the COCO 2017 dataset for image inpainting.
|
| 19 |
+
|
| 20 |
+
## Model Description
|
| 21 |
+
|
| 22 |
+
The model is a `ComplexUNet` architecture, a variant of the standard U-Net adapted for 256x256 images. It features a deep structure with 5 downsampling/upsampling stages and uses residual blocks for more stable training.
|
| 23 |
+
|
| 24 |
+
## How to Use
|
| 25 |
+
|
| 26 |
+
To use this model, you need to have `torch` and `torchvision` installed.
|
| 27 |
+
|
| 28 |
+
1. Place the `model.py` file in your project directory.
|
| 29 |
+
2. Download the `inpainting_model_coco.pth` file from the 'Files and versions' tab.
|
| 30 |
+
3. Load the model as shown below.
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
import torch
|
| 34 |
+
from model import ComplexUNet # Import the class from model.py
|
| 35 |
+
from PIL import Image
|
| 36 |
+
import torchvision.transforms as T
|
| 37 |
+
|
| 38 |
+
# --- Setup ---
|
| 39 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
MODEL_PATH = "inpainting_model_coco.pth" # <-- Make sure you've downloaded this file
|
| 41 |
+
|
| 42 |
+
# --- Load Model ---
|
| 43 |
+
# Note: Use base_channels=64 as it was during training
|
| 44 |
+
model = ComplexUNet(base_channels=64)
|
| 45 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 46 |
+
model.to(DEVICE)
|
| 47 |
+
model.eval()
|
| 48 |
+
|
| 49 |
+
print("Model loaded successfully!")
|
| 50 |
+
|
| 51 |
+
# --- Example: Inpaint an image ---
|
| 52 |
+
# 1. Load your masked image
|
| 53 |
+
# masked_image = Image.open("path/to/your/masked_image.png").convert("RGB")
|
| 54 |
+
#
|
| 55 |
+
# 2. Create a tensor from your image
|
| 56 |
+
# transform = T.Compose([
|
| 57 |
+
# T.Resize(256),
|
| 58 |
+
# T.CenterCrop(256),
|
| 59 |
+
# T.ToTensor()
|
| 60 |
+
# ])
|
| 61 |
+
# masked_tensor = transform(masked_image).unsqueeze(0).to(DEVICE)
|
| 62 |
+
#
|
| 63 |
+
# 3. Get the reconstructed image
|
| 64 |
+
# with torch.no_grad():
|
| 65 |
+
# reconstructed_tensor = model(masked_tensor)
|
| 66 |
+
#
|
| 67 |
+
# 4. Convert tensor back to PIL Image
|
| 68 |
+
# reconstructed_image = T.ToPILImage()(reconstructed_tensor.squeeze(0).cpu())
|
| 69 |
+
# reconstructed_image.save("reconstructed_result.png")
|
| 70 |
+
# print("Inpainting complete. Saved to reconstructed_result.png")
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Training Details
|
| 74 |
+
|
| 75 |
+
- **Framework**: PyTorch
|
| 76 |
+
- **Dataset**: COCO 2017
|
| 77 |
+
- **Epochs**: 10
|
| 78 |
+
- **Batch Size**: 16
|
| 79 |
+
- **Learning Rate**: 0.001
|
| 80 |
+
- **Optimizer**: Adam
|
| 81 |
+
- **Loss Function**: Mean Squared Error (MSE)
|
| 82 |
+
- **Image Resolution**: 256x256
|
inpainting_model_coco.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:add6cd16e13bbc0003bee51af922041e992ef2a6e58af6eea5f64c8fb1b385ec
|
| 3 |
+
size 520110786
|
model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
# Model architecture copied directly from the training script.
|
| 6 |
+
class ResidualBlock(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels):
|
| 8 |
+
super(ResidualBlock, self).__init__()
|
| 9 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
| 10 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 11 |
+
self.relu = nn.ReLU(inplace=True)
|
| 12 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
| 13 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 14 |
+
if in_channels != out_channels:
|
| 15 |
+
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
|
| 16 |
+
else:
|
| 17 |
+
self.shortcut = nn.Identity()
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
residual = self.shortcut(x)
|
| 20 |
+
out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
|
| 21 |
+
out = self.conv2(out); out = self.bn2(out)
|
| 22 |
+
out += residual
|
| 23 |
+
out = self.relu(out)
|
| 24 |
+
return out
|
| 25 |
+
|
| 26 |
+
class ComplexUNet(nn.Module):
|
| 27 |
+
def __init__(self, base_channels=64):
|
| 28 |
+
super(ComplexUNet, self).__init__()
|
| 29 |
+
c = base_channels
|
| 30 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 31 |
+
self.enc1 = ResidualBlock(3, c)
|
| 32 |
+
self.enc2 = ResidualBlock(c, c*2)
|
| 33 |
+
self.enc3 = ResidualBlock(c*2, c*4)
|
| 34 |
+
self.enc4 = ResidualBlock(c*4, c*8)
|
| 35 |
+
self.enc5 = ResidualBlock(c*8, c*16)
|
| 36 |
+
self.bottleneck = ResidualBlock(c*16, c*32)
|
| 37 |
+
self.upconv1 = nn.ConvTranspose2d(c*32, c*16, kernel_size=2, stride=2)
|
| 38 |
+
self.upconv2 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
|
| 39 |
+
self.upconv3 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
|
| 40 |
+
self.upconv4 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
|
| 41 |
+
self.upconv5 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
|
| 42 |
+
self.dec_conv1 = ResidualBlock(c*32, c*16)
|
| 43 |
+
self.dec_conv2 = ResidualBlock(c*16, c*8)
|
| 44 |
+
self.dec_conv3 = ResidualBlock(c*8, c*4)
|
| 45 |
+
self.dec_conv4 = ResidualBlock(c*4, c*2)
|
| 46 |
+
self.dec_conv5 = ResidualBlock(c*2, c)
|
| 47 |
+
self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
|
| 50 |
+
e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
|
| 51 |
+
e5 = self.enc5(p4); p5 = self.pool(e5)
|
| 52 |
+
b = self.bottleneck(p5)
|
| 53 |
+
d1 = self.upconv1(b); d1 = torch.cat([d1, e5], dim=1); d1 = self.dec_conv1(d1)
|
| 54 |
+
d2 = self.upconv2(d1); d2 = torch.cat([d2, e4], dim=1); d2 = self.dec_conv2(d2)
|
| 55 |
+
d3 = self.upconv3(d2); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec_conv3(d3)
|
| 56 |
+
d4 = self.upconv4(d3); d4 = torch.cat([d4, e2], dim=1); d4 = self.dec_conv4(d4)
|
| 57 |
+
d5 = self.upconv5(d4); d5 = torch.cat([d5, e1], dim=1); d5 = self.dec_conv5(d5)
|
| 58 |
+
out = self.final_conv(d5)
|
| 59 |
+
return torch.sigmoid(out)
|