Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
| # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Literal | |
| import torch | |
| from einops import rearrange | |
| from PIL import ExifTags, Image | |
| import torchvision.transforms.functional as TVF | |
| from withanyone.flux.modules.layers import ( | |
| DoubleStreamBlockLoraProcessor, | |
| DoubleStreamBlockProcessor, | |
| SingleStreamBlockLoraProcessor, | |
| SingleStreamBlockProcessor, | |
| ) | |
| from withanyone.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack | |
| from withanyone.flux.util import ( | |
| load_ae, | |
| load_clip, | |
| load_flow_model_no_lora, | |
| load_flow_model_diffusers, | |
| load_t5, | |
| ) | |
| from withanyone.flux.model import SiglipEmbedding, create_person_cross_attention_mask_varlen | |
| def preprocess_ref(raw_image: Image.Image, long_size: int = 512): | |
| image_w, image_h = raw_image.size | |
| if image_w >= image_h: | |
| new_w = long_size | |
| new_h = int((long_size / image_w) * image_h) | |
| else: | |
| new_h = long_size | |
| new_w = int((long_size / image_h) * image_w) | |
| raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) | |
| target_w = new_w // 16 * 16 | |
| target_h = new_h // 16 * 16 | |
| left = (new_w - target_w) // 2 | |
| top = (new_h - target_h) // 2 | |
| right = left + target_w | |
| bottom = top + target_h | |
| raw_image = raw_image.crop((left, top, right, bottom)) | |
| raw_image = raw_image.convert("RGB") | |
| return raw_image | |
| from io import BytesIO | |
| import insightface | |
| import numpy as np | |
| class FaceExtractor: | |
| def __init__(self, model_path = "./"): | |
| self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider']) | |
| self.model.prepare(ctx_id=0, det_thresh=0.45) | |
| def extract_moref(self, img, bboxes, face_size_restriction=1): | |
| """ | |
| Extract faces from an image based on bounding boxes in JSON data. | |
| Makes each face square and resizes to 512x512. | |
| Args: | |
| img: PIL Image or image data | |
| json_data: JSON object with 'bboxes' and 'crop' information | |
| Returns: | |
| List of PIL Images, each 512x512, containing extracted faces | |
| """ | |
| # Ensure img is a PIL Image | |
| try: | |
| if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor): | |
| img = Image.open(BytesIO(img)) | |
| # bboxes = json_data['bboxes'] | |
| # crop = json_data['crop'] | |
| # print("len of bboxes:", len(bboxes)) | |
| # Recalculate bounding boxes based on crop info | |
| # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes] | |
| new_bboxes = bboxes | |
| # any of the face is less than 100 * 100, we ignore this image | |
| for bbox in new_bboxes: | |
| x1, y1, x2, y2 = bbox | |
| if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction: | |
| return [] | |
| # print("len of new_bboxes:", len(new_bboxes)) | |
| faces = [] | |
| for bbox in new_bboxes: | |
| # print("processing bbox") | |
| # Convert coordinates to integers | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # Calculate width and height | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| # Make the bounding box square by expanding the shorter dimension | |
| if width > height: | |
| # Height is shorter, expand it | |
| diff = width - height | |
| y1 -= diff // 2 | |
| y2 += diff - (diff // 2) # Handle odd differences | |
| elif height > width: | |
| # Width is shorter, expand it | |
| diff = height - width | |
| x1 -= diff // 2 | |
| x2 += diff - (diff // 2) # Handle odd differences | |
| # Ensure coordinates are within image boundaries | |
| img_width, img_height = img.size | |
| x1 = max(0, x1) | |
| y1 = max(0, y1) | |
| x2 = min(img_width, x2) | |
| y2 = min(img_height, y2) | |
| # Extract face region | |
| face_region = img.crop((x1, y1, x2, y2)) | |
| # Resize to 512x512 | |
| face_region = face_region.resize((512, 512), Image.LANCZOS) | |
| faces.append(face_region) | |
| # print("len of faces:", len(faces)) | |
| return faces | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| return [] | |
| def __call__(self, img): | |
| # if np, get PIL, else, get np | |
| if isinstance(img, torch.Tensor): | |
| img_np = img.cpu().numpy() | |
| img_pil = Image.fromarray(img_np) | |
| elif isinstance(img, Image.Image): | |
| img_pil = img | |
| img_np = np.array(img) | |
| elif isinstance(img, np.ndarray): | |
| img_np = img | |
| img_pil = Image.fromarray(img) | |
| else: | |
| raise ValueError("Unsupported image format. Please provide a PIL Image or numpy array.") | |
| # Detect faces in the image | |
| faces = self.model.get(img_np) | |
| # use one | |
| if len(faces) > 0: | |
| bboxes = [] | |
| face = faces[0] | |
| bbox = face.bbox.astype(int) | |
| bboxes.append(bbox) | |
| return self.extract_moref(img_pil, bboxes)[0] | |
| else: | |
| print("Warning: No faces detected in the image.") | |
| return img_pil | |
| class WithAnyonePipeline: | |
| def __init__( | |
| self, | |
| model_type: str, | |
| ipa_path: str, | |
| device: torch.device, | |
| offload: bool = False, | |
| only_lora: bool = False, | |
| no_lora: bool = False, | |
| lora_rank: int = 16, | |
| face_extractor = None, | |
| additional_lora_ckpt: str = None, | |
| lora_weight: float = 1.0, | |
| clip_path: str = "openai/clip-vit-large-patch14", | |
| t5_path: str = "xlabs-ai/xflux_text_encoders", | |
| flux_path: str = "black-forest-labs/FLUX.1-dev", | |
| siglip_path: str = "google/siglip-base-patch16-256-i18n", | |
| ): | |
| self.device = device | |
| self.offload = offload | |
| self.model_type = model_type | |
| self.clip = load_clip(clip_path, self.device) | |
| self.t5 = load_t5(t5_path, self.device, max_length=512) | |
| self.ae = load_ae(flux_path, model_type, device="cpu" if offload else self.device) | |
| self.use_fp8 = "fp8" in model_type | |
| if additional_lora_ckpt is not None: | |
| self.model = load_flow_model_diffusers( | |
| model_type, | |
| flux_path, | |
| ipa_path, | |
| device="cpu" if offload else self.device, | |
| lora_rank=lora_rank, | |
| use_fp8=self.use_fp8, | |
| additional_lora_ckpt=additional_lora_ckpt, | |
| lora_weight=lora_weight, | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.model = load_flow_model_no_lora( | |
| model_type, | |
| flux_path, | |
| ipa_path, | |
| device="cpu" if offload else self.device, | |
| use_fp8=self.use_fp8 | |
| ) | |
| if face_extractor is not None: | |
| self.face_extractor = face_extractor | |
| else: | |
| self.face_extractor = FaceExtractor() | |
| self.siglip = SiglipEmbedding(siglip_path=siglip_path) | |
| def load_ckpt(self, ckpt_path): | |
| if ckpt_path is not None: | |
| from safetensors.torch import load_file as load_sft | |
| print("Loading checkpoint to replace old keys") | |
| # load_sft doesn't support torch.device | |
| if ckpt_path.endswith('safetensors'): | |
| sd = load_sft(ckpt_path, device='cpu') | |
| missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) | |
| else: | |
| dit_state = torch.load(ckpt_path, map_location='cpu') | |
| sd = {} | |
| for k in dit_state.keys(): | |
| sd[k.replace('module.','')] = dit_state[k] | |
| missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) | |
| self.model.to(str(self.device)) | |
| print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") | |
| def __call__( | |
| self, | |
| prompt: str, | |
| width: int = 512, | |
| height: int = 512, | |
| guidance: float = 4, | |
| num_steps: int = 50, | |
| seed: int = 123456789, | |
| **kwargs | |
| ): | |
| width = 16 * (width // 16) | |
| height = 16 * (height // 16) | |
| device_type = self.device if isinstance(self.device, str) else self.device.type | |
| if device_type == "mps": | |
| device_type = "cpu" # for support macos mps | |
| with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16): | |
| return self.forward( | |
| prompt, | |
| width, | |
| height, | |
| guidance, | |
| num_steps, | |
| seed, | |
| **kwargs | |
| ) | |
| def forward( | |
| self, | |
| prompt: str, | |
| width: int, | |
| height: int, | |
| guidance: float, | |
| num_steps: int, | |
| seed: int, | |
| ref_imgs: list[Image.Image] | None = None, | |
| arcface_embeddings: list[torch.Tensor] = None, | |
| bboxes = None, | |
| id_weight: float = 1.0, | |
| siglip_weight: float = 1.0, | |
| ): | |
| x = get_noise( | |
| 1, height, width, device=self.device, | |
| dtype=torch.bfloat16, seed=seed | |
| ) | |
| timesteps = get_schedule( | |
| num_steps, | |
| (width // 8) * (height // 8) // (16 * 16), | |
| shift=True, | |
| ) | |
| if self.offload: | |
| self.ae.encoder = self.ae.encoder.to(self.device) | |
| if ref_imgs is None: | |
| siglip_embeddings = None | |
| else: | |
| siglip_embeddings = self.siglip(ref_imgs).to(self.device, torch.bfloat16).permute(1,0,2,3) | |
| # num_ref, (1), n, d | |
| if arcface_embeddings is not None: | |
| arcface_embeddings = arcface_embeddings.unsqueeze(1) | |
| # num_ref, 1, 512 | |
| arcface_embeddings = arcface_embeddings.to(self.device, torch.bfloat16) | |
| if self.offload: | |
| self.offload_model_to_cpu(self.ae.encoder) | |
| self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) | |
| inp_cond = prepare(t5=self.t5, clip=self.clip,img=x,prompt=prompt | |
| ) | |
| if self.offload: | |
| self.offload_model_to_cpu(self.t5, self.clip) | |
| self.model = self.model.to(self.device) | |
| img = inp_cond["img"] | |
| img_length = img.shape[1] | |
| ##### create mask for siglip and arcface ##### | |
| if bboxes is not None: | |
| arc_mask = create_person_cross_attention_mask_varlen( | |
| batch_size=img.shape[0], | |
| # num_heads=self.params.num_heads, | |
| # txt_len=text_length, | |
| img_len=img_length, | |
| id_len=8, | |
| bbox_lists=bboxes, | |
| max_num_ids=len(bboxes[0]), | |
| original_width=width, | |
| original_height= height, | |
| ).to(img.device) | |
| siglip_mask = create_person_cross_attention_mask_varlen( | |
| batch_size=img.shape[0], | |
| # num_heads=self.params.num_heads, | |
| # txt_len=text_length, | |
| img_len=img_length, | |
| id_len=256+8, | |
| bbox_lists=bboxes, | |
| max_num_ids=len(bboxes[0]), | |
| original_width=width, | |
| original_height= height, | |
| ).to(img.device) | |
| results = denoise( | |
| self.model, | |
| **inp_cond, | |
| timesteps=timesteps, | |
| guidance=guidance, | |
| arcface_embeddings=arcface_embeddings, | |
| siglip_embeddings=siglip_embeddings, | |
| bboxes=bboxes, | |
| id_weight=id_weight, | |
| siglip_weight=siglip_weight, | |
| img_height=height, | |
| img_width=width, | |
| arc_mask=arc_mask if bboxes is not None else None, | |
| siglip_mask=siglip_mask if bboxes is not None else None, | |
| ) | |
| x = results | |
| if self.offload: | |
| self.offload_model_to_cpu(self.model) | |
| self.ae.decoder.to(x.device) | |
| x = unpack(x.float(), height, width) | |
| x = self.ae.decode(x) | |
| self.offload_model_to_cpu(self.ae.decoder) | |
| x1 = x.clamp(-1, 1) | |
| x1 = rearrange(x1[-1], "c h w -> h w c") | |
| output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) | |
| return output_img | |
| def offload_model_to_cpu(self, *models): | |
| if not self.offload: return | |
| for model in models: | |
| model.cpu() | |
| torch.cuda.empty_cache() | |