| from typing import Optional |
| import torch |
| from torch import nn |
| from transformers import PreTrainedModel |
| from transformers.models.qwen3_vl import Qwen3VLModel |
| from transformers.utils import logging |
|
|
| from .configuration_ops_colqwen3 import OpsColQwen3Config |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class OpsColQwen3PreTrainedModel(PreTrainedModel): |
| config_class = OpsColQwen3Config |
| base_model_prefix = "ops_colqwen3" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
|
|
| class OpsColQwen3Model(OpsColQwen3PreTrainedModel): |
| _checkpoint_conversion_mapping = { |
| r"^language_model": r"qwen3vl.language_model", |
| r"^visual": "qwen3vl.visual", |
| } |
|
|
| def __init__(self, config: OpsColQwen3Config): |
| super().__init__(config) |
| self.config = config |
|
|
| self.qwen3vl = Qwen3VLModel(config) |
| self.dims = config.text_config.hidden_size |
| self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims) |
|
|
| self.mask_non_image_embeddings = config.mask_non_image_embeddings |
| self.post_init() |
|
|
| @classmethod |
| def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs): |
| key_mapping = kwargs.pop("key_mapping", None) |
| if key_mapping is None: |
| key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) |
| dims = None |
| if 'dims' in kwargs: |
| dims = kwargs.pop('dims') |
| elif config is not None: |
| dims = config.dims |
|
|
| model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping) |
| if dims is not None: |
| model.dims = dims |
| return model |
|
|
| def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: |
| has_pixel_values = pixel_values is not None |
|
|
| if has_pixel_values: |
| if image_grid_thw is None: |
| raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.") |
| if not torch.is_tensor(image_grid_thw): |
| image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device) |
|
|
| offsets = image_grid_thw.prod(dim=1) |
| unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)] |
| pixel_values = torch.cat(unpadded, dim=0) if unpadded else None |
|
|
| outputs = self.qwen3vl( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| image_grid_thw=image_grid_thw, |
| use_cache=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| last_hidden_states = outputs.last_hidden_state |
| proj = self.custom_text_proj(last_hidden_states) |
|
|
| if self.dims < self.config.text_config.hidden_size: |
| proj = proj[..., : self.dims] |
|
|
| proj = proj / proj.norm(dim=-1, keepdim=True) |
|
|
| if attention_mask is not None: |
| proj = proj * attention_mask.unsqueeze(-1) |
|
|
| if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None: |
| image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) |
| proj = proj * image_mask |
|
|
| return proj |
|
|
| @property |
| def patch_size(self) -> int: |
| return self.qwen3vl.visual.config.patch_size |
|
|
| @property |
| def spatial_merge_size(self) -> int: |
| return self.qwen3vl.visual.config.spatial_merge_size |
|
|