|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import copy |
|
|
import math |
|
|
import re |
|
|
from collections import defaultdict |
|
|
from typing import Any, Callable, Optional, Sequence, Union |
|
|
|
|
|
import PIL.Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
AutoModel, |
|
|
AutoTokenizer, |
|
|
BatchFeature, |
|
|
Cache, |
|
|
Qwen3Config, |
|
|
Qwen3ForCausalLM, |
|
|
Qwen3PreTrainedModel, |
|
|
) |
|
|
from transformers.cache_utils import SlidingWindowCache, StaticCache |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from transformers.image_processing_utils_fast import ( |
|
|
BaseImageProcessorFast, |
|
|
SizeDict, |
|
|
group_images_by_shape, |
|
|
reorder_images, |
|
|
DefaultFastImageProcessorKwargs, |
|
|
) |
|
|
from transformers.image_utils import ( |
|
|
ChannelDimension, |
|
|
PILImageResampling, |
|
|
) |
|
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
|
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding |
|
|
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
|
|
from transformers.models.siglip2.modeling_siglip2 import ( |
|
|
Siglip2Attention, |
|
|
Siglip2Encoder as HFSiglip2Encoder, |
|
|
Siglip2EncoderLayer as HFSiglip2EncoderLayer, |
|
|
Siglip2VisionEmbeddings as HFSiglip2VisionEmbeddings, |
|
|
) |
|
|
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, Unpack |
|
|
from transformers.tokenization_utils import TensorType |
|
|
from transformers.utils import auto_docstring |
|
|
from transformers.utils.generic import can_return_tuple |
|
|
|
|
|
|
|
|
from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN |
|
|
from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD |
|
|
from transformers.utils.import_utils import is_torchdynamo_compiling |
|
|
|
|
|
try: |
|
|
from genesis.public.tensorstream.tensor_stream import ( |
|
|
Event, |
|
|
Stream, |
|
|
TensorStream, |
|
|
TextType, |
|
|
VisionType, |
|
|
create_stream, |
|
|
group_streams, |
|
|
) |
|
|
from genesis.public.tensorstream.tensor_stream_utils import ( |
|
|
compute_mrope_pos_tensor, |
|
|
modality_mask, |
|
|
reconstruct_tensor_stream_from_compact_dict, |
|
|
tensor_stream_token_view, |
|
|
) |
|
|
from genesis.public.tensorstream.tensor_stream_utils import ( |
|
|
slice as ts_slice, |
|
|
) |
|
|
except ModuleNotFoundError as exc: |
|
|
raise ModuleNotFoundError( |
|
|
"genesis.public.tensorstream is required for the Isaac HuggingFace integration. " |
|
|
"Ensure the TensorStream package is installed and on PYTHONPATH." |
|
|
) from exc |
|
|
|
|
|
|
|
|
_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} |
|
|
for _attn_name in ("flash_attention_2", "sdpa", "eager"): |
|
|
if _attn_name in ALL_ATTENTION_FUNCTIONS: |
|
|
_ORIGINAL_ATTENTION_FUNCTIONS[_attn_name] = ALL_ATTENTION_FUNCTIONS[_attn_name] |
|
|
|
|
|
|
|
|
class IsaacVisionConfig(Siglip2VisionConfig): |
|
|
"""Vision configuration for Isaac with Pixel Shuffle support. |
|
|
|
|
|
Extends Siglip2VisionConfig with additional fields for pixel shuffle. |
|
|
|
|
|
Args: |
|
|
pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): |
|
|
Spatial factor applied before pixel shuffle reduces the resolution. |
|
|
num_patches (`int`, *optional*, defaults to 256): |
|
|
Maximum number of learnable positional embeddings to initialize. |
|
|
""" |
|
|
|
|
|
model_type = "isaac_vision" |
|
|
base_config_key = "vision_config" |
|
|
_attn_implementation: str | None = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pixel_shuffle_scale_factor: int = 1, |
|
|
num_patches: int = 256, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor |
|
|
self.num_patches = num_patches |
|
|
|
|
|
if self._attn_implementation is None: |
|
|
self._attn_implementation = "flash_attention_2" |
|
|
|
|
|
|
|
|
class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): |
|
|
patch_size: int | None |
|
|
max_num_patches: int | None |
|
|
min_num_patches: int | None |
|
|
pixel_shuffle_scale: int | None |
|
|
|
|
|
|
|
|
class IsaacProcessorKwargs(ProcessingKwargs, total=False): |
|
|
images_kwargs: IsaacImageProcessorKwargs |
|
|
|
|
|
|
|
|
|
|
|
IsaacProcessorKwargs.__annotations__["images_kwargs"] = IsaacImageProcessorKwargs |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class IsaacImageProcessorFast(BaseImageProcessorFast): |
|
|
MAX_PIXELS = 60_000_000 |
|
|
r"""Fast torch-based image processor for Isaac vision inputs.""" |
|
|
|
|
|
resample = PILImageResampling.BILINEAR |
|
|
model_input_names = ["patches", "token_grids"] |
|
|
valid_kwargs = IsaacImageProcessorKwargs |
|
|
unused_kwargs = ["size", "do_center_crop", "crop_size"] |
|
|
|
|
|
do_resize = True |
|
|
size: SizeDict | None = None |
|
|
default_to_square: bool | None = None |
|
|
do_center_crop = False |
|
|
crop_size: SizeDict | None = None |
|
|
patch_size: int | None = 16 |
|
|
max_num_patches: int | None = 256 |
|
|
min_num_patches: int | None = None |
|
|
pixel_shuffle_scale: int | None = 1 |
|
|
do_pad = False |
|
|
pad_size: SizeDict | None = None |
|
|
do_rescale = True |
|
|
rescale_factor = 1 / 255 |
|
|
do_normalize = True |
|
|
image_mean = list(VISION_MEAN) |
|
|
image_std = list(VISION_STD) |
|
|
do_convert_rgb = True |
|
|
return_tensors = None |
|
|
data_format = ChannelDimension.FIRST |
|
|
input_data_format = None |
|
|
device = None |
|
|
disable_grouping = False |
|
|
size_divisor: int | None = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
**kwargs: Unpack[IsaacImageProcessorKwargs], |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) |
|
|
if pixel_shuffle_scale < 1: |
|
|
raise ValueError("`pixel_shuffle_scale` must be >= 1") |
|
|
self.pixel_shuffle_scale = pixel_shuffle_scale |
|
|
|
|
|
|
|
|
def _validate_preprocess_kwargs(self, **kwargs): |
|
|
|
|
|
kwargs.pop("do_resize", None) |
|
|
kwargs.pop("size", None) |
|
|
kwargs.pop("do_center_crop", None) |
|
|
kwargs.pop("crop_size", None) |
|
|
kwargs.pop("disable_grouping", None) |
|
|
return super()._validate_preprocess_kwargs(**kwargs) |
|
|
|
|
|
def resize( |
|
|
self, |
|
|
image: "torch.Tensor", |
|
|
size: SizeDict, |
|
|
interpolation: Optional[Any] = None, |
|
|
antialias: bool = True, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
if size.height is None or size.width is None: |
|
|
raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") |
|
|
|
|
|
resize_mode: Any = interpolation |
|
|
if hasattr(resize_mode, "value"): |
|
|
resize_mode = resize_mode.value |
|
|
elif hasattr(resize_mode, "name"): |
|
|
resize_mode = resize_mode.name.lower() |
|
|
elif resize_mode is None: |
|
|
resize_mode = "bilinear" |
|
|
|
|
|
if isinstance(resize_mode, str): |
|
|
mode_key = resize_mode.lower() |
|
|
else: |
|
|
mode_key = resize_mode |
|
|
|
|
|
resize_kwargs: dict[str, Any] = {} |
|
|
if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: |
|
|
resize_kwargs["align_corners"] = False |
|
|
|
|
|
return F.interpolate( |
|
|
image, |
|
|
size=(size.height, size.width), |
|
|
mode=resize_mode, |
|
|
**resize_kwargs, |
|
|
) |
|
|
|
|
|
def _preprocess( |
|
|
self, |
|
|
images: list["torch.Tensor"], |
|
|
do_resize: bool, |
|
|
size: Optional[SizeDict], |
|
|
interpolation: Optional[Any], |
|
|
do_center_crop: bool, |
|
|
crop_size: Optional[SizeDict], |
|
|
do_rescale: Optional[bool], |
|
|
rescale_factor: Optional[float], |
|
|
do_normalize: Optional[bool], |
|
|
image_mean: Optional[Union[float, Sequence[float]]], |
|
|
image_std: Optional[Union[float, Sequence[float]]], |
|
|
disable_grouping: Optional[bool] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
do_pad: Optional[bool] = None, |
|
|
pad_size: Optional[SizeDict] = None, |
|
|
*, |
|
|
patch_size: int | None = None, |
|
|
max_num_patches: int | None = None, |
|
|
min_num_patches: int | None = None, |
|
|
pixel_shuffle_scale: int | None = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
if do_center_crop: |
|
|
raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") |
|
|
if do_pad: |
|
|
raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") |
|
|
|
|
|
|
|
|
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
|
|
processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
|
|
token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
|
|
virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
|
|
real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} |
|
|
|
|
|
for shape, stacked_images in grouped_images.items(): |
|
|
if stacked_images.ndim != 4: |
|
|
raise ValueError("Expected batched channel-first image tensors.") |
|
|
|
|
|
batch_size, channels, original_height, original_width = stacked_images.shape |
|
|
|
|
|
if bool(self.do_convert_rgb) and channels == 1: |
|
|
stacked_images = stacked_images.repeat(1, 3, 1, 1) |
|
|
channels = 3 |
|
|
|
|
|
if original_height * original_width > self.MAX_PIXELS: |
|
|
raise ValueError( |
|
|
f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`" |
|
|
) |
|
|
|
|
|
target_height, target_width = get_image_size_for_max_num_patches( |
|
|
original_height, |
|
|
original_width, |
|
|
patch_size, |
|
|
max_num_patches, |
|
|
min_num_patches=min_num_patches, |
|
|
pixel_shuffle_scale=pixel_shuffle_scale, |
|
|
) |
|
|
|
|
|
if do_resize: |
|
|
resize_size = SizeDict(height=target_height, width=target_width) |
|
|
image_batch = self.resize( |
|
|
image=stacked_images, |
|
|
size=resize_size, |
|
|
interpolation=interpolation, |
|
|
) |
|
|
else: |
|
|
if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): |
|
|
raise ValueError( |
|
|
"Image dimensions must be divisible by patch_size when resize is disabled." |
|
|
) |
|
|
image_batch = stacked_images |
|
|
target_height, target_width = original_height, original_width |
|
|
|
|
|
if do_rescale: |
|
|
image_batch = self.rescale_and_normalize( |
|
|
image_batch, |
|
|
do_rescale=do_rescale, |
|
|
rescale_factor=rescale_factor, |
|
|
do_normalize=do_normalize, |
|
|
image_mean=image_mean, |
|
|
image_std=image_std, |
|
|
) |
|
|
|
|
|
nhwc_images = image_batch.permute(0, 2, 3, 1) |
|
|
nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) |
|
|
|
|
|
patches = patchify_vision(nhwc_images, patch_size=patch_size) |
|
|
_, height_tokens, width_tokens, _ = patches.shape |
|
|
|
|
|
token_grid = torch.tensor( |
|
|
[height_tokens, width_tokens], |
|
|
dtype=torch.long, |
|
|
device=patches.device, |
|
|
).unsqueeze(0).repeat(batch_size, 1) |
|
|
|
|
|
real_dim = torch.tensor( |
|
|
[1, height_tokens, width_tokens], |
|
|
dtype=torch.long, |
|
|
device=patches.device, |
|
|
).unsqueeze(0).repeat(batch_size, 1) |
|
|
|
|
|
if pixel_shuffle_scale > 1: |
|
|
if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): |
|
|
raise ValueError( |
|
|
"Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." |
|
|
) |
|
|
virtual_height = height_tokens // pixel_shuffle_scale |
|
|
virtual_width = width_tokens // pixel_shuffle_scale |
|
|
else: |
|
|
virtual_height = height_tokens |
|
|
virtual_width = width_tokens |
|
|
|
|
|
virtual_dim = torch.tensor( |
|
|
[1, virtual_height, virtual_width], |
|
|
dtype=torch.long, |
|
|
device=patches.device, |
|
|
).unsqueeze(0).repeat(batch_size, 1) |
|
|
|
|
|
processed_patches_grouped[shape] = patches |
|
|
token_grids_grouped[shape] = token_grid |
|
|
virtual_dims_grouped[shape] = virtual_dim |
|
|
real_dims_grouped[shape] = real_dim |
|
|
|
|
|
patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) |
|
|
token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) |
|
|
virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) |
|
|
real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) |
|
|
|
|
|
patches_tensor = torch.stack(patches_slices, dim=0) |
|
|
token_grids_tensor = torch.stack(token_grid_slices, dim=0) |
|
|
virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) |
|
|
real_dims_tensor = torch.stack(real_dim_slices, dim=0) |
|
|
|
|
|
return BatchFeature( |
|
|
data={ |
|
|
"patches": patches_tensor, |
|
|
"token_grids": token_grids_tensor, |
|
|
"virtual_pixel_size": virtual_dims_tensor, |
|
|
"real_pixel_size": real_dims_tensor, |
|
|
}, |
|
|
tensor_type=return_tensors, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: |
|
|
"""Helper to compute max sequence length from cumulative sequence lengths.""" |
|
|
if cu is None or len(cu) < 2: |
|
|
return fallback |
|
|
return int((cu[1:] - cu[:-1]).max().item()) |
|
|
|
|
|
|
|
|
def build_document_attention_mask( |
|
|
cu_seqlens: torch.Tensor | None, |
|
|
total_tokens: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
) -> torch.Tensor | None: |
|
|
"""Creates an additive attention mask that blocks cross-document attention.""" |
|
|
|
|
|
if cu_seqlens is None: |
|
|
return None |
|
|
|
|
|
if cu_seqlens.numel() < 2: |
|
|
return None |
|
|
|
|
|
seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() |
|
|
if seq_sizes.numel() == 0: |
|
|
return None |
|
|
|
|
|
seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=device), seq_sizes) |
|
|
block_mask = seg_ids[:, None] != seg_ids[None, :] |
|
|
additive_mask = torch.zeros((total_tokens, total_tokens), dtype=dtype, device=device) |
|
|
additive_mask.masked_fill_(block_mask, float("-inf")) |
|
|
return additive_mask.view(1, 1, total_tokens, total_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_document_attention_mask( |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
cu_seqlens: Optional[torch.Tensor], |
|
|
total_tokens: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
) -> Optional[torch.Tensor]: |
|
|
if attention_mask is not None or cu_seqlens is None: |
|
|
return attention_mask |
|
|
|
|
|
return build_document_attention_mask( |
|
|
cu_seqlens=cu_seqlens, |
|
|
total_tokens=total_tokens, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
def flash_attention_document_mask_forward( |
|
|
module: torch.nn.Module, |
|
|
q_lhd: torch.Tensor, |
|
|
k_lhd: torch.Tensor, |
|
|
v_lhd: torch.Tensor, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
dropout: float = 0.0, |
|
|
scaling: float | None = None, |
|
|
cum_seq_q: torch.Tensor | None = None, |
|
|
cum_seq_k: torch.Tensor | None = None, |
|
|
max_seqlen: int | None = None, |
|
|
is_causal: bool = False, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor, None]: |
|
|
"""FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" |
|
|
L, H, D = q_lhd.shape |
|
|
|
|
|
|
|
|
if max_seqlen is not None: |
|
|
max_q = max_k = int(max_seqlen) |
|
|
else: |
|
|
max_q = _max_from_cu(cum_seq_q, L) |
|
|
max_k = _max_from_cu(cum_seq_k, L) |
|
|
|
|
|
|
|
|
if not q_lhd.is_contiguous(): |
|
|
q_lhd = q_lhd.contiguous() |
|
|
if not k_lhd.is_contiguous(): |
|
|
k_lhd = k_lhd.contiguous() |
|
|
if not v_lhd.is_contiguous(): |
|
|
v_lhd = v_lhd.contiguous() |
|
|
|
|
|
out_lhd, *_ = torch.ops.aten._flash_attention_forward( |
|
|
query=q_lhd, |
|
|
key=k_lhd, |
|
|
value=v_lhd, |
|
|
cum_seq_q=cum_seq_q, |
|
|
cum_seq_k=cum_seq_k, |
|
|
max_q=max_q, |
|
|
max_k=max_k, |
|
|
dropout_p=dropout, |
|
|
is_causal=is_causal, |
|
|
return_debug_mask=False, |
|
|
scale=scaling, |
|
|
window_size_left=-1, |
|
|
window_size_right=-1, |
|
|
alibi_slopes=None, |
|
|
) |
|
|
return out_lhd, None |
|
|
|
|
|
|
|
|
def sdpa_document_mask_forward( |
|
|
q_lhd: torch.Tensor, |
|
|
k_lhd: torch.Tensor, |
|
|
v_lhd: torch.Tensor, |
|
|
dropout: float, |
|
|
scaling: float | None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
cu_seqlens: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
"""SDPA with block-diagonal masking for variable-length sequences.""" |
|
|
L, H, D = q_lhd.shape |
|
|
|
|
|
|
|
|
Q = q_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
K = k_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
V = v_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
|
|
|
|
|
|
attn_mask = attention_mask |
|
|
if attn_mask is None: |
|
|
attn_mask = build_document_attention_mask( |
|
|
cu_seqlens=cu_seqlens, |
|
|
total_tokens=L, |
|
|
dtype=q_lhd.dtype, |
|
|
device=q_lhd.device, |
|
|
) |
|
|
|
|
|
if attn_mask is not None and attn_mask.dtype != Q.dtype: |
|
|
attn_mask = attn_mask.to(Q.dtype) |
|
|
|
|
|
Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) |
|
|
return Y.squeeze(0).permute(1, 0, 2) |
|
|
|
|
|
|
|
|
class IsaacVisionEmbeddings(HFSiglip2VisionEmbeddings): |
|
|
"""Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" |
|
|
|
|
|
def __init__(self, config: IsaacVisionConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: |
|
|
packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) |
|
|
if packed_pixel_values is None: |
|
|
return seq_patches.new_zeros((0, self.embed_dim)) |
|
|
|
|
|
embeddings = super().forward(packed_pixel_values, spatial_shapes) |
|
|
return self._unpack_from_batch(embeddings, seq_lengths) |
|
|
|
|
|
def _pack_to_batch( |
|
|
self, |
|
|
seq_patches: torch.Tensor, |
|
|
spatial_shapes: torch.Tensor, |
|
|
) -> tuple[torch.Tensor | None, torch.Tensor]: |
|
|
if seq_patches.ndim != 2: |
|
|
raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") |
|
|
if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: |
|
|
raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") |
|
|
|
|
|
seq_lengths = spatial_shapes.long().prod(dim=-1) |
|
|
total_patches = int(seq_lengths.sum().item()) |
|
|
if total_patches != seq_patches.size(0): |
|
|
raise ValueError( |
|
|
"Mismatch between packed patches and spatial shapes: got " |
|
|
f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." |
|
|
) |
|
|
|
|
|
batch_size = spatial_shapes.size(0) |
|
|
if batch_size == 0: |
|
|
return None, seq_lengths |
|
|
|
|
|
max_length = int(seq_lengths.max().item()) |
|
|
patch_dim = seq_patches.size(-1) |
|
|
device = seq_patches.device |
|
|
|
|
|
packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) |
|
|
|
|
|
start = 0 |
|
|
for batch_idx, length in enumerate(seq_lengths.tolist()): |
|
|
if length == 0: |
|
|
continue |
|
|
end = start + length |
|
|
packed_pixel_values[batch_idx, :length] = seq_patches[start:end] |
|
|
start = end |
|
|
|
|
|
return packed_pixel_values, seq_lengths |
|
|
|
|
|
def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: |
|
|
output_chunks: list[torch.Tensor] = [] |
|
|
for batch_idx, length in enumerate(seq_lengths.tolist()): |
|
|
if length == 0: |
|
|
continue |
|
|
output_chunks.append(embeddings[batch_idx, :length]) |
|
|
|
|
|
if not output_chunks: |
|
|
return embeddings.new_zeros((0, embeddings.size(-1))) |
|
|
|
|
|
return torch.cat(output_chunks, dim=0) |
|
|
|
|
|
|
|
|
class IsaacVisionAttention(Siglip2Attention): |
|
|
"""Custom attention that supports variable-length sequences with flash attention.""" |
|
|
|
|
|
ATTENTION_KEY_MAP: dict[str, str] = { |
|
|
"flash_attention_2": "isaac_flash_attention_2", |
|
|
"flash_attention_3": "isaac_flash_attention_3", |
|
|
"isaac_flash_attention_2": "isaac_flash_attention_2", |
|
|
"isaac_flash_attention_3": "isaac_flash_attention_3", |
|
|
"sdpa": "isaac_sdpa", |
|
|
"isaac_sdpa": "isaac_sdpa", |
|
|
"eager": "isaac_eager", |
|
|
"isaac_eager": "isaac_eager", |
|
|
} |
|
|
|
|
|
def __init__(self, vision_config): |
|
|
super().__init__(vision_config) |
|
|
self.vision_config = vision_config |
|
|
self._variable_length_metadata = None |
|
|
|
|
|
def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): |
|
|
"""Store packed-sequence metadata for the next forward call.""" |
|
|
self._variable_length_metadata = (cu_seqlens, max_seqlen) |
|
|
|
|
|
def _consume_variable_length_metadata(self): |
|
|
if self._variable_length_metadata is None: |
|
|
return None, None |
|
|
cu_seqlens, max_seqlen = self._variable_length_metadata |
|
|
self._variable_length_metadata = None |
|
|
return cu_seqlens, max_seqlen |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None, **kwargs): |
|
|
cu_seqlens = kwargs.pop("cu_seqlens", None) |
|
|
max_seqlen = kwargs.pop("max_seqlen", None) |
|
|
kwargs.pop("output_attentions", None) |
|
|
kwargs.pop("output_hidden_states", None) |
|
|
kwargs.pop("return_dict", None) |
|
|
if kwargs: |
|
|
unexpected = ', '.join(sorted(kwargs)) |
|
|
raise TypeError(f'Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}') |
|
|
cached_cu, cached_max = self._consume_variable_length_metadata() |
|
|
if cu_seqlens is None: |
|
|
cu_seqlens = cached_cu |
|
|
if max_seqlen is None: |
|
|
max_seqlen = cached_max |
|
|
|
|
|
|
|
|
batch_size, L, _ = hidden_states.shape |
|
|
if batch_size != 1: |
|
|
raise ValueError("packed variable-length attention expects batch_size=1") |
|
|
x = hidden_states[0] |
|
|
|
|
|
H = self.num_heads |
|
|
D = self.head_dim |
|
|
p_drop = self.dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
q = self.q_proj(x).view(L, H, D) |
|
|
k = self.k_proj(x).view(L, H, D) |
|
|
v = self.v_proj(x).view(L, H, D) |
|
|
|
|
|
attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") |
|
|
|
|
|
attn_mask = ensure_document_attention_mask( |
|
|
attention_mask, |
|
|
cu_seqlens, |
|
|
L, |
|
|
q.dtype, |
|
|
q.device, |
|
|
) |
|
|
|
|
|
resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) |
|
|
attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if resolved_key is not None else None |
|
|
if attention_fn is None: |
|
|
raise ValueError(f"Attention implementation {attn_impl} not found.") |
|
|
|
|
|
query_states = q.transpose(0, 1).unsqueeze(0) |
|
|
key_states = k.transpose(0, 1).unsqueeze(0) |
|
|
value_states = v.transpose(0, 1).unsqueeze(0) |
|
|
|
|
|
attention_kwargs: dict[str, Any] = { |
|
|
"dropout": p_drop, |
|
|
"scaling": self.scale, |
|
|
"is_causal": False, |
|
|
} |
|
|
if cu_seqlens is not None: |
|
|
attention_kwargs["cu_seq_lens_q"] = cu_seqlens |
|
|
attention_kwargs["cu_seq_lens_k"] = cu_seqlens |
|
|
if max_seqlen is not None: |
|
|
attention_kwargs["max_length_q"] = max_seqlen |
|
|
attention_kwargs["max_length_k"] = max_seqlen |
|
|
|
|
|
attn_output, _ = attention_fn( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask, |
|
|
**attention_kwargs, |
|
|
) |
|
|
|
|
|
y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() |
|
|
|
|
|
|
|
|
y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) |
|
|
return y.unsqueeze(0), None |
|
|
|
|
|
|
|
|
class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): |
|
|
"""Isaac vision encoder layer with variable-length attention.""" |
|
|
|
|
|
def __init__(self, vision_config: IsaacVisionConfig): |
|
|
super().__init__(vision_config) |
|
|
self.self_attn = IsaacVisionAttention(vision_config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
|
max_seqlen: Optional[int] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
): |
|
|
if cu_seqlens is not None or max_seqlen is not None: |
|
|
self.self_attn._variable_length_context( |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
) |
|
|
|
|
|
attention_mask = ensure_document_attention_mask( |
|
|
attention_mask, |
|
|
cu_seqlens, |
|
|
hidden_states.size(1), |
|
|
hidden_states.dtype, |
|
|
hidden_states.device, |
|
|
) |
|
|
|
|
|
return super().forward( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class IsaacVisionEncoder(HFSiglip2Encoder): |
|
|
"""Encoder using Isaac encoder layers with variable-length attention support.""" |
|
|
|
|
|
def __init__(self, config: IsaacVisionConfig): |
|
|
super().__init__(config) |
|
|
self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
|
|
def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: |
|
|
if cu_seqlens is None and max_seqlen is None: |
|
|
return |
|
|
|
|
|
for layer in self.layers: |
|
|
if isinstance(layer, IsaacVisionEncoderLayer): |
|
|
layer.self_attn._variable_length_context( |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
) |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
inputs_embeds, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
|
max_seqlen: Optional[int] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
): |
|
|
self.__variable_length_context(cu_seqlens, max_seqlen) |
|
|
|
|
|
attention_mask = ensure_document_attention_mask( |
|
|
attention_mask, |
|
|
cu_seqlens, |
|
|
inputs_embeds.size(1), |
|
|
inputs_embeds.dtype, |
|
|
inputs_embeds.device, |
|
|
) |
|
|
|
|
|
return super().forward( |
|
|
inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def _isaac_flash_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
dropout: float = 0.0, |
|
|
scaling: Optional[float] = None, |
|
|
is_causal: bool = False, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("flash_attention_2") |
|
|
if not isinstance(module, IsaacVisionAttention) or base_fn is None: |
|
|
if base_fn is None: |
|
|
raise ValueError("Base flash attention function unavailable for fallback.") |
|
|
return base_fn( |
|
|
module, |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
attention_mask, |
|
|
dropout=dropout, |
|
|
scaling=scaling, |
|
|
is_causal=is_causal, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if query.dim() != 4 or query.size(0) != 1: |
|
|
raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") |
|
|
|
|
|
_, num_heads, seq_len, head_dim = query.shape |
|
|
q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
|
|
|
cum_seq_q = kwargs.get("cu_seq_lens_q") |
|
|
cum_seq_k = kwargs.get("cu_seq_lens_k", cum_seq_q) |
|
|
max_seqlen = kwargs.get("max_length_q") |
|
|
|
|
|
effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) |
|
|
effective_scaling = module.scale if scaling is None else scaling |
|
|
|
|
|
attn_mask = attention_mask |
|
|
if attn_mask is None: |
|
|
attn_mask = build_document_attention_mask( |
|
|
cu_seqlens=cum_seq_q, |
|
|
total_tokens=seq_len, |
|
|
dtype=q_lhd.dtype, |
|
|
device=q_lhd.device, |
|
|
) |
|
|
|
|
|
attn_output_lhd, attn_weights = flash_attention_document_mask_forward( |
|
|
module, |
|
|
q_lhd, |
|
|
k_lhd, |
|
|
v_lhd, |
|
|
attention_mask=attn_mask, |
|
|
dropout=effective_dropout, |
|
|
scaling=effective_scaling, |
|
|
cum_seq_q=cum_seq_q, |
|
|
cum_seq_k=cum_seq_k, |
|
|
max_seqlen=max_seqlen, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
|
|
|
attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
def _isaac_sdpa_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
dropout: float = 0.0, |
|
|
scaling: Optional[float] = None, |
|
|
is_causal: bool = False, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("sdpa") |
|
|
if not isinstance(module, IsaacVisionAttention) or base_fn is None: |
|
|
if base_fn is None: |
|
|
raise ValueError("Base SDPA function unavailable for fallback.") |
|
|
return base_fn( |
|
|
module, |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
attention_mask, |
|
|
dropout=dropout, |
|
|
scaling=scaling, |
|
|
is_causal=is_causal, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if query.dim() != 4 or query.size(0) != 1: |
|
|
raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") |
|
|
|
|
|
_, num_heads, seq_len, head_dim = query.shape |
|
|
q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
|
|
|
cum_seq = kwargs.get("cu_seq_lens_q") |
|
|
effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) |
|
|
effective_scaling = module.scale if scaling is None else scaling |
|
|
|
|
|
attn_mask = attention_mask |
|
|
if attn_mask is None: |
|
|
attn_mask = build_document_attention_mask( |
|
|
cu_seqlens=cum_seq, |
|
|
total_tokens=seq_len, |
|
|
dtype=q_lhd.dtype, |
|
|
device=q_lhd.device, |
|
|
) |
|
|
|
|
|
attn_output_lhd = sdpa_document_mask_forward( |
|
|
q_lhd, |
|
|
k_lhd, |
|
|
v_lhd, |
|
|
dropout=effective_dropout, |
|
|
scaling=effective_scaling, |
|
|
attention_mask=attn_mask, |
|
|
cu_seqlens=cum_seq, |
|
|
) |
|
|
|
|
|
attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
def _isaac_eager_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
dropout: float = 0.0, |
|
|
scaling: Optional[float] = None, |
|
|
is_causal: bool = False, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("eager") |
|
|
if not isinstance(module, IsaacVisionAttention) or base_fn is None: |
|
|
if base_fn is None: |
|
|
raise ValueError("Base eager attention function unavailable for fallback.") |
|
|
return base_fn( |
|
|
module, |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
attention_mask, |
|
|
dropout=dropout, |
|
|
scaling=scaling, |
|
|
is_causal=is_causal, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if query.dim() != 4 or query.size(0) != 1: |
|
|
raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") |
|
|
|
|
|
_, num_heads, seq_len, head_dim = query.shape |
|
|
q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) |
|
|
|
|
|
effective_scaling = module.scale if scaling is None else scaling |
|
|
attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * effective_scaling |
|
|
|
|
|
if attention_mask is not None: |
|
|
mask = attention_mask |
|
|
if mask.dim() == 4: |
|
|
mask = mask.squeeze(0).squeeze(0) |
|
|
attn_weights = attn_weights + mask |
|
|
|
|
|
attn_weights = torch.softmax(attn_weights, dim=-1) |
|
|
if dropout and module.training: |
|
|
attn_weights = F.dropout(attn_weights, p=dropout, training=True) |
|
|
|
|
|
attn_output_lhd = torch.matmul(attn_weights, v_lhd) |
|
|
attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) |
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_2", _isaac_flash_attention_forward) |
|
|
ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_3", _isaac_flash_attention_forward) |
|
|
ALL_ATTENTION_FUNCTIONS.register("isaac_sdpa", _isaac_sdpa_forward) |
|
|
ALL_ATTENTION_FUNCTIONS.register("isaac_eager", _isaac_eager_forward) |
|
|
|
|
|
|
|
|
def create_pixel_shuffle_index_map( |
|
|
seq_sizes: torch.Tensor, |
|
|
token_grids: torch.Tensor, |
|
|
scale_factor: int = 1, |
|
|
device: torch.device | None = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Build a gather-index map that tells us, for every *output* token after |
|
|
pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. |
|
|
|
|
|
Args |
|
|
---- |
|
|
seq_sizes : (num_images,) - #patches in each image (row-major order) |
|
|
token_grids : (num_images,2) - (height, width) for every image |
|
|
scale_factor : spatial down-scale factor (≥2) |
|
|
device : (optional) overrides `seq_sizes.device` |
|
|
|
|
|
Returns |
|
|
------- |
|
|
gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. |
|
|
gather_idx[i, j] is the *flat* index into the *original* |
|
|
packed sequence for the j-th sub-patch that forms the |
|
|
i-th output token. |
|
|
""" |
|
|
if device is None: |
|
|
device = seq_sizes.device |
|
|
|
|
|
scale_factor = int(scale_factor) |
|
|
if scale_factor < 2: |
|
|
raise ValueError("`scale_factor` must be ≥ 2") |
|
|
|
|
|
|
|
|
|
|
|
if not is_torchdynamo_compiling(): |
|
|
if not ( |
|
|
(token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all() |
|
|
): |
|
|
raise AssertionError( |
|
|
"Every (H,W) in `token_grids` must be divisible by " |
|
|
f"scale_factor={scale_factor}, got {token_grids.tolist()}" |
|
|
) |
|
|
|
|
|
gather_chunks: list[torch.Tensor] = [] |
|
|
tok_offset = 0 |
|
|
|
|
|
for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): |
|
|
|
|
|
grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset |
|
|
grid = grid.view(h, w) |
|
|
|
|
|
|
|
|
|
|
|
grid = grid.view(h, w // scale_factor, scale_factor) |
|
|
|
|
|
grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) |
|
|
|
|
|
|
|
|
grid = grid.permute(0, 2, 1, 3).contiguous() |
|
|
|
|
|
gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) |
|
|
|
|
|
|
|
|
tok_offset += seq_len |
|
|
|
|
|
|
|
|
gather_idx = torch.cat(gather_chunks, dim=0) |
|
|
return gather_idx |
|
|
|
|
|
|
|
|
def pixel_shuffle_varlen( |
|
|
x: torch.Tensor, |
|
|
token_grids: torch.Tensor, |
|
|
scale_factor: int = 1, |
|
|
) -> torch.Tensor: |
|
|
r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. |
|
|
|
|
|
Args: |
|
|
x (`torch.Tensor`): |
|
|
Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes |
|
|
produced by stacking image patches. |
|
|
token_grids (`torch.Tensor`): |
|
|
Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes |
|
|
corresponding to each image segment inside `x`. |
|
|
scale_factor (`int`, *optional*, defaults to 1): |
|
|
Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a |
|
|
single embedding channel-group. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: |
|
|
`(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` |
|
|
if the singleton batch dimension was present. |
|
|
|
|
|
Raises: |
|
|
ValueError: If more than one batch item is provided. |
|
|
""" |
|
|
keep_batch_dim = x.dim() == 3 |
|
|
if keep_batch_dim: |
|
|
if x.size(0) != 1: |
|
|
raise AssertionError("Packed sequence is expected to have batch_size == 1") |
|
|
x_ = x.squeeze(0) |
|
|
else: |
|
|
x_ = x |
|
|
|
|
|
embed_dim = x_.size(-1) |
|
|
scale_factor = int(scale_factor) |
|
|
|
|
|
|
|
|
seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
|
|
|
|
|
|
gather_idx = create_pixel_shuffle_index_map( |
|
|
seq_sizes=seq_sizes, |
|
|
token_grids=token_grids, |
|
|
scale_factor=scale_factor, |
|
|
device=x_.device, |
|
|
) |
|
|
|
|
|
|
|
|
gathered = x_[gather_idx] |
|
|
|
|
|
|
|
|
out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) |
|
|
|
|
|
|
|
|
if keep_batch_dim: |
|
|
out = out.unsqueeze(0) |
|
|
return out |
|
|
|
|
|
|
|
|
class IsaacVisionTransformer(nn.Module): |
|
|
def __init__(self, config: IsaacVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embeddings = IsaacVisionEmbeddings(config) |
|
|
self.encoder = IsaacVisionEncoder(config) |
|
|
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor |
|
|
|
|
|
def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): |
|
|
seq_patches, token_grids = packed_seq_patches |
|
|
seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
|
|
|
|
|
|
hidden_states = self.embeddings(seq_patches, token_grids) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.unsqueeze(0) |
|
|
|
|
|
|
|
|
cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) |
|
|
cu_seqlens[1:] = seq_sizes.cumsum(0) |
|
|
max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 |
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
inputs_embeds=hidden_states, |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
return_dict=True, |
|
|
) |
|
|
hidden_states = encoder_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
hidden_states = self.post_layernorm(hidden_states) |
|
|
|
|
|
if self.pixel_shuffle_scale_factor > 1: |
|
|
hidden_states = pixel_shuffle_varlen( |
|
|
x=hidden_states, |
|
|
token_grids=token_grids, |
|
|
scale_factor=self.pixel_shuffle_scale_factor, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states.squeeze(0) |
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def get_scaled_image_size( |
|
|
scale: float, |
|
|
original_size: int, |
|
|
patch_size: int, |
|
|
pixel_shuffle_scale: int, |
|
|
) -> int: |
|
|
scaled_size = scale * original_size |
|
|
divisor = patch_size * pixel_shuffle_scale |
|
|
scaled_size = math.ceil(scaled_size / divisor) * divisor |
|
|
scaled_size = max(divisor, scaled_size) |
|
|
return int(scaled_size) |
|
|
|
|
|
|
|
|
def get_image_size_for_max_num_patches( |
|
|
image_height: int, |
|
|
image_width: int, |
|
|
patch_size: int, |
|
|
max_num_patches: int, |
|
|
min_num_patches: int | None = None, |
|
|
eps: float = 1e-5, |
|
|
pixel_shuffle_scale: int = 1, |
|
|
) -> tuple[int, int]: |
|
|
r"""Compute a target resolution whose patch grid satisfies patching parametrization. |
|
|
|
|
|
Args: |
|
|
image_height (`int`): |
|
|
Height in pixels of the source image prior to any resizing. |
|
|
image_width (`int`): |
|
|
Width in pixels of the source image prior to any resizing. |
|
|
patch_size (`int`): |
|
|
Size of the square patch used by the vision encoder. |
|
|
max_num_patches (`int`): |
|
|
Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. |
|
|
min_num_patches (`int`, *optional*): |
|
|
Lower bound on the number of patches. When provided the image will be scaled up if necessary. |
|
|
eps (`float`, *optional*, defaults to 1e-5): |
|
|
Convergence tolerance for the internal binary search to determing the target dimensions. |
|
|
pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
|
|
Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. |
|
|
|
|
|
Returns: |
|
|
`tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` |
|
|
and respect both the maximum and optional minimum patch-count constraints. |
|
|
""" |
|
|
|
|
|
|
|
|
divisor = patch_size * pixel_shuffle_scale |
|
|
adjusted_height = math.ceil(image_height / divisor) * divisor |
|
|
adjusted_height = max(divisor, adjusted_height) |
|
|
adjusted_width = math.ceil(image_width / divisor) * divisor |
|
|
adjusted_width = max(divisor, adjusted_width) |
|
|
|
|
|
num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) |
|
|
|
|
|
if min_num_patches is not None and num_patches < min_num_patches: |
|
|
|
|
|
scale_min, scale_max = 1.0, 100.0 |
|
|
while (scale_max - scale_min) >= eps: |
|
|
scale = (scale_min + scale_max) / 2 |
|
|
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
|
|
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
|
|
num_patches = (target_height / patch_size) * (target_width / patch_size) |
|
|
if num_patches >= min_num_patches: |
|
|
scale_max = scale |
|
|
else: |
|
|
scale_min = scale |
|
|
scale = scale_max |
|
|
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
|
|
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
|
|
return target_height, target_width |
|
|
elif num_patches <= max_num_patches: |
|
|
return adjusted_height, adjusted_width |
|
|
else: |
|
|
|
|
|
scale_min, scale_max = eps / 10, 1.0 |
|
|
while (scale_max - scale_min) >= eps: |
|
|
scale = (scale_min + scale_max) / 2 |
|
|
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
|
|
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
|
|
num_patches = (target_height / patch_size) * (target_width / patch_size) |
|
|
if num_patches <= max_num_patches: |
|
|
scale_min = scale |
|
|
else: |
|
|
scale_max = scale |
|
|
scale = scale_min |
|
|
target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
|
|
target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
|
|
return target_height, target_width |
|
|
|
|
|
|
|
|
def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: |
|
|
r"""Convert normalized images into flattened ViT-style patches. |
|
|
|
|
|
Args: |
|
|
image (`torch.Tensor`): |
|
|
Tensor of shape `(num_images, height, width, channels)`. |
|
|
patch_size (`int`): |
|
|
Edge length of the square patches |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: |
|
|
Patch tensor where each position stores the flattened pixels belonging to that patch. |
|
|
|
|
|
Raises: |
|
|
ValueError: If `height` or `width` is not divisible by `patch_size`. |
|
|
""" |
|
|
num_images, height, width, channels = image.shape |
|
|
if height % patch_size or width % patch_size: |
|
|
raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") |
|
|
patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) |
|
|
patches = patches.permute(0, 1, 3, 2, 4, 5) |
|
|
patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) |
|
|
return patches |
|
|
|
|
|
|
|
|
class IsaacConfig(Qwen3Config): |
|
|
"""Configuration class for Isaac multimodal model.""" |
|
|
|
|
|
model_type = "isaac" |
|
|
sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} |
|
|
image_processor_type = "IsaacImageProcessor" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vision_config: IsaacVisionConfig | None = None, |
|
|
text_config: Qwen3Config | dict | None = None, |
|
|
vision_rescale_factor: float = 1/255, |
|
|
max_sequence_length: int = 16384, |
|
|
vision_token: str = "<image>", |
|
|
**kwargs, |
|
|
): |
|
|
self._rope_scaling: dict[str, Any] | None = None |
|
|
resolved_text_config = kwargs.pop("text_config", text_config) |
|
|
if isinstance(resolved_text_config, Qwen3Config): |
|
|
text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) |
|
|
elif isinstance(resolved_text_config, dict): |
|
|
text_config_kwargs = copy.deepcopy(resolved_text_config) |
|
|
elif resolved_text_config is None: |
|
|
text_config_kwargs = {} |
|
|
else: |
|
|
raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") |
|
|
|
|
|
text_config_kwargs.update(kwargs) |
|
|
|
|
|
super().__init__(**text_config_kwargs) |
|
|
self.text_config = Qwen3Config(**text_config_kwargs) |
|
|
if self._rope_scaling is None: |
|
|
self._rope_scaling = getattr(self.text_config, "rope_scaling", None) |
|
|
else: |
|
|
self.text_config.rope_scaling = self._rope_scaling |
|
|
|
|
|
|
|
|
if isinstance(vision_config, dict): |
|
|
self.vision_config = self.sub_configs["vision_config"](**vision_config) |
|
|
elif isinstance(vision_config, IsaacVisionConfig): |
|
|
self.vision_config = vision_config |
|
|
elif vision_config is None: |
|
|
self.vision_config = self.sub_configs["vision_config"]() |
|
|
|
|
|
|
|
|
self.vision_rescale_factor = float(vision_rescale_factor) |
|
|
|
|
|
|
|
|
self.max_sequence_length = max_sequence_length |
|
|
self.vision_token = vision_token |
|
|
|
|
|
def get_text_config(self, *_, **kwargs) -> Qwen3Config: |
|
|
|
|
|
kwargs.pop("decoder", None) |
|
|
kwargs.pop("encoder", None) |
|
|
return self.text_config |
|
|
|
|
|
@property |
|
|
def rope_scaling(self): |
|
|
if hasattr(self, "text_config") and self.text_config is not None: |
|
|
return getattr(self.text_config, "rope_scaling", None) |
|
|
return self._rope_scaling |
|
|
|
|
|
@rope_scaling.setter |
|
|
def rope_scaling(self, value): |
|
|
self._rope_scaling = value |
|
|
if hasattr(self, "text_config") and self.text_config is not None: |
|
|
self.text_config.rope_scaling = value |
|
|
|
|
|
@property |
|
|
def vision_attn_implementation(self) -> str | None: |
|
|
|
|
|
value = getattr(self.vision_config, "_attn_implementation", None) |
|
|
if value is None: |
|
|
value = getattr(self.vision_config, "attn_implementation", None) |
|
|
return value |
|
|
|
|
|
@vision_attn_implementation.setter |
|
|
def vision_attn_implementation(self, value: str | None) -> None: |
|
|
self.vision_config._attn_implementation = value |
|
|
if value is not None: |
|
|
self.vision_config.attn_implementation = value |
|
|
elif hasattr(self.vision_config, "attn_implementation"): |
|
|
delattr(self.vision_config, "attn_implementation") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: |
|
|
r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. |
|
|
|
|
|
Args: |
|
|
tokenizer (`AutoTokenizer`): |
|
|
Tokenizer used to convert text into model vocabulary ids. |
|
|
text (`str`): |
|
|
Plain-text fragment to encode. |
|
|
time (`float`, *optional*, defaults to 0.0): |
|
|
Timeline coordinate associated with the event. Both start and end times use the same value because text |
|
|
segments are instantaneous in the scheduler. |
|
|
|
|
|
Returns: |
|
|
`Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching |
|
|
metadata so that downstream processors can compute modality-specific embeddings. |
|
|
""" |
|
|
tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) |
|
|
|
|
|
|
|
|
num_tokens = len(tokens) |
|
|
dims_virtual = [num_tokens, 1] |
|
|
dims_real = dims_virtual.copy() |
|
|
|
|
|
|
|
|
|
|
|
if tokens.dim() == 1: |
|
|
tokens = tokens.unsqueeze(-1) |
|
|
|
|
|
return Event( |
|
|
data=tokens, |
|
|
type=TextType.text, |
|
|
time=(time, time), |
|
|
dims_virtual=dims_virtual, |
|
|
dims_real=dims_real, |
|
|
idx_range=(0, num_tokens), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IsaacProcessor(ProcessorMixin): |
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
image_processor_class = ("IsaacImageProcessorFast",) |
|
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
valid_processor_kwargs = IsaacProcessorKwargs |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_processor: IsaacImageProcessorFast | None = None, |
|
|
tokenizer: Qwen2Tokenizer | None = None, |
|
|
*, |
|
|
vision_token: str = "<image>", |
|
|
max_sequence_length: int = 16384, |
|
|
rescale_factor: float | None = None, |
|
|
config: IsaacConfig | dict | None = None, |
|
|
) -> None: |
|
|
if tokenizer is None: |
|
|
raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") |
|
|
|
|
|
if isinstance(config, dict): |
|
|
config = IsaacConfig(**config) |
|
|
|
|
|
if config is not None: |
|
|
max_sequence_length = config.max_sequence_length |
|
|
vision_token = config.vision_token |
|
|
rescale_factor = config.vision_rescale_factor |
|
|
|
|
|
resolved_rescale_factor = ( |
|
|
float(rescale_factor) if rescale_factor is not None else float(1/255) |
|
|
) |
|
|
|
|
|
if config is not None: |
|
|
config.vision_rescale_factor = resolved_rescale_factor |
|
|
|
|
|
self.image_processor = image_processor |
|
|
|
|
|
super().__init__(image_processor, tokenizer) |
|
|
self.current_processor = self.image_processor |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.chat_template = getattr(self.tokenizer, "chat_template", None) |
|
|
|
|
|
self.vision_token = vision_token |
|
|
self.max_sequence_length = max_sequence_length |
|
|
|
|
|
def build_event_stream_simple( |
|
|
self, |
|
|
text: str, |
|
|
images: list[PIL.Image.Image] | None = None, |
|
|
) -> Stream: |
|
|
events = [] |
|
|
|
|
|
|
|
|
|
|
|
pattern = re.escape(self.vision_token) |
|
|
parts = re.split(f"({pattern})", text) |
|
|
|
|
|
image_idx = 0 |
|
|
for current_time, part in enumerate(parts): |
|
|
if part == self.vision_token: |
|
|
|
|
|
if images is None or image_idx >= len(images): |
|
|
raise ValueError("Encountered vision token without a corresponding image.") |
|
|
|
|
|
features = self.image_processor( |
|
|
images=images[image_idx], |
|
|
return_tensors=TensorType.PYTORCH, |
|
|
) |
|
|
|
|
|
patches = features["patches"][0] |
|
|
virtual_dims = features["virtual_pixel_size"][0].tolist() |
|
|
real_dims = features["real_pixel_size"][0].tolist() |
|
|
|
|
|
vision_event = Event( |
|
|
data=patches.reshape(-1, patches.shape[-1]), |
|
|
type=VisionType.image, |
|
|
time=(current_time, current_time), |
|
|
dims_virtual=virtual_dims, |
|
|
dims_real=real_dims, |
|
|
idx_range=(0, math.prod(virtual_dims)), |
|
|
) |
|
|
events.append(vision_event) |
|
|
image_idx += 1 |
|
|
elif part: |
|
|
|
|
|
text_event = create_text_event(self.tokenizer, part, time=current_time) |
|
|
events.append(text_event) |
|
|
|
|
|
|
|
|
return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: str | list[str], |
|
|
images: PIL.Image.Image | list[PIL.Image.Image] | None = None, |
|
|
return_tensors: str | TensorType | None = TensorType.PYTORCH, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Process text and images into TensorStream format. |
|
|
Args: |
|
|
text: Input text or list of texts with vision tokens |
|
|
images: PIL image or list of images (optional) |
|
|
return_tensors: Format for output tensors |
|
|
|
|
|
Returns: |
|
|
BatchFeature with input_ids and tensor_stream |
|
|
""" |
|
|
|
|
|
if isinstance(text, str): |
|
|
texts = [text] |
|
|
else: |
|
|
texts = text |
|
|
|
|
|
if images is not None: |
|
|
if isinstance(images, PIL.Image.Image): |
|
|
images_list = [images] |
|
|
else: |
|
|
images_list = images |
|
|
else: |
|
|
images_list = None |
|
|
|
|
|
if len(texts) != 1: |
|
|
raise ValueError("IsaacProcessor currently supports batch_size=1") |
|
|
if images_list is not None: |
|
|
|
|
|
vision_token_count = texts[0].count(self.vision_token) |
|
|
if vision_token_count != len(images_list): |
|
|
raise ValueError( |
|
|
f"Number of {self.vision_token} tokens in text ({vision_token_count}) " |
|
|
f"must match number of images ({len(images_list)})" |
|
|
) |
|
|
|
|
|
|
|
|
stream = self.build_event_stream_simple( |
|
|
text=texts[0], |
|
|
images=images_list, |
|
|
) |
|
|
|
|
|
|
|
|
tensor_stream = TensorStream([stream]) |
|
|
|
|
|
|
|
|
_, T = tensor_stream.shape |
|
|
if T > self.max_sequence_length: |
|
|
tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) |
|
|
|
|
|
|
|
|
tokens = tensor_stream_token_view(tensor_stream) |
|
|
if return_tensors in (TensorType.PYTORCH, "pt"): |
|
|
input_ids = torch.as_tensor(tokens, dtype=torch.long) |
|
|
else: |
|
|
input_ids = tokens |
|
|
|
|
|
data = { |
|
|
"input_ids": input_ids, |
|
|
"tensor_stream": tensor_stream, |
|
|
} |
|
|
|
|
|
return BatchFeature(data=data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: |
|
|
r"""Create 3D positional indices for token input. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.Tensor`): |
|
|
Tensor of shape `(batch_size, seq_len)` containing token ids. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the |
|
|
1D position so it can be consumed by the 3-axis MRoPE rotary embedding. |
|
|
""" |
|
|
batch_size, seq_length = input_ids.shape |
|
|
position_ids = torch.arange(seq_length, device=input_ids.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) |
|
|
return position_ids |
|
|
|
|
|
|
|
|
class IsaacRotaryEmbedding(nn.Module): |
|
|
EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} |
|
|
|
|
|
def __init__(self, config: IsaacConfig, device=None): |
|
|
super().__init__() |
|
|
|
|
|
rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config |
|
|
rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} |
|
|
|
|
|
sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} |
|
|
config_for_rope = copy.copy(rope_source_cfg) |
|
|
config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None |
|
|
|
|
|
init_device = device if device is not None and getattr(device, "type", None) != "meta" else None |
|
|
self._qwen_rotary = Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) |
|
|
|
|
|
rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] |
|
|
self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) |
|
|
self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size |
|
|
|
|
|
@staticmethod |
|
|
def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: |
|
|
if section is None: |
|
|
weights = (2, 1, 1) |
|
|
base = [rotary_half_dim * w // sum(weights) for w in weights] |
|
|
base[0] += rotary_half_dim - sum(base) |
|
|
return base |
|
|
|
|
|
section = [int(v) for v in section] |
|
|
if len(section) != 3: |
|
|
raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") |
|
|
if sum(section) != rotary_half_dim: |
|
|
raise ValueError( |
|
|
f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." |
|
|
) |
|
|
return section |
|
|
|
|
|
def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: |
|
|
split_sections = tuple(self.mrope_section * 2) |
|
|
chunks = tensor.split(split_sections, dim=-1) |
|
|
return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) |
|
|
|
|
|
@property |
|
|
def inv_freq(self) -> torch.Tensor: |
|
|
return self._qwen_rotary.inv_freq |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
position_ids: torch.Tensor, |
|
|
modality_tensor: torch.Tensor, |
|
|
hidden_states: torch.Tensor | None = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
if position_ids.ndim != 3 or position_ids.size(-1) != 3: |
|
|
raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") |
|
|
if modality_tensor.shape != position_ids.shape[:2]: |
|
|
raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") |
|
|
|
|
|
if hidden_states is None: |
|
|
batch, seq_len, _ = position_ids.shape |
|
|
hidden_states = torch.zeros( |
|
|
batch, |
|
|
seq_len, |
|
|
self.hidden_size, |
|
|
dtype=torch.float32, |
|
|
device=position_ids.device, |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pos = position_ids.clone() |
|
|
not_spatial = modality_tensor != VisionType.image.value |
|
|
if not_spatial.any(): |
|
|
data_1d = pos[not_spatial][..., 0].unsqueeze(-1) |
|
|
pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) |
|
|
|
|
|
pos_axes = pos.permute(2, 0, 1).contiguous() |
|
|
|
|
|
cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) |
|
|
|
|
|
cos_axes = cos_axes.to(hidden_states.dtype) |
|
|
sin_axes = sin_axes.to(hidden_states.dtype) |
|
|
|
|
|
cos_combined = self._combine_axes(cos_axes) |
|
|
sin_combined = self._combine_axes(sin_axes) |
|
|
|
|
|
return cos_combined, sin_combined |
|
|
|
|
|
class IsaacModel(Qwen3PreTrainedModel): |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: IsaacConfig): |
|
|
Qwen3PreTrainedModel.__init__(self, config) |
|
|
|
|
|
text_cfg_source = getattr(config, "get_text_config", lambda: config)() |
|
|
text_cfg = copy.deepcopy(text_cfg_source) |
|
|
text_cfg._attn_implementation = config._attn_implementation |
|
|
self.text_model = AutoModel.from_config(text_cfg) |
|
|
|
|
|
self.text_model.config = config |
|
|
|
|
|
self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) |
|
|
|
|
|
if config.vision_config is None: |
|
|
raise ValueError("IsaacConfig should always have vision_config") |
|
|
|
|
|
hidden_dim = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) |
|
|
self.vision_embedding = nn.Sequential( |
|
|
IsaacVisionTransformer(config.vision_config), |
|
|
nn.Linear( |
|
|
hidden_dim, |
|
|
4 * hidden_dim, |
|
|
bias=False, |
|
|
), |
|
|
nn.SiLU(), |
|
|
nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), |
|
|
) |
|
|
|
|
|
|
|
|
self.embed_fns = { |
|
|
TextType: self.embed_text_tokens, |
|
|
VisionType: self.embed_vision, |
|
|
} |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
return self.text_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Module) -> None: |
|
|
self.text_model.set_input_embeddings(value) |
|
|
|
|
|
@property |
|
|
def embed_tokens(self) -> nn.Module: |
|
|
return self.text_model.embed_tokens |
|
|
|
|
|
@embed_tokens.setter |
|
|
def embed_tokens(self, value: nn.Module) -> None: |
|
|
self.text_model.embed_tokens = value |
|
|
|
|
|
@property |
|
|
def layers(self) -> nn.ModuleList: |
|
|
return self.text_model.layers |
|
|
|
|
|
@property |
|
|
def norm(self) -> nn.Module: |
|
|
return self.text_model.norm |
|
|
|
|
|
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): |
|
|
self.text_model._set_gradient_checkpointing( |
|
|
enable=enable, gradient_checkpointing_func=gradient_checkpointing_func |
|
|
) |
|
|
|
|
|
def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
|
|
"""Embed text tokens, squeezing singleton dimensions.""" |
|
|
|
|
|
h = self.text_model.embed_tokens(token_ids) |
|
|
if h.dim() >= 2 and h.size(-2) == 1: |
|
|
h = h[..., 0, :] |
|
|
return h |
|
|
|
|
|
def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
|
|
"""Embed vision tokens using the vision encoder.""" |
|
|
|
|
|
return self.vision_embedding(vision_tokens) |
|
|
|
|
|
def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: |
|
|
""" |
|
|
Embed each modality stream independently, preserving the original TensorStream |
|
|
structure. |
|
|
""" |
|
|
flat_stream = tensor_stream.flat_stream() |
|
|
per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) |
|
|
per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} |
|
|
|
|
|
|
|
|
token_grids = defaultdict(list) |
|
|
for stream in tensor_stream.streams: |
|
|
for event in stream: |
|
|
token_grids[event.type].append(event.dims(virtual=False)) |
|
|
|
|
|
embedded_compact = {} |
|
|
for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): |
|
|
if stream_type.modality == VisionType: |
|
|
|
|
|
grids = token_grids.get(stream_type, []) |
|
|
if len(grids) == 0: |
|
|
input_tensor = modality_payload_tensor |
|
|
else: |
|
|
token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] |
|
|
input_tensor = (modality_payload_tensor, token_grids_tensor) |
|
|
embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) |
|
|
else: |
|
|
embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) |
|
|
|
|
|
|
|
|
embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) |
|
|
h = embedded_ts.compact() |
|
|
return h |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
tensor_stream: TensorStream | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
modality_tensor: torch.LongTensor | None = None, |
|
|
past_key_values: list[torch.FloatTensor] | None = None, |
|
|
inputs_embeds: torch.FloatTensor | None = None, |
|
|
use_cache: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
cache_position: torch.LongTensor | None = None, |
|
|
**kwargs, |
|
|
) -> tuple | BaseModelOutputWithPast: |
|
|
""" |
|
|
Forward pass with MRoPE position embeddings. |
|
|
|
|
|
Computes position embeddings once and passes them through all layers. |
|
|
""" |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if tensor_stream is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both tensor_stream and inputs_embeds") |
|
|
elif tensor_stream is not None: |
|
|
|
|
|
inputs_embeds = self.embed_stream(tensor_stream) |
|
|
|
|
|
if modality_tensor is None: |
|
|
modality_tensor = modality_mask(tensor_stream) |
|
|
elif input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
inputs_embeds = self.text_model.embed_tokens(input_ids) |
|
|
|
|
|
if modality_tensor is None: |
|
|
batch_size, seq_length = input_ids.shape |
|
|
modality_tensor = torch.full( |
|
|
(batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long |
|
|
) |
|
|
elif inputs_embeds is None: |
|
|
raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
if tensor_stream is not None: |
|
|
position_ids = compute_mrope_pos_tensor(tensor_stream) |
|
|
else: |
|
|
position_ids = compute_position_ids_input_ids(input_ids) |
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb( |
|
|
position_ids, |
|
|
modality_tensor, |
|
|
hidden_states=inputs_embeds, |
|
|
) |
|
|
cos = cos.to(inputs_embeds.dtype) |
|
|
sin = sin.to(inputs_embeds.dtype) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = self._update_causal_mask( |
|
|
attention_mask, inputs_embeds, cache_position, past_key_values, False |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
for decoder_layer in self.text_model.layers: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=(cos, sin), |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs |
|
|
|
|
|
|
|
|
hidden_states = self.text_model.norm(hidden_states) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|
|
|
def _update_causal_mask( |
|
|
self, |
|
|
attention_mask: torch.Tensor, |
|
|
input_tensor: torch.Tensor, |
|
|
cache_position: torch.Tensor, |
|
|
past_key_values: Cache, |
|
|
output_attentions: bool = False, |
|
|
): |
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
if attention_mask is not None and past_key_values is not None: |
|
|
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] |
|
|
if is_padding_right: |
|
|
raise ValueError( |
|
|
"You are attempting to perform batched generation with padding_side='right'" |
|
|
" this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " |
|
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
|
|
) |
|
|
if attention_mask is not None and 0.0 in attention_mask: |
|
|
return attention_mask |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.config._attn_implementation == "sdpa" |
|
|
and not (using_static_cache or using_sliding_window_cache) |
|
|
and not output_attentions |
|
|
): |
|
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
|
attention_mask, |
|
|
inputs_embeds=input_tensor, |
|
|
past_key_values_length=past_seen_tokens, |
|
|
sliding_window=self.config.sliding_window, |
|
|
is_training=self.training, |
|
|
): |
|
|
return None |
|
|
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
sequence_length = input_tensor.shape[1] |
|
|
|
|
|
if using_sliding_window_cache or using_static_cache: |
|
|
target_length = past_key_values.get_max_cache_shape() |
|
|
|
|
|
else: |
|
|
target_length = ( |
|
|
attention_mask.shape[-1] |
|
|
if isinstance(attention_mask, torch.Tensor) |
|
|
else past_seen_tokens + sequence_length + 1 |
|
|
) |
|
|
|
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask, |
|
|
sequence_length=sequence_length, |
|
|
target_length=target_length, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
cache_position=cache_position, |
|
|
batch_size=input_tensor.shape[0], |
|
|
config=self.config, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|
|
|
if ( |
|
|
self.config._attn_implementation == "sdpa" |
|
|
and attention_mask is not None |
|
|
and attention_mask.device.type in ["cuda", "xpu", "npu"] |
|
|
and not output_attentions |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
|
|
return causal_mask |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask: torch.Tensor, |
|
|
sequence_length: int, |
|
|
target_length: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
cache_position: torch.Tensor, |
|
|
batch_size: int, |
|
|
config: Qwen3Config, |
|
|
past_key_values: Cache, |
|
|
): |
|
|
""" |
|
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
|
|
Args: |
|
|
attention_mask (`torch.Tensor`): |
|
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
|
|
sequence_length (`int`): |
|
|
The sequence length being processed. |
|
|
target_length (`int`): |
|
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
|
|
dtype (`torch.dtype`): |
|
|
The dtype to use for the 4D attention mask. |
|
|
device (`torch.device`): |
|
|
The device to place the 4D attention mask on. |
|
|
cache_position (`torch.Tensor`): |
|
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
|
batch_size (`torch.Tensor`): |
|
|
Batch size. |
|
|
config (`Qwen3Config`): |
|
|
The model's configuration class |
|
|
past_key_values (`Cache`): |
|
|
The cache class that is being used currently to generate |
|
|
""" |
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
causal_mask = attention_mask |
|
|
else: |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
|
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
|
if config.sliding_window is not None: |
|
|
|
|
|
|
|
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: |
|
|
sliding_attend_mask = torch.arange(target_length, device=device) <= ( |
|
|
cache_position.reshape(-1, 1) - config.sliding_window |
|
|
) |
|
|
diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
|
|
causal_mask *= diagonal_attend_mask |
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask.clone() |
|
|
if attention_mask.shape[-1] > target_length: |
|
|
attention_mask = attention_mask[:, :target_length] |
|
|
mask_length = attention_mask.shape[-1] |
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
|
causal_mask.device |
|
|
) |
|
|
padding_mask = padding_mask == 0 |
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
|
padding_mask, min_dtype |
|
|
) |
|
|
return causal_mask |
|
|
|
|
|
|
|
|
class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): |
|
|
"""Isaac multimodal model for conditional generation.""" |
|
|
|
|
|
config_class = IsaacConfig |
|
|
|
|
|
def __init__(self, config: IsaacConfig): |
|
|
super().__init__(config) |
|
|
self.model = IsaacModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.rope_deltas = None |
|
|
|
|
|
def get_rope_index( |
|
|
self, |
|
|
input_ids: torch.Tensor | None, |
|
|
tensor_stream: TensorStream | None, |
|
|
attention_mask: torch.Tensor | None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute MRoPE position ids from a TensorStream (or 1D fallback). |
|
|
|
|
|
Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. |
|
|
rope_deltas is (B,1) used to advance positions in decode. |
|
|
""" |
|
|
|
|
|
if tensor_stream is None and input_ids is None: |
|
|
raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") |
|
|
|
|
|
if tensor_stream is not None: |
|
|
pos_3d = compute_mrope_pos_tensor(tensor_stream) |
|
|
else: |
|
|
pos_3d = compute_position_ids_input_ids(input_ids) |
|
|
B, L, _ = pos_3d.shape |
|
|
|
|
|
|
|
|
m_per_batch = pos_3d.amax(dim=(1, 2)) |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
seq_lens = torch.full_like(m_per_batch, L) |
|
|
else: |
|
|
seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) |
|
|
|
|
|
rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) |
|
|
return pos_3d, rope_deltas |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
tensor_stream: TensorStream | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
past_key_values: list[torch.FloatTensor] | None = None, |
|
|
inputs_embeds: torch.FloatTensor | None = None, |
|
|
labels: torch.LongTensor | None = None, |
|
|
use_cache: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
cache_position: torch.LongTensor | None = None, |
|
|
**kwargs, |
|
|
) -> tuple | CausalLMOutputWithPast: |
|
|
""" |
|
|
Forward pass for conditional generation supporting both standard inputs and TensorStream. |
|
|
Uses our embed_stream approach for multimodal inputs. |
|
|
""" |
|
|
|
|
|
|
|
|
if tensor_stream is not None: |
|
|
input_ids = None |
|
|
if input_ids is None and inputs_embeds is None and tensor_stream is None: |
|
|
raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if position_ids is None and tensor_stream is not None: |
|
|
position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) |
|
|
elif position_ids is None and input_ids is not None: |
|
|
|
|
|
position_ids = compute_position_ids_input_ids(input_ids) |
|
|
if cache_position is not None and self.rope_deltas is not None: |
|
|
|
|
|
|
|
|
rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) |
|
|
else: |
|
|
rope_delta = 0 |
|
|
if cache_position is not None and not isinstance(rope_delta, int): |
|
|
batch_size = input_ids.shape[0] |
|
|
rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(rope_delta) |
|
|
|
|
|
if tensor_stream is not None: |
|
|
modality_tensor = modality_mask(tensor_stream) |
|
|
else: |
|
|
batch_size, seq_len = input_ids.shape |
|
|
modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
tensor_stream=tensor_stream, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
modality_tensor=modality_tensor, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
past_key_values: list[torch.FloatTensor] | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.FloatTensor | None = None, |
|
|
tensor_stream: TensorStream | None = None, |
|
|
cache_position: torch.LongTensor | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
use_cache: bool = True, |
|
|
**kwargs, |
|
|
) -> dict[str, Any]: |
|
|
""" |
|
|
Prepare inputs for generation, handling TensorStream inputs properly. |
|
|
""" |
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
cache_position=cache_position, |
|
|
position_ids=position_ids, |
|
|
use_cache=use_cache, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): |
|
|
model_inputs["tensor_stream"] = tensor_stream |
|
|
|
|
|
model_inputs["position_ids"] = None |
|
|
|
|
|
if cache_position is not None and cache_position[0] != 0: |
|
|
model_inputs["tensor_stream"] = None |
|
|
return model_inputs |
|
|
|
|
|
def can_generate(self) -> bool: |
|
|
return True |
|
|
|
|
|
|
|
|
AutoImageProcessor.register( |
|
|
IsaacConfig, |
|
|
fast_image_processor_class=IsaacImageProcessorFast, |
|
|
exist_ok=True, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"IsaacConfig", |
|
|
"IsaacModel", |
|
|
"IsaacForConditionalGeneration", |
|
|
"IsaacImageProcessorFast", |
|
|
"IsaacProcessor", |
|
|
] |
|
|
|
|
|
|
|
|
def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: |
|
|
"""Compute residuals for P-frames to stay in sync with the training pipeline.""" |
|
|
if not any(is_p_frame): |
|
|
return frames |
|
|
|
|
|
frame_indices = torch.arange(len(is_p_frame), device=frames.device) |
|
|
i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) |
|
|
last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 |
|
|
p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] |
|
|
frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] |
|
|
return frames |
|
|
|