# -------------------------------------------------------- # NaViL # Copyright (c) 2025 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from typing import Optional, Tuple, Union from functools import partial import torch import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange from timm.models.layers import DropPath from torch import nn from transformers.activations import ACT2FN from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from .configuration_navil_vit import NaViLVisionConfig from .modular_intern_vit import ( InternVisionFlashAttention2, InternVisionSdpaAttention, InternMLP, NORM2FN, InternVisionRotaryEmbedding, ) try: # from .flash_attention import FlashAttention from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb has_flash_attn = True except: print('FlashAttention is not installed.') has_flash_attn = False logger = logging.get_logger(__name__) class NaViLVisionEmbeddingsAnyRes(nn.Module): def __init__(self, config: NaViLVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.merge_size = int(1.0 / config.downsample_ratio) self.patch_embedding = nn.Conv2d( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] batch_size, _, height, width = patch_embeds.shape return patch_embeds.flatten(1) class NaViLVisionEncoderLayerAnyRes(nn.Module): def __init__(self, config: NaViLVisionConfig, drop_path_rate: float): super().__init__() self.embed_dim = config.hidden_size self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type if has_flash_attn: self.attn = InternVisionFlashAttention2(config) else: self.attn = InternVisionSdpaAttention(config) self.mlp = InternMLP(config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward( self, hidden_states: torch.Tensor, cu_seqlens, rotary_pos_emb ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: """ Args: hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` """ hidden_states = hidden_states + self.drop_path1( self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, ) * self.ls1) hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) return hidden_states class NaViLVisionEncoderAnyRes(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`InternEncoderLayer`]. Args: config (`InternConfig`): The corresponding vision configuration for the `InternEncoder`. """ def __init__(self, config: NaViLVisionConfig): super().__init__() self.config = config # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] self.layers = nn.ModuleList([ NaViLVisionEncoderLayerAnyRes(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) self.gradient_checkpointing = True head_dim = config.hidden_size // config.num_attention_heads self.rotary_pos_emb = InternVisionRotaryEmbedding(head_dim // 2) self.merge_size = int(1.0 / config.downsample_ratio) self.merge_unit = self.merge_size * self.merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes self.window_size = config.window_size def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.merge_size, self.merge_size, w // self.merge_size, self.merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.merge_size, self.merge_size, w // self.merge_size, self.merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = self.window_size // self.merge_size assert vit_merger_window_size > 0 for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.merge_size, grid_w // self.merge_size, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def forward( self, inputs_embeds, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, grid_thw: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Embedded representation of the inputs. Should be float, not int tokens. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None hidden_states = inputs_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, device=hidden_states.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.merge_unit, self.merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for idx, encoder_layer in enumerate(self.layers): if (self.fullatt_block_indexes is None) or (idx in self.fullatt_block_indexes): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = torch.utils.checkpoint.checkpoint( partial(encoder_layer, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb), hidden_states) else: layer_outputs = encoder_layer( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, ) hidden_states = layer_outputs if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states ) class NaViLVisionModelAnyRes(PreTrainedModel): main_input_name = 'pixel_values' config_class = NaViLVisionConfig _no_split_modules = ['NaViLVisionEncoderLayerAnyRes'] def __init__(self, config: NaViLVisionConfig): super().__init__(config) self.config = config self.merge_size = int(1.0 / config.downsample_ratio) self.embeddings = NaViLVisionEmbeddingsAnyRes(config) self.encoder = NaViLVisionEncoderAnyRes(config) def get_input_embeddings(self): return self.embeddings def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_embeds: Optional[torch.FloatTensor] = None, grid_thw: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None and pixel_embeds is None: raise ValueError('You have to specify pixel_values or pixel_embeds') if pixel_embeds is not None: hidden_states = pixel_embeds else: if len(pixel_values.shape) == 4: hidden_states = self.embeddings(pixel_values) else: raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_hidden_states=output_hidden_states, return_dict=return_dict, grid_thw=grid_thw ) last_hidden_state = encoder_outputs.last_hidden_state # pooled_output = last_hidden_state[:, 0, :] last_hidden_state = last_hidden_state.unsqueeze(1).reshape(-1, self.merge_size, self.merge_size, last_hidden_state.shape[-1]) if not return_dict: return (last_hidden_state, ) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=None, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )