from typing import List, Optional, Union, Tuple import PIL import torch from torchvision.transforms.v2 import ( Compose, Lambda, Resize, Normalize, InterpolationMode, ) import numpy as np from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import ChannelDimension, to_numpy_array from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) class VAEImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, do_resize:bool = True, image_size: Tuple[int, int]=[64, 64], do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[List[float]]] = [0.5, 0.5, 0.5], image_std: Optional[Union[List[float]]] = [0.5, 0.5, 0.5], *args, **kwargs ): super().__init__(*args, **kwargs) self.do_resize = do_resize self.image_size = image_size self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std def preprocess( self, images: Union["PIL.Image.Image", np.ndarray, List["PIL.Image.Image"], List[np.ndarray]], is_video: bool = False, return_tensors: Optional[Union[str, TensorType]] = "pt", input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST, **kwargs ): if isinstance(images, list): images = [to_numpy_array(image) for image in images] images = torch.from_numpy(np.stack(images, axis=0)).float() else: images = to_numpy_array(images) images = torch.from_numpy(images).float() if is_video: if images.dim() == 4: images = images.unsqueeze(0) if input_data_format == ChannelDimension.LAST: images = images.permute(0, 1, 4, 2, 3) else: if images.dim() == 3: images = images.unsqueeze(0) if input_data_format == ChannelDimension.LAST: images = images.permute(0, 3, 1, 2) compose_tf = Compose( [ Resize(self.image_size, interpolation=InterpolationMode.BICUBIC) if self.do_resize else Lambda(lambda x: x), Lambda(lambda x: x / 255.0) if self.do_rescale else Lambda(lambda x: x), Normalize(self.image_mean, self.image_std) if self.do_normalize else Lambda(lambda x: x), ] ) images = compose_tf(images) return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) def postprocess( self, images: "torch.Tensor", is_video: bool = False, return_tensors: Optional[Union[str, TensorType]] = "np", input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, **kwargs ): if isinstance(images, np.ndarray): images = torch.from_numpy(images).float() if isinstance(images, list): images = torch.stack(images, dim=0) if not isinstance(images, torch.Tensor): raise ValueError("images must be a torch.Tensor") if is_video: if images.dim() == 4: images = images.unsqueeze(0) if input_data_format == ChannelDimension.FIRST: images = images.permute(0, 1, 3, 4, 2) else: if images.dim() == 3: images = images.unsqueeze(0) if input_data_format == ChannelDimension.FIRST: images = images.permute(0, 2, 3, 1) if self.do_normalize: images = (images * torch.tensor(self.image_std)) + torch.tensor(self.image_mean) if self.do_rescale: images = torch.clamp(images, 0, 1) images = (images * 255).type(torch.uint8) if return_tensors == TensorType.NUMPY: images = images.numpy() return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)