File size: 2,536 Bytes
372980e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from enum import Enum
from transformers import PretrainedConfig
from .module_layers import Encoder, Decoder
from .module_layers_attn import Encoder as AttnEncoder, Decoder as AttnDecoder
# from .module_quantizers import VectorQuantizer
class EncoderType(Enum):
Simple = Encoder
Attn = AttnEncoder
class DecoderType(Enum):
Simple = Decoder
Attn = AttnDecoder
# class QuantizerType(Enum):
# VQ = VectorQuantizer
class VAEConfig(PretrainedConfig):
model_type = "vae"
def __init__(self, **kwargs):
# ref ./modules/__init__.py
self.encoder_type = kwargs.get("encoder_type", EncoderType.Simple.name)
self.decoder_type = kwargs.get("decoder_type", DecoderType.Simple.name)
# self.quantizer_type = kwargs.get("quantizer_type", QuantizerType.VQ.name)
# in_ch -> channels * channels_mult -> z_channels -> codebook_dim -> z_channels -> channels * channels_mult -> out_ch
self.in_channels = kwargs.get("in_channels", 3)
self.out_channels = kwargs.get("out_channels", 3)
self.z_channels = kwargs.get("z_channels", 256) # embeding dim
self.channels = kwargs.get("channels", 128)
# features = [channels * mult for mult in channels_mult]
# res -> res // 2**(len(channels_mult)-1)
self.channels_mult = kwargs.get("channels_mult", [1, 1, 2, 2])
self.codebook_dim = kwargs.get("codebook_dim", 8)
self.codebook_size = kwargs.get("codebook_size", 1024)
# if res = 128 and ch_mult = [1, 1, 2, 2], select any from [128/1, 128/2, 128/2**2, 128/2**3]
# in taming-transformers use attn_resolutions = [res/2**(len(ch_mult)-1)]
self.attn_resolutions = kwargs.get("attn_resolutions", [])
self.num_res_blocks = kwargs.get("num_res_blocks", 2)
self.resolution = kwargs.get("resolution", [64, 64])
self.dropout = kwargs.get("dropout", 0.)
# imagenet mean [0.1616, 0.1646, 0.1618], std [0.2206, 0.2233, 0.2214]
# nusc mean [0.3814, 0.3861, 0.3778], std [0.2219, 0.2188, 0.2248]
self.image_mean = kwargs.get('image_mean', [0.1616, 0.1646, 0.1618])
self.image_std = kwargs.get("image_std", [0.2206, 0.2233, 0.2214])
self.w_mse = kwargs.get("w_mse", 2)
self.w_l1 = kwargs.get("w_l1", 0.2)
self.w_perceptual = kwargs.get("w_perceptual", 0.1)
self.w_commit = kwargs.get("w_commit", 1)
self.w_dino = kwargs.get("w_dino", 0.1)
self.w_kl = kwargs.get("w_kl", 0.1)
super().__init__(**kwargs) |