|
|
import re |
|
|
import torch |
|
|
from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
from transformers.image_utils import ImageInput |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
from PIL import Image |
|
|
|
|
|
from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
|
|
|
|
|
|
|
def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor: |
|
|
""" |
|
|
Expands an image to a square by adding a background color. |
|
|
""" |
|
|
c, height, width = image.shape |
|
|
if width == height: |
|
|
return image |
|
|
elif width > height: |
|
|
result = torch.ones((c, width, width), dtype=image.dtype) * background_color |
|
|
result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image |
|
|
return result |
|
|
else: |
|
|
result = torch.ones((c, height, height), dtype=image.dtype) * background_color |
|
|
result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image |
|
|
return result |
|
|
|
|
|
|
|
|
class FastVLMImageProcessor(CLIPImageProcessorFast): |
|
|
def _preprocess(self, images, **kwargs): |
|
|
image_sizes = [image.shape[-2:][::-1] for image in images] |
|
|
images = [expand_to_square(image) for image in images] |
|
|
images = super()._preprocess(images, **kwargs) |
|
|
pixel_values = torch.stack(images.pixel_values, dim=0) |
|
|
return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes}) |
|
|
|
|
|
class FastVLMProcessor(ProcessorMixin): |
|
|
attributes = ["tokenizer", "image_processor"] |
|
|
image_processor_class = "AutoImageProcessor" |
|
|
tokenizer_class = "AutoTokenizer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer, |
|
|
image_processor, |
|
|
chat_template=None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: ImageInput = None, |
|
|
text: Optional[Union[str, List[str]]] = None, |
|
|
return_tensors: Optional[str] = "pt", |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
elif not isinstance(text, list) and not isinstance(text[0], str): |
|
|
raise TypeError("Invalid input text. Please provide a string, or a list of strings") |
|
|
|
|
|
image_inputs = {} |
|
|
if images is not None: |
|
|
image_inputs = self.image_processor(images=images) |
|
|
|
|
|
image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64) |
|
|
input_ids = torch.tensor([], dtype=torch.int64) |
|
|
attention_mask = torch.tensor([], dtype=torch.int64) |
|
|
for prompt in text: |
|
|
image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)] |
|
|
if len(image_indexes) > 1: |
|
|
raise ValueError( |
|
|
f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead." |
|
|
) |
|
|
|
|
|
|
|
|
pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN) |
|
|
pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
|
|
|
sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64) |
|
|
sample_mask = torch.ones_like(sample_ids) |
|
|
|
|
|
input_ids = torch.cat([input_ids, sample_ids], dim=0) |
|
|
attention_mask = torch.cat([attention_mask, sample_mask], dim=0) |
|
|
|
|
|
return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors) |
|
|
|