vae_test / configuration_vae.py
stonesstones's picture
End of training
372980e verified
from enum import Enum
from transformers import PretrainedConfig
from .module_layers import Encoder, Decoder
from .module_layers_attn import Encoder as AttnEncoder, Decoder as AttnDecoder
# from .module_quantizers import VectorQuantizer
class EncoderType(Enum):
Simple = Encoder
Attn = AttnEncoder
class DecoderType(Enum):
Simple = Decoder
Attn = AttnDecoder
# class QuantizerType(Enum):
# VQ = VectorQuantizer
class VAEConfig(PretrainedConfig):
model_type = "vae"
def __init__(self, **kwargs):
# ref ./modules/__init__.py
self.encoder_type = kwargs.get("encoder_type", EncoderType.Simple.name)
self.decoder_type = kwargs.get("decoder_type", DecoderType.Simple.name)
# self.quantizer_type = kwargs.get("quantizer_type", QuantizerType.VQ.name)
# in_ch -> channels * channels_mult -> z_channels -> codebook_dim -> z_channels -> channels * channels_mult -> out_ch
self.in_channels = kwargs.get("in_channels", 3)
self.out_channels = kwargs.get("out_channels", 3)
self.z_channels = kwargs.get("z_channels", 256) # embeding dim
self.channels = kwargs.get("channels", 128)
# features = [channels * mult for mult in channels_mult]
# res -> res // 2**(len(channels_mult)-1)
self.channels_mult = kwargs.get("channels_mult", [1, 1, 2, 2])
self.codebook_dim = kwargs.get("codebook_dim", 8)
self.codebook_size = kwargs.get("codebook_size", 1024)
# if res = 128 and ch_mult = [1, 1, 2, 2], select any from [128/1, 128/2, 128/2**2, 128/2**3]
# in taming-transformers use attn_resolutions = [res/2**(len(ch_mult)-1)]
self.attn_resolutions = kwargs.get("attn_resolutions", [])
self.num_res_blocks = kwargs.get("num_res_blocks", 2)
self.resolution = kwargs.get("resolution", [64, 64])
self.dropout = kwargs.get("dropout", 0.)
# imagenet mean [0.1616, 0.1646, 0.1618], std [0.2206, 0.2233, 0.2214]
# nusc mean [0.3814, 0.3861, 0.3778], std [0.2219, 0.2188, 0.2248]
self.image_mean = kwargs.get('image_mean', [0.1616, 0.1646, 0.1618])
self.image_std = kwargs.get("image_std", [0.2206, 0.2233, 0.2214])
self.w_mse = kwargs.get("w_mse", 2)
self.w_l1 = kwargs.get("w_l1", 0.2)
self.w_perceptual = kwargs.get("w_perceptual", 0.1)
self.w_commit = kwargs.get("w_commit", 1)
self.w_dino = kwargs.get("w_dino", 0.1)
self.w_kl = kwargs.get("w_kl", 0.1)
super().__init__(**kwargs)