Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| import torch | |
| from torch import Tensor, nn | |
| from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding, PerceiverAttentionCA | |
| from transformers import AutoTokenizer, AutoProcessor, SiglipModel | |
| import math | |
| from transformers import AutoModelForImageSegmentation | |
| from einops import rearrange | |
| from torchvision import transforms | |
| from PIL import Image | |
| from torch.cuda.amp import autocast | |
| def create_person_cross_attention_mask_varlen( | |
| batch_size, img_len, id_len, | |
| bbox_lists, original_width, original_height, | |
| max_num_ids=2, # Default to support 2 identities | |
| vae_scale_factor=8, patch_size=2, num_heads = 24 | |
| ): | |
| """ | |
| Create boolean attention masks limiting image tokens to interact only with corresponding person ID tokens | |
| Parameters: | |
| - batch_size: Number of samples in batch | |
| - num_heads: Number of attention heads | |
| - img_len: Length of image token sequence | |
| - id_len: Length of EACH identity embedding (not total) | |
| - bbox_lists: List where bbox_lists[i] contains all bboxes for batch i | |
| Each batch may have a different number of bboxes/identities | |
| - max_num_ids: Maximum number of identities to support (for padding) | |
| - original_width/height: Original image dimensions | |
| - vae_scale_factor: VAE downsampling factor (default 8) | |
| - patch_size: Patch size for token creation (default 2) | |
| Returns: | |
| - Boolean attention mask of shape [batch_size, num_heads, img_len, total_id_len] | |
| """ | |
| # Total length of ID tokens based on maximum number of identities | |
| total_id_len = max_num_ids * id_len | |
| # Initialize mask to block all attention | |
| mask = torch.zeros((batch_size, num_heads, img_len, total_id_len), dtype=torch.bool) | |
| # Calculate VAE dimensions | |
| latent_width = original_width // vae_scale_factor | |
| latent_height = original_height // vae_scale_factor | |
| patches_width = latent_width // patch_size | |
| patches_height = latent_height // patch_size | |
| # Convert boundary box to token indices | |
| def bbox_to_token_indices(bbox): | |
| x1, y1, x2, y2 = bbox | |
| # Convert to patch space coordinates | |
| if isinstance(x1, torch.Tensor): | |
| x1_patch = max(0, int(x1.item()) // vae_scale_factor // patch_size) | |
| y1_patch = max(0, int(y1.item()) // vae_scale_factor // patch_size) | |
| x2_patch = min(patches_width, math.ceil(int(x2.item()) / vae_scale_factor / patch_size)) | |
| y2_patch = min(patches_height, math.ceil(int(y2.item()) / vae_scale_factor / patch_size)) | |
| elif isinstance(x1, int): | |
| x1_patch = max(0, x1 // vae_scale_factor // patch_size) | |
| y1_patch = max(0, y1 // vae_scale_factor // patch_size) | |
| x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size)) | |
| y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size)) | |
| elif isinstance(x1, float): | |
| x1_patch = max(0, int(x1) // vae_scale_factor // patch_size) | |
| y1_patch = max(0, int(y1) // vae_scale_factor // patch_size) | |
| x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size)) | |
| y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size)) | |
| else: | |
| raise TypeError(f"Unsupported type: {type(x1)}") | |
| # Create list of all token indices in this region | |
| indices = [] | |
| for y in range(y1_patch, y2_patch): | |
| for x in range(x1_patch, x2_patch): | |
| idx = y * patches_width + x | |
| indices.append(idx) | |
| return indices | |
| for b in range(batch_size): | |
| # Get all bboxes for this batch item | |
| batch_bboxes = bbox_lists[b] if b < len(bbox_lists) else [] | |
| # Process each bbox in the batch up to max_num_ids | |
| for identity_idx, bbox in enumerate(batch_bboxes[:max_num_ids]): | |
| # Get image token indices for this bbox | |
| image_indices = bbox_to_token_indices(bbox) | |
| # Calculate ID token slice for this identity | |
| id_start = identity_idx * id_len | |
| id_end = id_start + id_len | |
| id_slice = slice(id_start, id_end) | |
| # Enable attention between this region's image tokens and the identity's tokens | |
| for h in range(num_heads): | |
| for idx in image_indices: | |
| mask[b, h, idx, id_slice] = True | |
| return mask | |
| # FFN | |
| def FeedForward(dim, mult=4): | |
| inner_dim = int(dim * mult) | |
| return nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, inner_dim, bias=False), | |
| nn.GELU(), | |
| nn.Linear(inner_dim, dim, bias=False), | |
| ) | |
| class FluxParams: | |
| in_channels: int | |
| vec_in_dim: int | |
| context_in_dim: int | |
| hidden_size: int | |
| mlp_ratio: float | |
| num_heads: int | |
| depth: int | |
| depth_single_blocks: int | |
| axes_dim: list[int] | |
| theta: int | |
| qkv_bias: bool | |
| guidance_embed: bool | |
| class SiglipEmbedding(nn.Module): | |
| def __init__(self, siglip_path = "google/siglip-base-patch16-256-i18n", use_matting=False): | |
| super().__init__() | |
| self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16) | |
| self.processor = AutoProcessor.from_pretrained(siglip_path) | |
| # self.model.to(torch.cuda.current_device()) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| # BiRefNet matting setup | |
| self.use_matting = use_matting | |
| if self.use_matting: | |
| self.birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| 'briaai/RMBG-2.0', trust_remote_code=True).to(self.device, dtype=torch.bfloat16) | |
| # Apply half precision to the entire model after loading | |
| self.matting_transform = transforms.Compose([ | |
| # transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def apply_matting(self, image): | |
| """Apply BiRefNet matting to remove background from image""" | |
| if not self.use_matting: | |
| return image | |
| # Convert to input format and move to GPU | |
| input_image = self.matting_transform(image).unsqueeze(0).to(self.device, dtype=torch.bfloat16) | |
| # Generate prediction | |
| with torch.no_grad(), autocast(dtype=torch.bfloat16): | |
| preds = self.birefnet(input_image)[-1].sigmoid().cpu() | |
| # Process the mask | |
| pred = preds[0].squeeze().float() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image.size) | |
| binary_mask = mask.convert("L") | |
| # Create a new image with black background | |
| result = Image.new("RGB", image.size, (0, 0, 0)) | |
| result.paste(image, (0, 0), binary_mask) | |
| return result | |
| def get_id_embedding(self, refimage): | |
| ''' | |
| refimage is a list (batch) of list (num of person) of PIL images | |
| considering the whole batch, the number of person is fixed | |
| ''' | |
| siglip_embedding = [] | |
| if isinstance(refimage, list): | |
| batch_size = len(refimage) | |
| for batch_idx, refimage_batch in enumerate(refimage): | |
| # Apply matting if enabled | |
| if self.use_matting: | |
| processed_images = [self.apply_matting(img) for img in refimage_batch] | |
| else: | |
| processed_images = refimage_batch | |
| pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values | |
| # device | |
| pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16) | |
| last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768 | |
| # pooled_output = self.model(pixel_values).pooler_output # 2, 768 | |
| siglip_embedding.append(last_hidden_state) | |
| # siglip_embedding.append(pooled_output) # 2, 768 | |
| siglip_embedding = torch.stack(siglip_embedding, dim=0) # shape ([batch_size, num_of_person, 256, 768]) | |
| if batch_size < 4: | |
| # run additional times to avoid the first time cuda memory allocation overhead | |
| for _ in range(4 - batch_size): | |
| pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values | |
| # device | |
| pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16) | |
| last_hidden_state = self.model(pixel_values).last_hidden_state | |
| elif isinstance(refimage, torch.Tensor): | |
| # refimage is a tensor of shape (batch_size, num_of_person, 3, H, W) | |
| batch_size, num_of_person, C, H, W = refimage.shape | |
| refimage = refimage.view(batch_size * num_of_person, C, H, W) | |
| refimage = refimage.to(self.device, dtype=torch.bfloat16) | |
| last_hidden_state = self.model(refimage).last_hidden_state | |
| siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768) | |
| return siglip_embedding | |
| def forward(self, refimage): | |
| return self.get_id_embedding(refimage) | |
| class Flux(nn.Module): | |
| """ | |
| Transformer model for flow matching on sequences. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| def __init__(self, params: FluxParams): | |
| super().__init__() | |
| self.params = params | |
| self.in_channels = params.in_channels | |
| self.out_channels = self.in_channels | |
| if params.hidden_size % params.num_heads != 0: | |
| raise ValueError( | |
| f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" | |
| ) | |
| pe_dim = params.hidden_size // params.num_heads | |
| if sum(params.axes_dim) != pe_dim: | |
| raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") | |
| self.hidden_size = params.hidden_size | |
| self.num_heads = params.num_heads | |
| self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) | |
| self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) | |
| self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) | |
| self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) | |
| self.guidance_in = ( | |
| MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() | |
| ) | |
| self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) | |
| self.double_blocks = nn.ModuleList( | |
| [ | |
| DoubleStreamBlock( | |
| self.hidden_size, | |
| self.num_heads, | |
| mlp_ratio=params.mlp_ratio, | |
| qkv_bias=params.qkv_bias, | |
| ) | |
| for _ in range(params.depth) | |
| ] | |
| ) | |
| self.single_blocks = nn.ModuleList( | |
| [ | |
| SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) | |
| for _ in range(params.depth_single_blocks) | |
| ] | |
| ) | |
| self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) | |
| self.gradient_checkpointing = False | |
| # use cross attention | |
| self.ipa_arc = nn.ModuleList([ | |
| PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads) | |
| for _ in range(self.params.depth_single_blocks + self.params.depth) | |
| ]) | |
| self.ipa_sig = nn.ModuleList([ | |
| PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads) | |
| for _ in range(self.params.depth_single_blocks + self.params.depth) | |
| ]) | |
| self.arcface_in_arc = nn.Sequential( | |
| nn.Linear(512, 4 * self.hidden_size, bias=True), | |
| nn.GELU(), | |
| nn.LayerNorm(4 * self.hidden_size), | |
| nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True), | |
| ) | |
| self.arcface_in_sig = nn.Sequential( | |
| nn.Linear(512, 4 * self.hidden_size, bias=True), | |
| nn.GELU(), | |
| nn.LayerNorm(4 * self.hidden_size), | |
| nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True), | |
| ) | |
| self.siglip_in_sig = nn.Sequential( | |
| nn.Linear(768, self.hidden_size, bias=True), | |
| nn.GELU(), | |
| nn.LayerNorm(self.hidden_size), | |
| nn.Linear(self.hidden_size, self.hidden_size, bias=True), | |
| ) | |
| def lq_in_arc(self, txt_lq, siglip_embeddings, arcface_embeddings): | |
| """ | |
| Process the siglip and arcface embeddings. | |
| """ | |
| # shape of arcface: (num_refs, bs, 512) | |
| arcface_embeddings = self.arcface_in_arc(arcface_embeddings) | |
| # shape of arcface: (num_refs, bs, 4*hidden_size) | |
| # 4*hidden_size -> 4 tokens of hidden_size | |
| arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size) | |
| # (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size) | |
| arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d) | |
| arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d') | |
| return arcface_embeddings | |
| def lq_in_sig(self, txt_lq, siglip_embeddings, arcface_embeddings): | |
| """ | |
| Process the siglip and arcface embeddings. | |
| """ | |
| # shape of arcface: (num_refs, bs, 512) | |
| arcface_embeddings = self.arcface_in_sig(arcface_embeddings) | |
| arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size) | |
| # (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size) | |
| arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d) | |
| siglip_embeddings = self.siglip_in_sig(siglip_embeddings) # (bs, num_refs, 256, 768) -> (bs, num_refs, 4*hidden_size) | |
| # concat in token dimension | |
| arcface_embeddings = torch.cat((siglip_embeddings, arcface_embeddings), dim=2) # (bs, num_refs, 4, hidden_size) cat (bs, num_refs, 4, hidden_size) -> (bs, num_refs, 8, hidden_size) | |
| arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d') | |
| return arcface_embeddings | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if hasattr(module, "gradient_checkpointing"): | |
| module.gradient_checkpointing = value | |
| def attn_processors(self): | |
| # set recursively | |
| processors = {} # type: dict[str, nn.Module] | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): | |
| if hasattr(module, "set_processor"): | |
| processors[f"{name}.processor"] = module.processor | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| def set_attn_processor(self, processor): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| def forward( | |
| self, | |
| img: Tensor, | |
| img_ids: Tensor, | |
| txt: Tensor, | |
| txt_ids: Tensor, | |
| timesteps: Tensor, | |
| y: Tensor, | |
| guidance: Tensor | None = None, | |
| siglip_embeddings: Tensor | None = None, # (bs, num_refs, 256, 768) | |
| arcface_embeddings: Tensor | None = None, # (bs, num_refs, 512) | |
| bbox_lists: list | None = None, # list of list of bboxes, bbox_lists[i] is for the i-th batch, each has different number of bboxes (ids), which should align with the dim1 of arcface_embeddings. This is used to replace bbox_A and bbox_B, which should be discarded, but remained for compatibility. | |
| use_mask: bool = True, | |
| id_weight: float = 1.0, | |
| siglip_weight: float = 1.0, | |
| siglip_mask = None, | |
| arc_mask = None, | |
| img_height: int = 512, | |
| img_width: int = 512, | |
| ) -> Tensor: | |
| if img.ndim != 3 or txt.ndim != 3: | |
| raise ValueError("Input img and txt tensors must have 3 dimensions.") | |
| # running on sequences img | |
| img = self.img_in(img) | |
| vec = self.time_in(timestep_embedding(timesteps, 256)) | |
| if self.params.guidance_embed: | |
| if guidance is None: | |
| raise ValueError("Didn't get guidance strength for guidance distilled model.") | |
| vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) | |
| vec = vec + self.vector_in(y) | |
| txt = self.txt_in(txt) | |
| text_length = txt.shape[1] | |
| img_length = img.shape[1] | |
| img_end = img.shape[1] | |
| use_ip = arcface_embeddings is not None | |
| if use_ip: | |
| id_embeddings = self.lq_in_arc(None, siglip_embeddings, arcface_embeddings) | |
| siglip_embeddings = self.lq_in_sig(None, siglip_embeddings, arcface_embeddings) | |
| text_length = txt.shape[1] # update text_length after adding learnable query | |
| # 8 tokens for arcface, 256 tokens for siglip | |
| id_len = 8 | |
| siglip_len = 256 + 8 | |
| if bbox_lists is not None and use_mask and (arc_mask is None or siglip_mask is None): | |
| arc_mask = create_person_cross_attention_mask_varlen( | |
| batch_size=img.shape[0], | |
| num_heads=self.params.num_heads, | |
| # txt_len=text_length, | |
| img_len=img_length, | |
| id_len=id_len, | |
| bbox_lists=bbox_lists, | |
| max_num_ids=len(bbox_lists[0]), | |
| original_width=img_width, | |
| original_height= img_height, | |
| ).to(img.device) | |
| siglip_mask = create_person_cross_attention_mask_varlen( | |
| batch_size=img.shape[0], | |
| num_heads=self.params.num_heads, | |
| # txt_len=text_length, | |
| img_len=img_length, | |
| id_len=siglip_len, | |
| bbox_lists=bbox_lists, | |
| max_num_ids=len(bbox_lists[0]), | |
| original_width=img_width, | |
| original_height= img_height, | |
| ).to(img.device) | |
| else: | |
| arc_mask = None | |
| siglip_mask = None | |
| # update text_ids and id_ids | |
| txt_ids = torch.zeros((txt.shape[0], text_length, 3)).to(img_ids.device) # (bs, T, 3) | |
| ids = torch.cat((txt_ids, img_ids), dim=1) # (bs, T + I + ID, 3) | |
| pe = self.pe_embedder(ids) | |
| # ipa | |
| ipa_idx = 0 | |
| for index_block, block in enumerate(self.double_blocks): | |
| if self.training and self.gradient_checkpointing: | |
| img, txt = torch.utils.checkpoint.checkpoint( | |
| block, | |
| img=img, | |
| txt=txt, | |
| vec=vec, | |
| pe=pe, | |
| # mask=mask, | |
| text_length=text_length, | |
| image_length=img_length, | |
| # return_map = False, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| img, txt= block( | |
| img=img, | |
| txt=txt, | |
| vec=vec, | |
| pe=pe, | |
| text_length=text_length, | |
| image_length=img_length, | |
| # return_map=False, | |
| ) | |
| if use_ip: | |
| img = img + id_weight * self.ipa_arc[ipa_idx](id_embeddings, img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, img, mask=siglip_mask) | |
| ipa_idx += 1 | |
| # for block in self.single_blocks: | |
| img = torch.cat((txt, img), 1) | |
| for index_block, block in enumerate(self.single_blocks): | |
| if self.training and self.gradient_checkpointing: | |
| img = torch.utils.checkpoint.checkpoint( | |
| block, | |
| img, vec=vec, pe=pe, #mask=mask, | |
| text_length=text_length, | |
| image_length=img_length, | |
| return_map=False, | |
| use_reentrant=False | |
| ) | |
| else: | |
| img = block(img, vec=vec, pe=pe,text_length=text_length, image_length=img_length, return_map=False) | |
| # IPA | |
| if use_ip: | |
| txt, real_img = img[:, :text_length, :], img[:, text_length:, :] | |
| id_ca = id_weight * self.ipa_arc[ipa_idx](id_embeddings, real_img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, real_img, mask=siglip_mask) | |
| real_img = real_img + id_ca | |
| img = torch.cat((txt, real_img), dim=1) | |
| ipa_idx += 1 | |
| img = img[:, txt.shape[1] :, ...] | |
| # index img | |
| img = img[:, :img_end, ...] | |
| img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) | |
| return img | |