Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from types import SimpleNamespace | |
| from deepfillv2 import test_dataset, utils | |
| from config import * | |
| class InpaintingTester: | |
| def __init__(self, save_path, resize_to=None): | |
| if resize_to is None: | |
| resize_to = RESIZE_TO | |
| self.save_path = save_path | |
| self.setsize = resize_to | |
| # Build the generator network | |
| opt = SimpleNamespace( | |
| pad_type=PAD_TYPE, | |
| in_channels=IN_CHANNELS, | |
| out_channels=OUT_CHANNELS, | |
| latent_channels=LATENT_CHANNELS, | |
| activation=ACTIVATION, | |
| norm=NORM, | |
| init_type=INIT_TYPE, | |
| init_gain=INIT_GAIN, | |
| use_cuda=CUDA, | |
| gpu_device=GPU_DEVICE, | |
| ) | |
| # Initialize generator (only once) | |
| self.generator = utils.create_generator(opt).eval() | |
| # Load pretrained model weights | |
| # print("-- INPAINT: Loading Pretrained Model --") | |
| self.load_model_generator(self.generator) | |
| # Move the generator to GPU | |
| self.generator = self.generator.to(GPU_DEVICE) | |
| def load_model_generator(self, generator): | |
| pretrained_dict = torch.load( | |
| DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True | |
| ) | |
| generator.load_state_dict(pretrained_dict) | |
| def process_image(self, in_image, mask_image, save_image_path): | |
| # Initialize dataset and dataloader | |
| trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize) | |
| dataloader = DataLoader( | |
| trainset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=8, | |
| pin_memory=True, | |
| ) | |
| # Testing loop for a single image | |
| for batch_idx, (img, mask) in enumerate(dataloader): | |
| img = img.to(GPU_DEVICE) | |
| mask = mask.to(GPU_DEVICE) | |
| # Generator output | |
| with torch.no_grad(): | |
| first_out, second_out = self.generator(img, mask) | |
| # Combine outputs with input | |
| first_out_wholeimg = img * (1 - mask) + first_out * mask | |
| second_out_wholeimg = img * (1 - mask) + second_out * mask | |
| masked_img = img * (1 - mask) + mask | |
| mask = torch.cat((mask, mask, mask), 1) | |
| img_list = [second_out_wholeimg] | |
| name_list = ["second_out"] | |
| # Save the sample image | |
| results_path = os.path.dirname(save_image_path) | |
| if not os.path.exists(results_path): | |
| os.makedirs(results_path) | |
| utils.save_sample_png( | |
| sample_folder=results_path, | |
| sample_name=os.path.basename(save_image_path), | |
| img_list=img_list, | |
| name_list=name_list, | |
| pixel_max_cnt=255, | |
| ) | |
| def process_multiple_images(self, image_mask_pairs): | |
| # Iterate through a list of image/mask pairs and save results | |
| png_images=[] | |
| for img_path, mask_path in image_mask_pairs: | |
| try: | |
| save_image_path = os.path.join(self.save_path, os.path.basename(img_path)) | |
| print(f"Processing: {img_path} and {mask_path}") | |
| self.process_image(img_path, mask_path, save_image_path) | |
| extention = os.path.splitext(save_image_path)[1] | |
| save_at=save_image_path.replace(extention, ".png") | |
| png_images.append(save_at) | |
| except Exception as e: | |
| if self.save_path in png_images: | |
| png_images.pop() | |
| png_images.append(None) | |
| print(f"Error: {e}") | |
| # print("-- All Inpainting is finished --") | |
| return png_images | |
| # Main execution | |
| # if __name__ == "__main__": | |
| # save_path = "./output" | |
| # resize_to = None # Default size from config | |
| # # List of image and mask pairs | |
| # image_mask_pairs = [ | |
| # ( "./input/image.jpg", "./input/mask.jpg"), | |
| # ] | |
| # tester = InpaintingTester(save_path, resize_to) | |
| # # Process multiple images using a loop | |
| # results=tester.process_multiple_images(image_mask_pairs) | |
| # print(results) | |