Spaces:
Runtime error
Runtime error
| # python3.7 | |
| """Contains the base class for generator.""" | |
| import os | |
| import sys | |
| import logging | |
| import numpy as np | |
| import torch | |
| from . import model_settings | |
| __all__ = ['BaseGenerator'] | |
| def get_temp_logger(logger_name='logger'): | |
| """Gets a temporary logger. | |
| This logger will print all levels of messages onto the screen. | |
| Args: | |
| logger_name: Name of the logger. | |
| Returns: | |
| A `logging.Logger`. | |
| Raises: | |
| ValueError: If the input `logger_name` is empty. | |
| """ | |
| if not logger_name: | |
| raise ValueError(f'Input `logger_name` should not be empty!') | |
| logger = logging.getLogger(logger_name) | |
| if not logger.hasHandlers(): | |
| logger.setLevel(logging.DEBUG) | |
| formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s") | |
| sh = logging.StreamHandler(stream=sys.stdout) | |
| sh.setLevel(logging.DEBUG) | |
| sh.setFormatter(formatter) | |
| logger.addHandler(sh) | |
| return logger | |
| class BaseGenerator(object): | |
| """Base class for generator used in GAN variants. | |
| NOTE: The model should be defined with pytorch, and only used for inference. | |
| """ | |
| def __init__(self, model_name, logger=None): | |
| """Initializes with specific settings. | |
| The model should be registered in `model_settings.py` with proper settings | |
| first. Among them, some attributes are necessary, including: | |
| (1) gan_type: Type of the GAN model. | |
| (2) latent_space_dim: Dimension of the latent space. Should be a tuple. | |
| (3) resolution: Resolution of the synthesis. | |
| (4) min_val: Minimum value of the raw output. (default -1.0) | |
| (5) max_val: Maximum value of the raw output. (default 1.0) | |
| (6) channel_order: Channel order of the output image. (default: `RGB`) | |
| Args: | |
| model_name: Name with which the model is registered. | |
| logger: Logger for recording log messages. If set as `None`, a default | |
| logger, which prints messages from all levels to screen, will be | |
| created. (default: None) | |
| Raises: | |
| AttributeError: If some necessary attributes are missing. | |
| """ | |
| self.model_name = model_name | |
| for key, val in model_settings.MODEL_POOL[model_name].items(): | |
| setattr(self, key, val) | |
| self.use_cuda = model_settings.USE_CUDA | |
| self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE | |
| self.logger = logger or get_temp_logger(model_name + '_generator') | |
| self.model = None | |
| self.run_device = 'cuda' if self.use_cuda else 'cpu' | |
| self.cpu_device = 'cpu' | |
| # Check necessary settings. | |
| self.check_attr('gan_type') | |
| self.check_attr('latent_space_dim') | |
| self.check_attr('resolution') | |
| self.min_val = getattr(self, 'min_val', -1.0) | |
| self.max_val = getattr(self, 'max_val', 1.0) | |
| self.output_channels = getattr(self, 'output_channels', 3) | |
| self.channel_order = getattr(self, 'channel_order', 'RGB').upper() | |
| assert self.channel_order in ['RGB', 'BGR'] | |
| # Build model and load pre-trained weights. | |
| self.build() | |
| if os.path.isfile(getattr(self, 'model_path', '')): | |
| self.load() | |
| elif os.path.isfile(getattr(self, 'tf_model_path', '')): | |
| self.convert_tf_model() | |
| else: | |
| self.logger.warning(f'No pre-trained model will be loaded!') | |
| # Change to inference mode and GPU mode if needed. | |
| assert self.model | |
| self.model.eval().to(self.run_device) | |
| def check_attr(self, attr_name): | |
| """Checks the existence of a particular attribute. | |
| Args: | |
| attr_name: Name of the attribute to check. | |
| Raises: | |
| AttributeError: If the target attribute is missing. | |
| """ | |
| if not hasattr(self, attr_name): | |
| raise AttributeError( | |
| f'`{attr_name}` is missing for model `{self.model_name}`!') | |
| def build(self): | |
| """Builds the graph.""" | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def load(self): | |
| """Loads pre-trained weights.""" | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def convert_tf_model(self, test_num=10): | |
| """Converts models weights from tensorflow version. | |
| Args: | |
| test_num: Number of images to generate for testing whether the conversion | |
| is done correctly. `0` means skipping the test. (default 10) | |
| """ | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def sample(self, num): | |
| """Samples latent codes randomly. | |
| Args: | |
| num: Number of latent codes to sample. Should be positive. | |
| Returns: | |
| A `numpy.ndarray` as sampled latend codes. | |
| """ | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def preprocess(self, latent_codes): | |
| """Preprocesses the input latent code if needed. | |
| Args: | |
| latent_codes: The input latent codes for preprocessing. | |
| Returns: | |
| The preprocessed latent codes which can be used as final input for the | |
| generator. | |
| """ | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def easy_sample(self, num): | |
| """Wraps functions `sample()` and `preprocess()` together.""" | |
| return self.preprocess(self.sample(num)) | |
| def synthesize(self, latent_codes): | |
| """Synthesizes images with given latent codes. | |
| NOTE: The latent codes should have already been preprocessed. | |
| Args: | |
| latent_codes: Input latent codes for image synthesis. | |
| Returns: | |
| A dictionary whose values are raw outputs from the generator. | |
| """ | |
| raise NotImplementedError(f'Should be implemented in derived class!') | |
| def get_value(self, tensor): | |
| """Gets value of a `torch.Tensor`. | |
| Args: | |
| tensor: The input tensor to get value from. | |
| Returns: | |
| A `numpy.ndarray`. | |
| Raises: | |
| ValueError: If the tensor is with neither `torch.Tensor` type or | |
| `numpy.ndarray` type. | |
| """ | |
| if isinstance(tensor, np.ndarray): | |
| return tensor | |
| if isinstance(tensor, torch.Tensor): | |
| return tensor.to(self.cpu_device).detach().numpy() | |
| raise ValueError(f'Unsupported input type `{type(tensor)}`!') | |
| def postprocess(self, images): | |
| """Postprocesses the output images if needed. | |
| This function assumes the input numpy array is with shape [batch_size, | |
| channel, height, width]. Here, `channel = 3` for color image and | |
| `channel = 1` for grayscale image. The return images are with shape | |
| [batch_size, height, width, channel]. NOTE: The channel order of output | |
| image will always be `RGB`. | |
| Args: | |
| images: The raw output from the generator. | |
| Returns: | |
| The postprocessed images with dtype `numpy.uint8` with range [0, 255]. | |
| Raises: | |
| ValueError: If the input `images` are not with type `numpy.ndarray` or not | |
| with shape [batch_size, channel, height, width]. | |
| """ | |
| if not isinstance(images, np.ndarray): | |
| raise ValueError(f'Images should be with type `numpy.ndarray`!') | |
| if ('stylegan3' not in self.model_name) and ('stylegan2' not in self.model_name): | |
| images_shape = images.shape | |
| if len(images_shape) != 4 or images_shape[1] not in [1, 3]: | |
| raise ValueError(f'Input should be with shape [batch_size, channel, ' | |
| f'height, width], where channel equals to 1 or 3. ' | |
| f'But {images_shape} is received!') | |
| images = (images - self.min_val) * 255 / (self.max_val - self.min_val) | |
| images = np.clip(images + 0.5, 0, 255).astype(np.uint8) | |
| images = images.transpose(0, 2, 3, 1) | |
| if self.channel_order == 'BGR': | |
| images = images[:, :, :, ::-1] | |
| return images | |
| def easy_synthesize(self, latent_codes, **kwargs): | |
| """Wraps functions `synthesize()` and `postprocess()` together.""" | |
| outputs = self.synthesize(latent_codes, **kwargs) | |
| if 'image' in outputs: | |
| outputs['image'] = self.postprocess(outputs['image']) | |
| return outputs | |
| def get_batch_inputs(self, latent_codes): | |
| """Gets batch inputs from a collection of latent codes. | |
| This function will yield at most `self.batch_size` latent_codes at a time. | |
| Args: | |
| latent_codes: The input latent codes for generation. First dimension | |
| should be the total number. | |
| """ | |
| total_num = latent_codes.shape[0] | |
| for i in range(0, total_num, self.batch_size): | |
| yield latent_codes[i:i + self.batch_size] | |