Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,251 Bytes
f460ce6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import random
from collections import Counter
import numpy as np
from torchvision import transforms
import cv2 # OpenCV
import torch
import re
import io
import base64
from PIL import Image, ImageOps
from src.pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
def get_bounding_box_from_mask(mask, padded=False):
mask = mask.squeeze()
rows, cols = torch.where(mask > 0.5)
if len(rows) == 0 or len(cols) == 0:
return (0, 0, 0, 0)
height, width = mask.shape
if padded:
padded_size = max(width, height)
if width < height:
offset_x = (padded_size - width) / 2
offset_y = 0
else:
offset_y = (padded_size - height) / 2
offset_x = 0
top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
else:
offset_x = 0
offset_y = 0
top_left_x = round(float(torch.min(cols).item() / width), 3)
bottom_right_x = round(float(torch.max(cols).item() / width), 3)
top_left_y = round(float(torch.min(rows).item() / height), 3)
bottom_right_y = round(float(torch.max(rows).item() / height), 3)
return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
def extract_bbox(text):
pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
match = re.search(pattern, text)
return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))
def resize_bbox(bbox, width_ratio, height_ratio):
x1, y1, x2, y2 = bbox
new_x1 = int(x1 * width_ratio)
new_y1 = int(y1 * height_ratio)
new_x2 = int(x2 * width_ratio)
new_y2 = int(y2 * height_ratio)
return (new_x1, new_y1, new_x2, new_y2)
def tensor_to_base64(tensor, quality=80, method=6):
tensor = tensor.squeeze(0).clone().detach().cpu()
if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
tensor *= 255
tensor = tensor.to(torch.uint8)
if tensor.ndim == 2: # 灰度图像
pil_image = Image.fromarray(tensor.numpy(), 'L')
pil_image = pil_image.convert('RGB')
elif tensor.ndim == 3:
if tensor.shape[2] == 1: # 单通道
pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
pil_image = pil_image.convert('RGB')
elif tensor.shape[2] == 3: # RGB
pil_image = Image.fromarray(tensor.numpy(), 'RGB')
elif tensor.shape[2] == 4: # RGBA
pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
else:
raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
else:
raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
buffered = io.BytesIO()
pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
image = Image.open(image_path)
image = ImageOps.exif_transpose(image)
if image.mode == 'RGBA':
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background, image)
image = image.convert(convert_to)
image_array = np.array(image).astype(np.float32) / 255.0
if has_alpha and convert_to == 'RGBA':
image_tensor = torch.from_numpy(image_array)[None,]
else:
if len(image_array.shape) == 3 and image_array.shape[2] > 3:
image_array = image_array[:, :, :3]
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def process_background(base64_image, convert_to='RGB', size=None):
image_data = read_base64_image(base64_image)
image = Image.open(image_data)
image = ImageOps.exif_transpose(image)
image = image.convert(convert_to)
# Select preferred size by closest aspect ratio, then snap to multiple_of
w0, h0 = image.size
aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
# Choose the (w, h) whose aspect ratio is closest to the input
_, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
multiple_of = 16 # default: vae_scale_factor (8) * 2
tw = (tw // multiple_of) * multiple_of
th = (th // multiple_of) * multiple_of
if (w0, h0) != (tw, th):
image = image.resize((tw, th), resample=Image.BICUBIC)
image_array = np.array(image).astype(np.uint8)
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def read_base64_image(base64_image):
if base64_image.startswith("data:image/png;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/jpeg;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/webp;base64,"):
base64_image = base64_image.split(",")[1]
else:
raise ValueError("Unsupported image format.")
image_data = base64.b64decode(base64_image)
return io.BytesIO(image_data)
def create_alpha_mask(image_path):
"""Create an alpha mask from the alpha channel of an image."""
image = Image.open(image_path)
image = ImageOps.exif_transpose(image)
mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
if 'A' in image.getbands():
alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
mask[0] = 1.0 - torch.from_numpy(alpha_channel)
return mask
def get_mask_bbox(mask_tensor, padding=10):
assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
_, H, W = mask_tensor.shape
mask_2d = mask_tensor.squeeze(0)
y_coords, x_coords = torch.where(mask_2d > 0)
if len(y_coords) == 0:
return None
x_min = int(torch.min(x_coords))
y_min = int(torch.min(y_coords))
x_max = int(torch.max(x_coords))
y_max = int(torch.max(y_coords))
x_min = max(0, x_min - padding)
y_min = max(0, y_min - padding)
x_max = min(W - 1, x_max + padding)
y_max = min(H - 1, y_max + padding)
return x_min, y_min, x_max, y_max
def tensor_to_pil(tensor):
tensor = tensor.squeeze(0).clone().detach().cpu()
if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
if tensor.max() <= 1.0:
tensor *= 255
tensor = tensor.to(torch.uint8)
if tensor.ndim == 2: # 灰度图像 [H, W]
return Image.fromarray(tensor.numpy(), 'L')
elif tensor.ndim == 3:
if tensor.shape[2] == 1: # 单通道 [H, W, 1]
return Image.fromarray(tensor.numpy().squeeze(2), 'L')
elif tensor.shape[2] >= 3: # RGB [H, W, 3]
return Image.fromarray(tensor.numpy(), 'RGB')
else:
raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
else:
raise ValueError(f"不支持的tensor维度: {tensor.ndim}") |