from enum import Enum from transformers import PretrainedConfig from .module_layers import Encoder, Decoder from .module_layers_attn import Encoder as AttnEncoder, Decoder as AttnDecoder # from .module_quantizers import VectorQuantizer class EncoderType(Enum): Simple = Encoder Attn = AttnEncoder class DecoderType(Enum): Simple = Decoder Attn = AttnDecoder # class QuantizerType(Enum): # VQ = VectorQuantizer class VAEConfig(PretrainedConfig): model_type = "vae" def __init__(self, **kwargs): # ref ./modules/__init__.py self.encoder_type = kwargs.get("encoder_type", EncoderType.Simple.name) self.decoder_type = kwargs.get("decoder_type", DecoderType.Simple.name) # self.quantizer_type = kwargs.get("quantizer_type", QuantizerType.VQ.name) # in_ch -> channels * channels_mult -> z_channels -> codebook_dim -> z_channels -> channels * channels_mult -> out_ch self.in_channels = kwargs.get("in_channels", 3) self.out_channels = kwargs.get("out_channels", 3) self.z_channels = kwargs.get("z_channels", 256) # embeding dim self.channels = kwargs.get("channels", 128) # features = [channels * mult for mult in channels_mult] # res -> res // 2**(len(channels_mult)-1) self.channels_mult = kwargs.get("channels_mult", [1, 1, 2, 2]) self.codebook_dim = kwargs.get("codebook_dim", 8) self.codebook_size = kwargs.get("codebook_size", 1024) # if res = 128 and ch_mult = [1, 1, 2, 2], select any from [128/1, 128/2, 128/2**2, 128/2**3] # in taming-transformers use attn_resolutions = [res/2**(len(ch_mult)-1)] self.attn_resolutions = kwargs.get("attn_resolutions", []) self.num_res_blocks = kwargs.get("num_res_blocks", 2) self.resolution = kwargs.get("resolution", [64, 64]) self.dropout = kwargs.get("dropout", 0.) # imagenet mean [0.1616, 0.1646, 0.1618], std [0.2206, 0.2233, 0.2214] # nusc mean [0.3814, 0.3861, 0.3778], std [0.2219, 0.2188, 0.2248] self.image_mean = kwargs.get('image_mean', [0.1616, 0.1646, 0.1618]) self.image_std = kwargs.get("image_std", [0.2206, 0.2233, 0.2214]) self.w_mse = kwargs.get("w_mse", 2) self.w_l1 = kwargs.get("w_l1", 0.2) self.w_perceptual = kwargs.get("w_perceptual", 0.1) self.w_commit = kwargs.get("w_commit", 1) self.w_dino = kwargs.get("w_dino", 0.1) self.w_kl = kwargs.get("w_kl", 0.1) super().__init__(**kwargs)