U-Net for Image Inpainting on COCO 2017

This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained on the COCO 2017 dataset for image inpainting.

Model Description

The model is a ComplexUNet architecture, a variant of the standard U-Net adapted for 256x256 images. It features a deep structure with 5 downsampling/upsampling stages and uses residual blocks for more stable training.

How to Use

To use this model, you need to have torch and torchvision installed.

  1. Place the model.py file in your project directory.
  2. Download the inpainting_model_coco.pth file from the 'Files and versions' tab.
  3. Load the model as shown below.
import torch
from model import ComplexUNet # Import the class from model.py
from PIL import Image
import torchvision.transforms as T

# --- Setup ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "inpainting_model_coco.pth" # <-- Make sure you've downloaded this file

# --- Load Model ---
# Note: Use base_channels=64 as it was during training
model = ComplexUNet(base_channels=64)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

print("Model loaded successfully!")

# --- Example: Inpaint an image ---
# 1. Load your masked image
# masked_image = Image.open("path/to/your/masked_image.png").convert("RGB")
#
# 2. Create a tensor from your image
# transform = T.Compose([
#     T.Resize(256),
#     T.CenterCrop(256),
#     T.ToTensor()
# ])
# masked_tensor = transform(masked_image).unsqueeze(0).to(DEVICE)
#
# 3. Get the reconstructed image
# with torch.no_grad():
#     reconstructed_tensor = model(masked_tensor)
#
# 4. Convert tensor back to PIL Image
# reconstructed_image = T.ToPILImage()(reconstructed_tensor.squeeze(0).cpu())
# reconstructed_image.save("reconstructed_result.png")
# print("Inpainting complete. Saved to reconstructed_result.png")

Training Details

  • Framework: PyTorch
  • Dataset: COCO 2017
  • Epochs: 10
  • Batch Size: 16
  • Learning Rate: 0.001
  • Optimizer: Adam
  • Loss Function: Mean Squared Error (MSE)
  • Image Resolution: 256x256
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support