File size: 7,251 Bytes
f460ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import random
from collections import Counter
import numpy as np
from torchvision import transforms
import cv2  # OpenCV
import torch
import re
import io
import base64
from PIL import Image, ImageOps
from src.pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS

def get_bounding_box_from_mask(mask, padded=False):
    mask = mask.squeeze()
    rows, cols = torch.where(mask > 0.5)
    if len(rows) == 0 or len(cols) == 0:
        return (0, 0, 0, 0)
    height, width = mask.shape
    if padded:
        padded_size = max(width, height)
        if width < height:
            offset_x = (padded_size - width) / 2
            offset_y = 0
        else:
            offset_y = (padded_size - height) / 2
            offset_x = 0
        top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
        bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
        top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
        bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
    else:
        offset_x = 0
        offset_y = 0

        top_left_x = round(float(torch.min(cols).item() / width), 3)
        bottom_right_x = round(float(torch.max(cols).item() / width), 3)
        top_left_y = round(float(torch.min(rows).item() / height), 3)
        bottom_right_y = round(float(torch.max(rows).item() / height), 3)

    
    return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)

def extract_bbox(text):
    pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
    match = re.search(pattern, text)
    return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))

def resize_bbox(bbox, width_ratio, height_ratio):
    x1, y1, x2, y2 = bbox
    new_x1 = int(x1 * width_ratio)
    new_y1 = int(y1 * height_ratio)
    new_x2 = int(x2 * width_ratio)
    new_y2 = int(y2 * height_ratio)

    return (new_x1, new_y1, new_x2, new_y2)


def tensor_to_base64(tensor, quality=80, method=6):
    tensor = tensor.squeeze(0).clone().detach().cpu()
    
    if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
        tensor *= 255
    tensor = tensor.to(torch.uint8)
    
    if tensor.ndim == 2:  # 灰度图像
        pil_image = Image.fromarray(tensor.numpy(), 'L')
        pil_image = pil_image.convert('RGB')
    elif tensor.ndim == 3:
        if tensor.shape[2] == 1:  # 单通道
            pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
            pil_image = pil_image.convert('RGB')
        elif tensor.shape[2] == 3:  # RGB
            pil_image = Image.fromarray(tensor.numpy(), 'RGB')
        elif tensor.shape[2] == 4:  # RGBA
            pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
        else:
            raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
    else:
        raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
    
    buffered = io.BytesIO()
    pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str

def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
    image = Image.open(image_path)
    image = ImageOps.exif_transpose(image)

    if image.mode == 'RGBA':
        background = Image.new('RGBA', image.size, (255, 255, 255, 255))
        image = Image.alpha_composite(background, image)
    image = image.convert(convert_to)
    image_array = np.array(image).astype(np.float32) / 255.0

    if has_alpha and convert_to == 'RGBA':
        image_tensor = torch.from_numpy(image_array)[None,]
    else:
        if len(image_array.shape) == 3 and image_array.shape[2] > 3:
            image_array = image_array[:, :, :3]
        image_tensor = torch.from_numpy(image_array)[None,]
    
    return image_tensor

def process_background(base64_image, convert_to='RGB', size=None):
    image_data = read_base64_image(base64_image)
    image = Image.open(image_data)
    image = ImageOps.exif_transpose(image)
    image = image.convert(convert_to)

    # Select preferred size by closest aspect ratio, then snap to multiple_of
    w0, h0 = image.size
    aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
    # Choose the (w, h) whose aspect ratio is closest to the input
    _, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
    multiple_of = 16  # default: vae_scale_factor (8) * 2
    tw = (tw // multiple_of) * multiple_of
    th = (th // multiple_of) * multiple_of

    if (w0, h0) != (tw, th):
        image = image.resize((tw, th), resample=Image.BICUBIC)

    image_array = np.array(image).astype(np.uint8)
    image_tensor = torch.from_numpy(image_array)[None,]
    return image_tensor

def read_base64_image(base64_image):
    if base64_image.startswith("data:image/png;base64,"):
        base64_image = base64_image.split(",")[1]
    elif base64_image.startswith("data:image/jpeg;base64,"):
        base64_image = base64_image.split(",")[1]
    elif base64_image.startswith("data:image/webp;base64,"):
        base64_image = base64_image.split(",")[1]
    else:
        raise ValueError("Unsupported image format.")
    image_data = base64.b64decode(base64_image)
    return io.BytesIO(image_data)

def create_alpha_mask(image_path):
    """Create an alpha mask from the alpha channel of an image."""
    image = Image.open(image_path)
    image = ImageOps.exif_transpose(image)
    mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
    if 'A' in image.getbands():
        alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
        mask[0] = 1.0 - torch.from_numpy(alpha_channel)
    return mask

def get_mask_bbox(mask_tensor, padding=10):
    assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
    _, H, W = mask_tensor.shape
    mask_2d = mask_tensor.squeeze(0)
    
    y_coords, x_coords = torch.where(mask_2d > 0)
    
    if len(y_coords) == 0:
        return None
    
    x_min = int(torch.min(x_coords))
    y_min = int(torch.min(y_coords))
    x_max = int(torch.max(x_coords))
    y_max = int(torch.max(y_coords))
    
    x_min = max(0, x_min - padding)
    y_min = max(0, y_min - padding)
    x_max = min(W - 1, x_max + padding)
    y_max = min(H - 1, y_max + padding)
    
    return x_min, y_min, x_max, y_max

def tensor_to_pil(tensor):
    tensor = tensor.squeeze(0).clone().detach().cpu()
    if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
        if tensor.max() <= 1.0:
            tensor *= 255
        tensor = tensor.to(torch.uint8)
    
    if tensor.ndim == 2:  # 灰度图像 [H, W]
        return Image.fromarray(tensor.numpy(), 'L')
    elif tensor.ndim == 3:
        if tensor.shape[2] == 1:  # 单通道 [H, W, 1]
            return Image.fromarray(tensor.numpy().squeeze(2), 'L')
        elif tensor.shape[2] >= 3:  # RGB [H, W, 3]
            return Image.fromarray(tensor.numpy(), 'RGB')
        else:
            raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
    else:
        raise ValueError(f"不支持的tensor维度: {tensor.ndim}")