vae_test / image_processing_vae.py
stonesstones's picture
End of training
372980e verified
from typing import List, Optional, Union, Tuple
import PIL
import torch
from torchvision.transforms.v2 import (
Compose,
Lambda,
Resize,
Normalize,
InterpolationMode,
)
import numpy as np
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ChannelDimension, to_numpy_array
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class VAEImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize:bool = True,
image_size: Tuple[int, int]=[64, 64],
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
image_std: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.do_resize = do_resize
self.image_size = image_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def preprocess(
self,
images: Union["PIL.Image.Image", np.ndarray, List["PIL.Image.Image"], List[np.ndarray]],
is_video: bool = False,
return_tensors: Optional[Union[str, TensorType]] = "pt",
input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
**kwargs
):
if isinstance(images, list):
images = [to_numpy_array(image) for image in images]
images = torch.from_numpy(np.stack(images, axis=0)).float()
else:
images = to_numpy_array(images)
images = torch.from_numpy(images).float()
if is_video:
if images.dim() == 4:
images = images.unsqueeze(0)
if input_data_format == ChannelDimension.LAST:
images = images.permute(0, 1, 4, 2, 3)
else:
if images.dim() == 3:
images = images.unsqueeze(0)
if input_data_format == ChannelDimension.LAST:
images = images.permute(0, 3, 1, 2)
compose_tf = Compose(
[
Resize(self.image_size, interpolation=InterpolationMode.BICUBIC) if self.do_resize else Lambda(lambda x: x),
Lambda(lambda x: x / 255.0) if self.do_rescale else Lambda(lambda x: x),
Normalize(self.image_mean, self.image_std) if self.do_normalize else Lambda(lambda x: x),
]
)
images = compose_tf(images)
return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
def postprocess(
self,
images: "torch.Tensor",
is_video: bool = False,
return_tensors: Optional[Union[str, TensorType]] = "np",
input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
**kwargs
):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images).float()
if isinstance(images, list):
images = torch.stack(images, dim=0)
if not isinstance(images, torch.Tensor):
raise ValueError("images must be a torch.Tensor")
if is_video:
if images.dim() == 4:
images = images.unsqueeze(0)
if input_data_format == ChannelDimension.FIRST:
images = images.permute(0, 1, 3, 4, 2)
else:
if images.dim() == 3:
images = images.unsqueeze(0)
if input_data_format == ChannelDimension.FIRST:
images = images.permute(0, 2, 3, 1)
if self.do_normalize:
images = (images * torch.tensor(self.image_std)) + torch.tensor(self.image_mean)
if self.do_rescale:
images = torch.clamp(images, 0, 1)
images = (images * 255).type(torch.uint8)
if return_tensors == TensorType.NUMPY:
images = images.numpy()
return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)