WithAnyone's picture
Upload 29 files
4910a8a verified
raw
history blame
13.7 kB
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Literal
import torch
from einops import rearrange
from PIL import ExifTags, Image
import torchvision.transforms.functional as TVF
from withanyone.flux.modules.layers import (
DoubleStreamBlockLoraProcessor,
DoubleStreamBlockProcessor,
SingleStreamBlockLoraProcessor,
SingleStreamBlockProcessor,
)
from withanyone.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from withanyone.flux.util import (
load_ae,
load_clip,
load_flow_model_no_lora,
load_flow_model_diffusers,
load_t5,
)
from withanyone.flux.model import SiglipEmbedding, create_person_cross_attention_mask_varlen
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
image_w, image_h = raw_image.size
if image_w >= image_h:
new_w = long_size
new_h = int((long_size / image_w) * image_h)
else:
new_h = long_size
new_w = int((long_size / image_h) * image_w)
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
target_w = new_w // 16 * 16
target_h = new_h // 16 * 16
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
right = left + target_w
bottom = top + target_h
raw_image = raw_image.crop((left, top, right, bottom))
raw_image = raw_image.convert("RGB")
return raw_image
from io import BytesIO
import insightface
import numpy as np
class FaceExtractor:
def __init__(self, model_path = "./"):
self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider'])
self.model.prepare(ctx_id=0, det_thresh=0.45)
def extract_moref(self, img, bboxes, face_size_restriction=1):
"""
Extract faces from an image based on bounding boxes in JSON data.
Makes each face square and resizes to 512x512.
Args:
img: PIL Image or image data
json_data: JSON object with 'bboxes' and 'crop' information
Returns:
List of PIL Images, each 512x512, containing extracted faces
"""
# Ensure img is a PIL Image
try:
if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor):
img = Image.open(BytesIO(img))
# bboxes = json_data['bboxes']
# crop = json_data['crop']
# print("len of bboxes:", len(bboxes))
# Recalculate bounding boxes based on crop info
# new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
new_bboxes = bboxes
# any of the face is less than 100 * 100, we ignore this image
for bbox in new_bboxes:
x1, y1, x2, y2 = bbox
if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
return []
# print("len of new_bboxes:", len(new_bboxes))
faces = []
for bbox in new_bboxes:
# print("processing bbox")
# Convert coordinates to integers
x1, y1, x2, y2 = map(int, bbox)
# Calculate width and height
width = x2 - x1
height = y2 - y1
# Make the bounding box square by expanding the shorter dimension
if width > height:
# Height is shorter, expand it
diff = width - height
y1 -= diff // 2
y2 += diff - (diff // 2) # Handle odd differences
elif height > width:
# Width is shorter, expand it
diff = height - width
x1 -= diff // 2
x2 += diff - (diff // 2) # Handle odd differences
# Ensure coordinates are within image boundaries
img_width, img_height = img.size
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(img_width, x2)
y2 = min(img_height, y2)
# Extract face region
face_region = img.crop((x1, y1, x2, y2))
# Resize to 512x512
face_region = face_region.resize((512, 512), Image.LANCZOS)
faces.append(face_region)
# print("len of faces:", len(faces))
return faces
except Exception as e:
print(f"Error processing image: {e}")
return []
def __call__(self, img):
# if np, get PIL, else, get np
if isinstance(img, torch.Tensor):
img_np = img.cpu().numpy()
img_pil = Image.fromarray(img_np)
elif isinstance(img, Image.Image):
img_pil = img
img_np = np.array(img)
elif isinstance(img, np.ndarray):
img_np = img
img_pil = Image.fromarray(img)
else:
raise ValueError("Unsupported image format. Please provide a PIL Image or numpy array.")
# Detect faces in the image
faces = self.model.get(img_np)
# use one
if len(faces) > 0:
bboxes = []
face = faces[0]
bbox = face.bbox.astype(int)
bboxes.append(bbox)
return self.extract_moref(img_pil, bboxes)[0]
else:
print("Warning: No faces detected in the image.")
return img_pil
class WithAnyonePipeline:
def __init__(
self,
model_type: str,
ipa_path: str,
device: torch.device,
offload: bool = False,
only_lora: bool = False,
no_lora: bool = False,
lora_rank: int = 16,
face_extractor = None,
additional_lora_ckpt: str = None,
lora_weight: float = 1.0,
clip_path: str = "openai/clip-vit-large-patch14",
t5_path: str = "xlabs-ai/xflux_text_encoders",
flux_path: str = "black-forest-labs/FLUX.1-dev",
siglip_path: str = "google/siglip-base-patch16-256-i18n",
):
self.device = device
self.offload = offload
self.model_type = model_type
self.clip = load_clip(clip_path, self.device)
self.t5 = load_t5(t5_path, self.device, max_length=512)
self.ae = load_ae(flux_path, model_type, device="cpu" if offload else self.device)
self.use_fp8 = "fp8" in model_type
if additional_lora_ckpt is not None:
self.model = load_flow_model_diffusers(
model_type,
flux_path,
ipa_path,
device="cpu" if offload else self.device,
lora_rank=lora_rank,
use_fp8=self.use_fp8,
additional_lora_ckpt=additional_lora_ckpt,
lora_weight=lora_weight,
).to("cuda" if torch.cuda.is_available() else "cpu")
else:
self.model = load_flow_model_no_lora(
model_type,
flux_path,
ipa_path,
device="cpu" if offload else self.device,
use_fp8=self.use_fp8
)
if face_extractor is not None:
self.face_extractor = face_extractor
else:
self.face_extractor = FaceExtractor()
self.siglip = SiglipEmbedding(siglip_path=siglip_path)
def load_ckpt(self, ckpt_path):
if ckpt_path is not None:
from safetensors.torch import load_file as load_sft
print("Loading checkpoint to replace old keys")
# load_sft doesn't support torch.device
if ckpt_path.endswith('safetensors'):
sd = load_sft(ckpt_path, device='cpu')
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
else:
dit_state = torch.load(ckpt_path, map_location='cpu')
sd = {}
for k in dit_state.keys():
sd[k.replace('module.','')] = dit_state[k]
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
self.model.to(str(self.device))
print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
def __call__(
self,
prompt: str,
width: int = 512,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
**kwargs
):
width = 16 * (width // 16)
height = 16 * (height // 16)
device_type = self.device if isinstance(self.device, str) else self.device.type
if device_type == "mps":
device_type = "cpu" # for support macos mps
with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
return self.forward(
prompt,
width,
height,
guidance,
num_steps,
seed,
**kwargs
)
@torch.inference_mode
def forward(
self,
prompt: str,
width: int,
height: int,
guidance: float,
num_steps: int,
seed: int,
ref_imgs: list[Image.Image] | None = None,
arcface_embeddings: list[torch.Tensor] = None,
bboxes = None,
id_weight: float = 1.0,
siglip_weight: float = 1.0,
):
x = get_noise(
1, height, width, device=self.device,
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
if self.offload:
self.ae.encoder = self.ae.encoder.to(self.device)
if ref_imgs is None:
siglip_embeddings = None
else:
siglip_embeddings = self.siglip(ref_imgs).to(self.device, torch.bfloat16).permute(1,0,2,3)
# num_ref, (1), n, d
if arcface_embeddings is not None:
arcface_embeddings = arcface_embeddings.unsqueeze(1)
# num_ref, 1, 512
arcface_embeddings = arcface_embeddings.to(self.device, torch.bfloat16)
if self.offload:
self.offload_model_to_cpu(self.ae.encoder)
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
inp_cond = prepare(t5=self.t5, clip=self.clip,img=x,prompt=prompt
)
if self.offload:
self.offload_model_to_cpu(self.t5, self.clip)
self.model = self.model.to(self.device)
img = inp_cond["img"]
img_length = img.shape[1]
##### create mask for siglip and arcface #####
if bboxes is not None:
arc_mask = create_person_cross_attention_mask_varlen(
batch_size=img.shape[0],
# num_heads=self.params.num_heads,
# txt_len=text_length,
img_len=img_length,
id_len=8,
bbox_lists=bboxes,
max_num_ids=len(bboxes[0]),
original_width=width,
original_height= height,
).to(img.device)
siglip_mask = create_person_cross_attention_mask_varlen(
batch_size=img.shape[0],
# num_heads=self.params.num_heads,
# txt_len=text_length,
img_len=img_length,
id_len=256+8,
bbox_lists=bboxes,
max_num_ids=len(bboxes[0]),
original_width=width,
original_height= height,
).to(img.device)
results = denoise(
self.model,
**inp_cond,
timesteps=timesteps,
guidance=guidance,
arcface_embeddings=arcface_embeddings,
siglip_embeddings=siglip_embeddings,
bboxes=bboxes,
id_weight=id_weight,
siglip_weight=siglip_weight,
img_height=height,
img_width=width,
arc_mask=arc_mask if bboxes is not None else None,
siglip_mask=siglip_mask if bboxes is not None else None,
)
x = results
if self.offload:
self.offload_model_to_cpu(self.model)
self.ae.decoder.to(x.device)
x = unpack(x.float(), height, width)
x = self.ae.decode(x)
self.offload_model_to_cpu(self.ae.decoder)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1[-1], "c h w -> h w c")
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
return output_img
def offload_model_to_cpu(self, *models):
if not self.offload: return
for model in models:
model.cpu()
torch.cuda.empty_cache()