U-Net for Image Inpainting on COCO 2017
This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained on the COCO 2017 dataset for image inpainting.
Model Description
The model is a ComplexUNet architecture, a variant of the standard U-Net adapted for 256x256 images. It features a deep structure with 5 downsampling/upsampling stages and uses residual blocks for more stable training.
How to Use
To use this model, you need to have torch and torchvision installed.
- Place the
model.pyfile in your project directory. - Download the
inpainting_model_coco.pthfile from the 'Files and versions' tab. - Load the model as shown below.
import torch
from model import ComplexUNet # Import the class from model.py
from PIL import Image
import torchvision.transforms as T
# --- Setup ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "inpainting_model_coco.pth" # <-- Make sure you've downloaded this file
# --- Load Model ---
# Note: Use base_channels=64 as it was during training
model = ComplexUNet(base_channels=64)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")
# --- Example: Inpaint an image ---
# 1. Load your masked image
# masked_image = Image.open("path/to/your/masked_image.png").convert("RGB")
#
# 2. Create a tensor from your image
# transform = T.Compose([
# T.Resize(256),
# T.CenterCrop(256),
# T.ToTensor()
# ])
# masked_tensor = transform(masked_image).unsqueeze(0).to(DEVICE)
#
# 3. Get the reconstructed image
# with torch.no_grad():
# reconstructed_tensor = model(masked_tensor)
#
# 4. Convert tensor back to PIL Image
# reconstructed_image = T.ToPILImage()(reconstructed_tensor.squeeze(0).cpu())
# reconstructed_image.save("reconstructed_result.png")
# print("Inpainting complete. Saved to reconstructed_result.png")
Training Details
- Framework: PyTorch
- Dataset: COCO 2017
- Epochs: 10
- Batch Size: 16
- Learning Rate: 0.001
- Optimizer: Adam
- Loss Function: Mean Squared Error (MSE)
- Image Resolution: 256x256
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support