| from typing import List, Optional, Union, Tuple | |
| import PIL | |
| import torch | |
| from torchvision.transforms.v2 import ( | |
| Compose, | |
| Lambda, | |
| Resize, | |
| Normalize, | |
| InterpolationMode, | |
| ) | |
| import numpy as np | |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature | |
| from transformers.image_utils import ChannelDimension, to_numpy_array | |
| from transformers.utils import TensorType, logging | |
| logger = logging.get_logger(__name__) | |
| class VAEImageProcessor(BaseImageProcessor): | |
| model_input_names = ["pixel_values"] | |
| def __init__( | |
| self, | |
| do_resize:bool = True, | |
| image_size: Tuple[int, int]=[64, 64], | |
| do_rescale: bool = True, | |
| rescale_factor: Union[int, float] = 1 / 255, | |
| do_normalize: bool = True, | |
| image_mean: Optional[Union[List[float]]] = [0.5, 0.5, 0.5], | |
| image_std: Optional[Union[List[float]]] = [0.5, 0.5, 0.5], | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.do_resize = do_resize | |
| self.image_size = image_size | |
| self.do_rescale = do_rescale | |
| self.rescale_factor = rescale_factor | |
| self.do_normalize = do_normalize | |
| self.image_mean = image_mean | |
| self.image_std = image_std | |
| def preprocess( | |
| self, | |
| images: Union["PIL.Image.Image", np.ndarray, List["PIL.Image.Image"], List[np.ndarray]], | |
| is_video: bool = False, | |
| return_tensors: Optional[Union[str, TensorType]] = "pt", | |
| input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST, | |
| **kwargs | |
| ): | |
| if isinstance(images, list): | |
| images = [to_numpy_array(image) for image in images] | |
| images = torch.from_numpy(np.stack(images, axis=0)).float() | |
| else: | |
| images = to_numpy_array(images) | |
| images = torch.from_numpy(images).float() | |
| if is_video: | |
| if images.dim() == 4: | |
| images = images.unsqueeze(0) | |
| if input_data_format == ChannelDimension.LAST: | |
| images = images.permute(0, 1, 4, 2, 3) | |
| else: | |
| if images.dim() == 3: | |
| images = images.unsqueeze(0) | |
| if input_data_format == ChannelDimension.LAST: | |
| images = images.permute(0, 3, 1, 2) | |
| compose_tf = Compose( | |
| [ | |
| Resize(self.image_size, interpolation=InterpolationMode.BICUBIC) if self.do_resize else Lambda(lambda x: x), | |
| Lambda(lambda x: x / 255.0) if self.do_rescale else Lambda(lambda x: x), | |
| Normalize(self.image_mean, self.image_std) if self.do_normalize else Lambda(lambda x: x), | |
| ] | |
| ) | |
| images = compose_tf(images) | |
| return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) | |
| def postprocess( | |
| self, | |
| images: "torch.Tensor", | |
| is_video: bool = False, | |
| return_tensors: Optional[Union[str, TensorType]] = "np", | |
| input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, | |
| **kwargs | |
| ): | |
| if isinstance(images, np.ndarray): | |
| images = torch.from_numpy(images).float() | |
| if isinstance(images, list): | |
| images = torch.stack(images, dim=0) | |
| if not isinstance(images, torch.Tensor): | |
| raise ValueError("images must be a torch.Tensor") | |
| if is_video: | |
| if images.dim() == 4: | |
| images = images.unsqueeze(0) | |
| if input_data_format == ChannelDimension.FIRST: | |
| images = images.permute(0, 1, 3, 4, 2) | |
| else: | |
| if images.dim() == 3: | |
| images = images.unsqueeze(0) | |
| if input_data_format == ChannelDimension.FIRST: | |
| images = images.permute(0, 2, 3, 1) | |
| if self.do_normalize: | |
| images = (images * torch.tensor(self.image_std)) + torch.tensor(self.image_mean) | |
| if self.do_rescale: | |
| images = torch.clamp(images, 0, 1) | |
| images = (images * 255).type(torch.uint8) | |
| if return_tensors == TensorType.NUMPY: | |
| images = images.numpy() | |
| return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) | |