Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,148 Bytes
f460ce6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from tqdm import trange
import torchvision.transforms as T
import torch.nn.functional as F
from typing import Tuple
import scipy.ndimage
import cv2
from train.src.condition.util import HWC3, common_input_validate
def check_image_mask(image, mask, name):
if len(image.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
image = image[None,:,:,:]
if len(mask.shape) > 3:
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
# take first mask, red channel
mask = (mask[:,:,:,0])[:,:,:]
elif len(mask.shape) < 3:
# mask tensor shape should be [B, H, W] but batch somehow is missing
mask = mask[None,:,:]
if image.shape[0] > mask.shape[0]:
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
if mask.shape[0] == 1:
print(name, "will copy the mask to fill batch")
mask = torch.cat([mask] * image.shape[0], dim=0)
else:
print(name, "will add empty masks to fill batch")
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
mask = torch.cat([mask, empty_mask], dim=0)
elif image.shape[0] < mask.shape[0]:
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
mask = mask[:image.shape[0],:,:]
return (image, mask)
def cv2_resize_shortest_edge(image, size):
h, w = image.shape[:2]
if h < w:
new_h = size
new_w = int(round(w / h * size))
else:
new_w = size
new_h = int(round(h / w * size))
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
return resized_image
def apply_color(img, res=512):
img = cv2_resize_shortest_edge(img, res)
h, w = img.shape[:2]
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
return input_img_color
#Color T2I like multiples-of-64, upscale methods are fixed.
class ColorDetector:
def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image = HWC3(input_image)
detected_map = HWC3(apply_color(input_image, detect_resolution))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
class InpaintPreprocessor:
def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
image = image.clone()
if black_pixel_for_xinsir_cn:
masked_pixel = 0.0
else:
masked_pixel = -1.0
image[mask > 0.5] = masked_pixel
return (image,)
class BlendInpaint:
def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
original, mask = check_image_mask(original, mask, 'Blend Inpaint')
if len(inpaint.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
inpaint = inpaint[None,:,:,:]
if inpaint.shape[0] < original.shape[0]:
print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
original= original[:inpaint.shape[0],:,:]
mask = mask[:inpaint.shape[0],:,:]
if inpaint.shape[0] > original.shape[0]:
# batch over inpaint
count = 0
original_list = []
mask_list = []
origin_list = []
while (count < inpaint.shape[0]):
for i in range(original.shape[0]):
original_list.append(original[i][None,:,:,:])
mask_list.append(mask[i][None,:,:])
if origin is not None:
origin_list.append(origin[i][None,:])
count += 1
if count >= inpaint.shape[0]:
break
original = torch.concat(original_list, dim=0)
mask = torch.concat(mask_list, dim=0)
if origin is not None:
origin = torch.concat(origin_list, dim=0)
if kernel % 2 == 0:
kernel += 1
transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
ret = []
blurred = []
for i in range(inpaint.shape[0]):
if origin is None:
blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
blurred.append(blurred_mask[0])
result = torch.nn.functional.interpolate(
inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
size=(
original[i].shape[0],
original[i].shape[1],
)
).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
else:
# got mask from CutForInpaint
height, width, _ = original[i].shape
x0 = origin[i][0].item()
y0 = origin[i][1].item()
if mask[i].shape[0] < height or mask[i].shape[1] < width:
padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
else:
padded_mask = mask[i]
blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
blurred.append(blurred_mask[0][0])
result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
result = result[None,:,:,:].to(original.device).to(original.dtype)
ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
return (torch.stack(ret), torch.stack(blurred), )
def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
class JoinImageWithAlpha:
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha))
out_images = []
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
class GrowMask:
def expand_mask(self, mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return (torch.stack(out, dim=0),)
class InvertMask:
def invert(self, mask):
out = 1.0 - mask
return (out,) |