|
|
from enum import Enum |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
from .module_layers import Encoder, Decoder |
|
|
from .module_layers_attn import Encoder as AttnEncoder, Decoder as AttnDecoder |
|
|
|
|
|
|
|
|
|
|
|
class EncoderType(Enum): |
|
|
Simple = Encoder |
|
|
Attn = AttnEncoder |
|
|
|
|
|
|
|
|
class DecoderType(Enum): |
|
|
Simple = Decoder |
|
|
Attn = AttnDecoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAEConfig(PretrainedConfig): |
|
|
model_type = "vae" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
self.encoder_type = kwargs.get("encoder_type", EncoderType.Simple.name) |
|
|
self.decoder_type = kwargs.get("decoder_type", DecoderType.Simple.name) |
|
|
|
|
|
|
|
|
self.in_channels = kwargs.get("in_channels", 3) |
|
|
self.out_channels = kwargs.get("out_channels", 3) |
|
|
self.z_channels = kwargs.get("z_channels", 256) |
|
|
self.channels = kwargs.get("channels", 128) |
|
|
|
|
|
|
|
|
self.channels_mult = kwargs.get("channels_mult", [1, 1, 2, 2]) |
|
|
self.codebook_dim = kwargs.get("codebook_dim", 8) |
|
|
self.codebook_size = kwargs.get("codebook_size", 1024) |
|
|
|
|
|
|
|
|
self.attn_resolutions = kwargs.get("attn_resolutions", []) |
|
|
self.num_res_blocks = kwargs.get("num_res_blocks", 2) |
|
|
self.resolution = kwargs.get("resolution", [64, 64]) |
|
|
self.dropout = kwargs.get("dropout", 0.) |
|
|
|
|
|
|
|
|
self.image_mean = kwargs.get('image_mean', [0.1616, 0.1646, 0.1618]) |
|
|
self.image_std = kwargs.get("image_std", [0.2206, 0.2233, 0.2214]) |
|
|
self.w_mse = kwargs.get("w_mse", 2) |
|
|
self.w_l1 = kwargs.get("w_l1", 0.2) |
|
|
self.w_perceptual = kwargs.get("w_perceptual", 0.1) |
|
|
self.w_commit = kwargs.get("w_commit", 1) |
|
|
self.w_dino = kwargs.get("w_dino", 0.1) |
|
|
self.w_kl = kwargs.get("w_kl", 0.1) |
|
|
super().__init__(**kwargs) |