stonesstones commited on
Commit
372980e
·
verified ·
1 Parent(s): 10a1d80

End of training

Browse files
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - image-classification
5
+ - generated_from_trainer
6
+ model-index:
7
+ - name: vae_test
8
+ results: []
9
+ ---
10
+
11
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
12
+ should probably proofread and complete it, then remove this comment. -->
13
+
14
+ # vae_test
15
+
16
+ This model is a fine-tuned version of [](https://huggingface.co/) on the train_file=/home/pj24002027/ku50001104/data/mutual_dataset/few_data/train.jsonl, validation_file=/home/pj24002027/ku50001104/data/mutual_dataset/few_data/test.jsonl, max_train_samples=2048, max_eval_samples=2048, use_sensor_keys=CAM_FRONT dataset.
17
+
18
+ ## Model description
19
+
20
+ More information needed
21
+
22
+ ## Intended uses & limitations
23
+
24
+ More information needed
25
+
26
+ ## Training and evaluation data
27
+
28
+ More information needed
29
+
30
+ ## Training procedure
31
+
32
+ ### Training hyperparameters
33
+
34
+ The following hyperparameters were used during training:
35
+ - learning_rate: 0.0002
36
+ - train_batch_size: 128
37
+ - eval_batch_size: 32
38
+ - seed: 42
39
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.5,0.9) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
40
+ - lr_scheduler_type: cosine_with_min_lr
41
+ - lr_scheduler_warmup_steps: 1000
42
+ - num_epochs: 0.1
43
+
44
+ ### Training results
45
+
46
+
47
+
48
+ ### Framework versions
49
+
50
+ - Transformers 4.51.3
51
+ - Pytorch 2.6.0+cu126
52
+ - Datasets 3.5.1
53
+ - Tokenizers 0.21.1
all_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 0.125,
3
+ "train_loss": 1.1237460374832153,
4
+ "train_runtime": 20.6388,
5
+ "train_samples_per_second": 9.923,
6
+ "train_steps_per_second": 0.097
7
+ }
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VAEModel"
4
+ ],
5
+ "attn_resolutions": [],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_vae.VAEConfig",
8
+ "AutoModel": "modeling_vae.VAEModel"
9
+ },
10
+ "channels": 128,
11
+ "channels_mult": [
12
+ 1,
13
+ 1,
14
+ 1,
15
+ 2,
16
+ 2
17
+ ],
18
+ "codebook_dim": 0,
19
+ "codebook_size": 0,
20
+ "decoder_type": "Simple",
21
+ "drop_out": 0,
22
+ "dropout": 0.0,
23
+ "encoder_type": "Simple",
24
+ "image_mean": [
25
+ 0.5,
26
+ 0.5,
27
+ 0.5
28
+ ],
29
+ "image_std": [
30
+ 0.5,
31
+ 0.5,
32
+ 0.5
33
+ ],
34
+ "in_channels": 3,
35
+ "model_type": "vae",
36
+ "num_res_blocks": 2,
37
+ "out_channels": 3,
38
+ "quantizer_type": "VQ",
39
+ "resolution": [
40
+ 64,
41
+ 64
42
+ ],
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.51.3",
45
+ "w_commit": 0,
46
+ "w_dino": 0,
47
+ "w_kl": 1,
48
+ "w_l1": 0.2,
49
+ "w_mse": 2,
50
+ "w_perceptual": 0,
51
+ "z_channels": 64
52
+ }
configuration_vae.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from transformers import PretrainedConfig
3
+
4
+ from .module_layers import Encoder, Decoder
5
+ from .module_layers_attn import Encoder as AttnEncoder, Decoder as AttnDecoder
6
+ # from .module_quantizers import VectorQuantizer
7
+
8
+
9
+ class EncoderType(Enum):
10
+ Simple = Encoder
11
+ Attn = AttnEncoder
12
+
13
+
14
+ class DecoderType(Enum):
15
+ Simple = Decoder
16
+ Attn = AttnDecoder
17
+
18
+
19
+ # class QuantizerType(Enum):
20
+ # VQ = VectorQuantizer
21
+
22
+
23
+ class VAEConfig(PretrainedConfig):
24
+ model_type = "vae"
25
+
26
+ def __init__(self, **kwargs):
27
+ # ref ./modules/__init__.py
28
+ self.encoder_type = kwargs.get("encoder_type", EncoderType.Simple.name)
29
+ self.decoder_type = kwargs.get("decoder_type", DecoderType.Simple.name)
30
+ # self.quantizer_type = kwargs.get("quantizer_type", QuantizerType.VQ.name)
31
+ # in_ch -> channels * channels_mult -> z_channels -> codebook_dim -> z_channels -> channels * channels_mult -> out_ch
32
+ self.in_channels = kwargs.get("in_channels", 3)
33
+ self.out_channels = kwargs.get("out_channels", 3)
34
+ self.z_channels = kwargs.get("z_channels", 256) # embeding dim
35
+ self.channels = kwargs.get("channels", 128)
36
+ # features = [channels * mult for mult in channels_mult]
37
+ # res -> res // 2**(len(channels_mult)-1)
38
+ self.channels_mult = kwargs.get("channels_mult", [1, 1, 2, 2])
39
+ self.codebook_dim = kwargs.get("codebook_dim", 8)
40
+ self.codebook_size = kwargs.get("codebook_size", 1024)
41
+ # if res = 128 and ch_mult = [1, 1, 2, 2], select any from [128/1, 128/2, 128/2**2, 128/2**3]
42
+ # in taming-transformers use attn_resolutions = [res/2**(len(ch_mult)-1)]
43
+ self.attn_resolutions = kwargs.get("attn_resolutions", [])
44
+ self.num_res_blocks = kwargs.get("num_res_blocks", 2)
45
+ self.resolution = kwargs.get("resolution", [64, 64])
46
+ self.dropout = kwargs.get("dropout", 0.)
47
+ # imagenet mean [0.1616, 0.1646, 0.1618], std [0.2206, 0.2233, 0.2214]
48
+ # nusc mean [0.3814, 0.3861, 0.3778], std [0.2219, 0.2188, 0.2248]
49
+ self.image_mean = kwargs.get('image_mean', [0.1616, 0.1646, 0.1618])
50
+ self.image_std = kwargs.get("image_std", [0.2206, 0.2233, 0.2214])
51
+ self.w_mse = kwargs.get("w_mse", 2)
52
+ self.w_l1 = kwargs.get("w_l1", 0.2)
53
+ self.w_perceptual = kwargs.get("w_perceptual", 0.1)
54
+ self.w_commit = kwargs.get("w_commit", 1)
55
+ self.w_dino = kwargs.get("w_dino", 0.1)
56
+ self.w_kl = kwargs.get("w_kl", 0.1)
57
+ super().__init__(**kwargs)
image_processing_vae.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Tuple
2
+ import PIL
3
+ import torch
4
+ from torchvision.transforms.v2 import (
5
+ Compose,
6
+ Lambda,
7
+ Resize,
8
+ Normalize,
9
+ InterpolationMode,
10
+ )
11
+ import numpy as np
12
+
13
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
14
+ from transformers.image_utils import ChannelDimension, to_numpy_array
15
+ from transformers.utils import TensorType, logging
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class VAEImageProcessor(BaseImageProcessor):
22
+
23
+ model_input_names = ["pixel_values"]
24
+
25
+ def __init__(
26
+ self,
27
+ do_resize:bool = True,
28
+ image_size: Tuple[int, int]=[64, 64],
29
+ do_rescale: bool = True,
30
+ rescale_factor: Union[int, float] = 1 / 255,
31
+ do_normalize: bool = True,
32
+ image_mean: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
33
+ image_std: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
34
+ *args,
35
+ **kwargs
36
+ ):
37
+ super().__init__(*args, **kwargs)
38
+ self.do_resize = do_resize
39
+ self.image_size = image_size
40
+ self.do_rescale = do_rescale
41
+ self.rescale_factor = rescale_factor
42
+ self.do_normalize = do_normalize
43
+ self.image_mean = image_mean
44
+ self.image_std = image_std
45
+
46
+ def preprocess(
47
+ self,
48
+ images: Union["PIL.Image.Image", np.ndarray, List["PIL.Image.Image"], List[np.ndarray]],
49
+ is_video: bool = False,
50
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
51
+ input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
52
+ **kwargs
53
+ ):
54
+ if isinstance(images, list):
55
+ images = [to_numpy_array(image) for image in images]
56
+ images = torch.from_numpy(np.stack(images, axis=0)).float()
57
+ else:
58
+ images = to_numpy_array(images)
59
+ images = torch.from_numpy(images).float()
60
+
61
+ if is_video:
62
+ if images.dim() == 4:
63
+ images = images.unsqueeze(0)
64
+ if input_data_format == ChannelDimension.LAST:
65
+ images = images.permute(0, 1, 4, 2, 3)
66
+ else:
67
+ if images.dim() == 3:
68
+ images = images.unsqueeze(0)
69
+ if input_data_format == ChannelDimension.LAST:
70
+ images = images.permute(0, 3, 1, 2)
71
+ compose_tf = Compose(
72
+ [
73
+ Resize(self.image_size, interpolation=InterpolationMode.BICUBIC) if self.do_resize else Lambda(lambda x: x),
74
+ Lambda(lambda x: x / 255.0) if self.do_rescale else Lambda(lambda x: x),
75
+ Normalize(self.image_mean, self.image_std) if self.do_normalize else Lambda(lambda x: x),
76
+ ]
77
+ )
78
+ images = compose_tf(images)
79
+
80
+ return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
81
+
82
+ def postprocess(
83
+ self,
84
+ images: "torch.Tensor",
85
+ is_video: bool = False,
86
+ return_tensors: Optional[Union[str, TensorType]] = "np",
87
+ input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
88
+ **kwargs
89
+ ):
90
+ if isinstance(images, np.ndarray):
91
+ images = torch.from_numpy(images).float()
92
+ if isinstance(images, list):
93
+ images = torch.stack(images, dim=0)
94
+ if not isinstance(images, torch.Tensor):
95
+ raise ValueError("images must be a torch.Tensor")
96
+
97
+ if is_video:
98
+ if images.dim() == 4:
99
+ images = images.unsqueeze(0)
100
+ if input_data_format == ChannelDimension.FIRST:
101
+ images = images.permute(0, 1, 3, 4, 2)
102
+ else:
103
+ if images.dim() == 3:
104
+ images = images.unsqueeze(0)
105
+ if input_data_format == ChannelDimension.FIRST:
106
+ images = images.permute(0, 2, 3, 1)
107
+
108
+ if self.do_normalize:
109
+ images = (images * torch.tensor(self.image_std)) + torch.tensor(self.image_mean)
110
+ if self.do_rescale:
111
+ images = torch.clamp(images, 0, 1)
112
+ images = (images * 255).type(torch.uint8)
113
+
114
+ if return_tensors == TensorType.NUMPY:
115
+ images = images.numpy()
116
+
117
+ return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3385c46b833dccf98c5e3ceb6b4a9b174795a85374120706fa0a0c9a42f2197
3
+ size 31338740
modeling_vae.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch import Tensor
7
+ from transformers import PreTrainedModel
8
+ from transformers.utils import logging, ModelOutput
9
+
10
+ from torchvision.models import vgg16, VGG16_Weights
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+
15
+ from .configuration_vae import VAEConfig, EncoderType, DecoderType
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class VAEOutput(ModelOutput):
23
+ loss: Optional[torch.FloatTensor] = None
24
+ reconstruction: torch.FloatTensor = None
25
+ mse_loss: Optional[torch.FloatTensor] = None
26
+ l1_loss: Optional[torch.FloatTensor] = None
27
+ perceptual_loss: Optional[torch.FloatTensor] = None
28
+ dino_loss: Optional[torch.FloatTensor] = None
29
+ kl_loss: Optional[torch.FloatTensor] = None
30
+
31
+
32
+ class Vgg16(nn.Module):
33
+ # ref https://github.com/dxyang/StyleTransfer/blob/master/vgg.py
34
+ def __init__(self, layers):
35
+ super().__init__()
36
+ features = vgg16(weights=VGG16_Weights.DEFAULT).features
37
+ self.to_relu_1_2 = nn.Sequential()
38
+ self.to_relu_2_2 = nn.Sequential()
39
+ self.to_relu_3_3 = nn.Sequential()
40
+ self.to_relu_4_3 = nn.Sequential()
41
+
42
+ for x in range(4):
43
+ self.to_relu_1_2.add_module(str(x), features[x])
44
+ for x in range(4, 9):
45
+ self.to_relu_2_2.add_module(str(x), features[x])
46
+ for x in range(9, 16):
47
+ self.to_relu_3_3.add_module(str(x), features[x])
48
+ for x in range(16, 23):
49
+ self.to_relu_4_3.add_module(str(x), features[x])
50
+
51
+ # don't need the gradients, just want the features
52
+ for param in self.parameters():
53
+ param.requires_grad = False
54
+
55
+ def forward(self, x):
56
+ h = self.to_relu_1_2(x)
57
+ h_relu_1_2 = h
58
+ h = self.to_relu_2_2(h)
59
+ h_relu_2_2 = h
60
+ h = self.to_relu_3_3(h)
61
+ h_relu_3_3 = h
62
+ h = self.to_relu_4_3(h)
63
+ h_relu_4_3 = h
64
+ out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
65
+ return out
66
+
67
+
68
+ class PerceptualLoss(nn.Module):
69
+ def __init__(self, layers=(3, 8, 15, 22), unnorm_mean=None, unnorm_std=None, weights=None):
70
+ super().__init__()
71
+ self.vgg = Vgg16(layers=layers)
72
+ self.layers = layers
73
+ self.weights = weights or [1.0 / len(layers)] * len(layers)
74
+
75
+ def forward(self, x, y):
76
+ x_vgg = self.vgg(x)
77
+ y_vgg = self.vgg(y)
78
+ loss = 0.0
79
+ for x_vgg_layer, y_vgg_layer in zip(x_vgg, y_vgg):
80
+ loss += F.mse_loss(x_vgg_layer, y_vgg_layer)
81
+ return loss
82
+
83
+ class DinoLoss(nn.Module):
84
+ def __init__(self, patch_size, use_large=False):
85
+ super().__init__()
86
+ size = 'b' if use_large else 's'
87
+ dino = f'dino_vit{size}{patch_size}'
88
+ self.vit = torch.hub.load('facebookresearch/dino:main', dino)
89
+ print('use ', dino)
90
+ self.vit.eval()
91
+ for param in self.vit.parameters():
92
+ param.requires_grad = False
93
+
94
+ def forward(self, gt, embed):
95
+ with torch.no_grad():
96
+ dino_features = self.vit.prepare_tokens(gt)
97
+ for blk in self.vit.blocks:
98
+ dino_features = blk(dino_features)
99
+ dino_features = self.vit.norm(dino_features)
100
+ dino_features = dino_features[:, 1:]
101
+ embed_features = rearrange(embed, 'b c h w -> b (h w) c').contiguous()
102
+ dtype = embed.dtype
103
+ dino_loss = 1 - F.cosine_similarity(dino_features.to(torch.float32), embed_features.to(torch.float32), dim=2)
104
+ dino_loss = dino_loss.mean()
105
+ dino_loss = dino_loss.to(dtype)
106
+ return dino_loss
107
+
108
+
109
+ class VAEModel(PreTrainedModel):
110
+ config_class = VAEConfig
111
+ main_input_name = "s0_img"
112
+
113
+ def __init__(self, config: VAEConfig):
114
+ super().__init__(config)
115
+ dict_config = config.to_dict()
116
+ self.encoder = EncoderType[config.encoder_type].value(**dict_config)
117
+ enc_out_dim = self.config.z_channels * (self.config.resolution[0] // (2 ** (len(self.config.channels_mult) - 1))) ** 2
118
+ latent_dim = 64
119
+ self.cond_mlp = nn.Sequential(
120
+ nn.Linear(enc_out_dim * 2, config.z_channels),
121
+ nn.ReLU(),
122
+ nn.Linear(config.z_channels, config.z_channels),
123
+ nn.ReLU(),
124
+ nn.Linear(config.z_channels, latent_dim * 2),
125
+ )
126
+ self.in_mlp = nn.Sequential(
127
+ nn.Linear(enc_out_dim, config.z_channels),
128
+ nn.ReLU(),
129
+ nn.Linear(config.z_channels, config.z_channels),
130
+ nn.ReLU(),
131
+ nn.Linear(config.z_channels, latent_dim * 2),
132
+ )
133
+ self.cond_mlp_out = nn.Sequential(
134
+ nn.Linear(latent_dim + enc_out_dim, config.z_channels),
135
+ nn.ReLU(),
136
+ nn.Linear(config.z_channels, config.z_channels),
137
+ nn.ReLU(),
138
+ nn.Linear(config.z_channels, enc_out_dim),
139
+ )
140
+ self.out_mlp = nn.Sequential(
141
+ nn.Linear(latent_dim, config.z_channels),
142
+ nn.ReLU(),
143
+ nn.Linear(config.z_channels, config.z_channels),
144
+ nn.ReLU(),
145
+ nn.Linear(config.z_channels, enc_out_dim),
146
+ )
147
+ self.decoder = DecoderType[config.decoder_type].value(**dict_config)
148
+ if config.w_perceptual > 0:
149
+ self.perceptual_loss = PerceptualLoss(
150
+ unnorm_mean=config.image_mean,
151
+ unnorm_std=config.image_std
152
+ )
153
+ if config.w_dino > 0:
154
+ assert config.z_channels in [384, 768]
155
+ patch_size = 2 ** (len(config.channels_mult) - 1)
156
+ self.dino_loss = DinoLoss(patch_size=patch_size)
157
+ self.log_state = {
158
+ "loss": None,
159
+ "mse_loss": None,
160
+ "l1_loss": None,
161
+ "perceptual_loss": None,
162
+ "dino_loss": None,
163
+ "gt": None,
164
+ "recon": None,
165
+ }
166
+ self.post_init()
167
+
168
+ def encode(self, s0_img: Tensor, s1_img: Tensor, a0: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
169
+ # s0 = self.encoder(s0_img).reshape(s0_img.shape[0], -1)
170
+ s0 = None
171
+ s1 = self.encoder(s1_img).reshape(s1_img.shape[0], -1)
172
+ # s1_mean_var = self.cond_mlp(torch.cat([s0, s1], dim=1))
173
+ s1_mean_var = self.in_mlp(s1)
174
+ s1_mean, s1_logvar = s1_mean_var.chunk(2, dim=1)
175
+ s1_stddev = torch.exp(s1_logvar * 0.5)
176
+ s1_latent = s1_mean + s1_stddev * torch.randn_like(s1_mean)
177
+ return s1_latent, s0, s1_mean, s1_logvar
178
+
179
+ def decode(self, s1_latent: Tensor, s0: Tensor) -> Tensor:
180
+ quant_h = int(self.config.resolution[0] / (2 ** (len(self.config.channels_mult) - 1)))
181
+ quant_w = int(self.config.resolution[1] / (2 ** (len(self.config.channels_mult) - 1)))
182
+ # s1_latent = self.cond_mlp_out(torch.cat([s1_latent, s0], dim=1)).reshape(s1_latent.shape[0], self.config.z_channels, quant_h, quant_w)
183
+ s1_latent = self.out_mlp(s1_latent).reshape(s1_latent.shape[0], self.config.z_channels, quant_h, quant_w)
184
+ return self.decoder(s1_latent)
185
+
186
+ def forward(self,
187
+ s0_img: Tensor,
188
+ s1_img: Tensor,
189
+ action: Tensor,
190
+ return_loss: bool = True,
191
+ return_dict: Optional[bool] = None,
192
+ ) -> Union[Tuple, VAEOutput]:
193
+ return_dict = return_dict if return_dict is not None else False
194
+ s1_latent, s0, s1_mean, s1_logvar = self.encode(s0_img, s1_img, action)
195
+ recon = self.decode(s1_latent, s0)
196
+
197
+ loss = None
198
+ if return_loss:
199
+ # recon loss
200
+ mse_loss = F.mse_loss(recon, s1_img)
201
+ l1_loss = F.l1_loss(recon, s1_img)
202
+ if self.config.w_perceptual > 0:
203
+ perceptual_loss = self.perceptual_loss(recon, s1_img)
204
+ else:
205
+ perceptual_loss = torch.zeros_like(mse_loss).to(mse_loss.device)
206
+ if self.config.w_dino > 0:
207
+ dino_loss = self.dino_loss(s1_img, None)
208
+ else:
209
+ dino_loss = torch.zeros_like(mse_loss).to(mse_loss.device)
210
+ # kl loss
211
+ kl_loss = torch.mean(-0.5 * torch.sum(1 + s1_logvar - s1_mean**2 - s1_logvar.exp(), dim=1))
212
+
213
+ loss = self.config.w_mse * mse_loss + \
214
+ self.config.w_l1 * l1_loss + \
215
+ self.config.w_perceptual * perceptual_loss + \
216
+ self.config.w_dino * dino_loss + \
217
+ self.config.w_kl * kl_loss
218
+ if not return_dict:
219
+ self.log_state["loss"] = loss.item()
220
+ self.log_state["mse_loss"] = mse_loss.item()
221
+ self.log_state["l1_loss"] = l1_loss.item()
222
+ self.log_state["perceptual_loss"] = perceptual_loss.item()
223
+ self.log_state["dino_loss"] = dino_loss.item()
224
+ self.log_state["kl_loss"] = kl_loss.item()
225
+ self.log_state["gt"] = s0_img.clone().detach().cpu()[:4].to(torch.float32)
226
+ self.log_state["recon"] = recon.clone().detach().cpu()[:4].to(torch.float32)
227
+ return ((loss,) + (recon,)) if loss is not None else recon
228
+ return VAEOutput(
229
+ loss=loss,
230
+ reconstruction=recon,
231
+ mse_loss=mse_loss,
232
+ l1_loss=l1_loss,
233
+ perceptual_loss=perceptual_loss,
234
+ dino_loss=dino_loss,
235
+ )
236
+
237
+ def get_last_layer(self):
238
+ raise NotImplementedError
239
+
module_layers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
8
+ super().__init__()
9
+ if mid_channels is None:
10
+ mid_channels = out_channels
11
+ self.conv = nn.Sequential(
12
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
13
+ nn.BatchNorm2d(mid_channels),
14
+ nn.ReLU(inplace=True),
15
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
16
+ nn.BatchNorm2d(out_channels),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ return self.conv(x)
22
+
23
+
24
+ class Down(nn.Module):
25
+ def __init__(self, in_channels: int, out_channels: int):
26
+ super().__init__()
27
+ self.maxpool_conv = nn.Sequential(
28
+ nn.MaxPool2d(2),
29
+ DoubleConv(in_channels, out_channels)
30
+ )
31
+
32
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
33
+ return self.maxpool_conv(x)
34
+
35
+
36
+ class Up(nn.Module):
37
+ def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False):
38
+ super().__init__()
39
+ if bilinear:
40
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
41
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
42
+ else:
43
+ self.up = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
44
+ self.conv = DoubleConv(in_channels, out_channels)
45
+
46
+ def forward(self, x):
47
+ x = self.up(x)
48
+ return self.conv(x)
49
+
50
+
51
+ class Encoder(nn.Module):
52
+ def __init__(self, z_channels: int, in_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs):
53
+ super().__init__()
54
+ self.encoder = nn.ModuleList()
55
+ num_resolutions = len(channels_mult)
56
+ in_ch_mult = (1,) + tuple(channels_mult)
57
+
58
+ self.encoder.append(DoubleConv(in_channels, channels))
59
+ for i_level in range(num_resolutions):
60
+ block_in = channels * in_ch_mult[i_level]
61
+ block_out = channels * channels_mult[i_level]
62
+ if i_level != num_resolutions - 1:
63
+ self.encoder.append(Down(block_in, block_out))
64
+ else:
65
+ self.encoder.append(DoubleConv(block_in, block_out))
66
+ block_in = block_out
67
+ self.encoder.append(nn.Conv2d(block_in, z_channels, kernel_size=(1, 1)))
68
+
69
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
70
+ for layer in self.encoder:
71
+ x = layer(x)
72
+ return x
73
+
74
+
75
+ class Decoder(nn.Module):
76
+ def __init__(self, z_channels: int, out_channels: int, channels: int, channels_mult: list[int], **ignore_kwargs):
77
+ super().__init__()
78
+ self.decoder = nn.ModuleList()
79
+ num_resolutions = len(channels_mult)
80
+
81
+ block_in = channels*channels_mult[num_resolutions-1]
82
+ self.decoder.append(nn.Conv2d(z_channels, block_in, kernel_size=(1, 1)))
83
+ for i_level in reversed(range(num_resolutions)):
84
+ block_out = channels * channels_mult[i_level]
85
+ if i_level != 0:
86
+ self.decoder.append(Up(block_in, block_out))
87
+ else:
88
+ self.decoder.append(DoubleConv(block_in, block_out))
89
+ block_in = block_out
90
+ self.final_conv = nn.Conv2d(block_in, out_channels, kernel_size=1)
91
+
92
+ def forward(self, x):
93
+ for layer in self.decoder:
94
+ x = layer(x)
95
+ return self.final_conv(x)
module_layers_attn.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ # Ref [https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py]
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def nonlinearity(x):
9
+ # swish
10
+ return x*torch.sigmoid(x)
11
+
12
+
13
+ def Normalize(in_channels):
14
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
15
+
16
+
17
+ class Upsample(nn.Module):
18
+ def __init__(self, in_channels, with_conv):
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ if self.with_conv:
22
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
23
+
24
+ def forward(self, x):
25
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
26
+ if self.with_conv:
27
+ x = self.conv(x)
28
+ return x
29
+
30
+
31
+ class Downsample(nn.Module):
32
+ def __init__(self, in_channels, with_conv):
33
+ super().__init__()
34
+ self.with_conv = with_conv
35
+ if self.with_conv:
36
+ # no asymmetric padding in torch conv, must do it ourselves
37
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
38
+
39
+ def forward(self, x):
40
+ if self.with_conv:
41
+ pad = (0, 1, 0, 1)
42
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
43
+ x = self.conv(x)
44
+ else:
45
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
46
+ return x
47
+
48
+
49
+ class ResnetBlock(nn.Module):
50
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
51
+ dropout, temb_channels=512):
52
+ super().__init__()
53
+ self.in_channels = in_channels
54
+ out_channels = in_channels if out_channels is None else out_channels
55
+ self.out_channels = out_channels
56
+ self.use_conv_shortcut = conv_shortcut
57
+
58
+ self.norm1 = Normalize(in_channels)
59
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
60
+ if temb_channels > 0:
61
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
62
+ self.norm2 = Normalize(out_channels)
63
+ self.dropout = torch.nn.Dropout(dropout)
64
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
65
+ if self.in_channels != self.out_channels:
66
+ if self.use_conv_shortcut:
67
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
68
+ else:
69
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
70
+
71
+ def forward(self, x, temb):
72
+ h = x
73
+ h = self.norm1(h)
74
+ h = nonlinearity(h)
75
+ h = self.conv1(h)
76
+
77
+ if temb is not None:
78
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
79
+
80
+ h = self.norm2(h)
81
+ h = nonlinearity(h)
82
+ h = self.dropout(h)
83
+ h = self.conv2(h)
84
+
85
+ if self.in_channels != self.out_channels:
86
+ if self.use_conv_shortcut:
87
+ x = self.conv_shortcut(x)
88
+ else:
89
+ x = self.nin_shortcut(x)
90
+
91
+ return x + h
92
+
93
+
94
+ class AttnBlock(nn.Module):
95
+ def __init__(self, in_channels):
96
+ super().__init__()
97
+
98
+ self.norm = Normalize(in_channels)
99
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
100
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
101
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
102
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
103
+
104
+ def forward(self, x):
105
+ h_ = x
106
+ h_ = self.norm(h_)
107
+ q = self.q(h_)
108
+ k = self.k(h_)
109
+ v = self.v(h_)
110
+
111
+ # compute attention
112
+ b,c,h,w = q.shape
113
+ q = q.reshape(b,c,h*w)
114
+ q = q.permute(0,2,1) # b,hw,c
115
+ k = k.reshape(b,c,h*w) # b,c,hw
116
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
117
+ w_ = w_ * (int(c)**(-0.5))
118
+ w_ = torch.nn.functional.softmax(w_, dim=2)
119
+
120
+ # attend to values
121
+ v = v.reshape(b,c,h*w)
122
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
123
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
124
+ h_ = h_.reshape(b,c,h,w)
125
+
126
+ h_ = self.proj_out(h_)
127
+
128
+ return x + h_
129
+
130
+
131
+ class Encoder(nn.Module):
132
+ def __init__(self,
133
+ in_channels: int,
134
+ channels: int,
135
+ channels_mult: list[int],
136
+ num_res_blocks: int,
137
+ attn_resolutions: int,
138
+ dropout: float,
139
+ resolution: list[int],
140
+ z_channels: int,
141
+ **ignore_kwargs):
142
+ super().__init__()
143
+ self.ch = channels
144
+ self.temb_ch = 0
145
+ self.num_resolutions = len(channels_mult)
146
+ self.num_res_blocks = num_res_blocks
147
+ self.resolution = resolution
148
+ self.in_channels = in_channels
149
+
150
+ # downsampling
151
+ self.conv_in = torch.nn.Conv2d(in_channels,
152
+ self.ch,
153
+ kernel_size=3,
154
+ stride=1,
155
+ padding=1)
156
+
157
+ curr_res = resolution if isinstance(resolution, int) else resolution[0]
158
+ in_ch_mult = (1,)+tuple(channels_mult)
159
+ self.down = nn.ModuleList()
160
+ for i_level in range(self.num_resolutions):
161
+ block = nn.ModuleList()
162
+ attn = nn.ModuleList()
163
+ block_in = channels*in_ch_mult[i_level]
164
+ block_out = channels*channels_mult[i_level]
165
+ for i_block in range(self.num_res_blocks):
166
+ block.append(ResnetBlock(in_channels=block_in,
167
+ out_channels=block_out,
168
+ temb_channels=self.temb_ch,
169
+ dropout=dropout))
170
+ block_in = block_out
171
+ if curr_res in attn_resolutions:
172
+ attn.append(AttnBlock(block_in))
173
+ down = nn.Module()
174
+ down.block = block
175
+ down.attn = attn
176
+ if i_level != self.num_resolutions-1:
177
+ down.downsample = Downsample(block_in, True)
178
+ curr_res = curr_res // 2
179
+ self.down.append(down)
180
+
181
+ # middle
182
+ self.mid = nn.Module()
183
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
184
+ out_channels=block_in,
185
+ temb_channels=self.temb_ch,
186
+ dropout=dropout)
187
+ self.mid.attn_1 = AttnBlock(block_in)
188
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
189
+ out_channels=block_in,
190
+ temb_channels=self.temb_ch,
191
+ dropout=dropout)
192
+
193
+ # end
194
+ self.norm_out = Normalize(block_in)
195
+ self.conv_out = torch.nn.Conv2d(block_in,
196
+ z_channels,
197
+ kernel_size=3,
198
+ stride=1,
199
+ padding=1)
200
+
201
+
202
+ def forward(self, x):
203
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
204
+
205
+ # timestep embedding
206
+ temb = None
207
+
208
+ # downsampling
209
+ hs = [self.conv_in(x)]
210
+ for i_level in range(self.num_resolutions):
211
+ for i_block in range(self.num_res_blocks):
212
+ h = self.down[i_level].block[i_block](hs[-1], temb)
213
+ if len(self.down[i_level].attn) > 0:
214
+ h = self.down[i_level].attn[i_block](h)
215
+ hs.append(h)
216
+ if i_level != self.num_resolutions-1:
217
+ hs.append(self.down[i_level].downsample(hs[-1]))
218
+
219
+ # middle
220
+ h = hs[-1]
221
+ h = self.mid.block_1(h, temb)
222
+ h = self.mid.attn_1(h)
223
+ h = self.mid.block_2(h, temb)
224
+
225
+ # end
226
+ h = self.norm_out(h)
227
+ h = nonlinearity(h)
228
+ h = self.conv_out(h)
229
+ return h
230
+
231
+
232
+ class Decoder(nn.Module):
233
+ def __init__(self,
234
+ out_channels:int,
235
+ channels: int,
236
+ channels_mult: list[int],
237
+ num_res_blocks: int,
238
+ attn_resolutions: list[int],
239
+ dropout: float,
240
+ resolution: list[int],
241
+ z_channels: int,
242
+ **ignorekwargs):
243
+ super().__init__()
244
+ self.ch = channels
245
+ self.temb_ch = 0
246
+ self.num_resolutions = len(channels_mult)
247
+ self.num_res_blocks = num_res_blocks
248
+ self.resolution = resolution
249
+
250
+ # compute in_ch_mult, block_in and curr_res at lowest res
251
+ in_ch_mult = (1,)+tuple(channels_mult)
252
+ block_in = channels*channels_mult[self.num_resolutions-1]
253
+ curr_res = resolution if isinstance(resolution, int) else resolution[0]
254
+ curr_res = curr_res // 2**(self.num_resolutions-1)
255
+ self.z_shape = (1,z_channels,curr_res,curr_res)
256
+ # print("Working with z of shape {} = {} dimensions.".format(
257
+ # self.z_shape, np.prod(self.z_shape)))
258
+
259
+ # z to block_in
260
+ self.conv_in = torch.nn.Conv2d(z_channels,
261
+ block_in,
262
+ kernel_size=3,
263
+ stride=1,
264
+ padding=1)
265
+
266
+ # middle
267
+ self.mid = nn.Module()
268
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
269
+ out_channels=block_in,
270
+ temb_channels=self.temb_ch,
271
+ dropout=dropout)
272
+ self.mid.attn_1 = AttnBlock(block_in)
273
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+
278
+ # upsampling
279
+ self.up = nn.ModuleList()
280
+ for i_level in reversed(range(self.num_resolutions)):
281
+ block = nn.ModuleList()
282
+ attn = nn.ModuleList()
283
+ block_out = channels*channels_mult[i_level]
284
+ for i_block in range(self.num_res_blocks+1):
285
+ block.append(ResnetBlock(in_channels=block_in,
286
+ out_channels=block_out,
287
+ temb_channels=self.temb_ch,
288
+ dropout=dropout))
289
+ block_in = block_out
290
+ if curr_res in attn_resolutions:
291
+ attn.append(AttnBlock(block_in))
292
+ up = nn.Module()
293
+ up.block = block
294
+ up.attn = attn
295
+ if i_level != 0:
296
+ up.upsample = Upsample(block_in, True)
297
+ curr_res = curr_res * 2
298
+ self.up.insert(0, up) # prepend to get consistent order
299
+
300
+ # end
301
+ self.norm_out = Normalize(block_in)
302
+ self.conv_out = torch.nn.Conv2d(block_in,
303
+ out_channels,
304
+ kernel_size=3,
305
+ stride=1,
306
+ padding=1)
307
+
308
+ def forward(self, z):
309
+ #assert z.shape[1:] == self.z_shape[1:]
310
+ self.last_z_shape = z.shape
311
+
312
+ # timestep embedding
313
+ temb = None
314
+
315
+ # z to block_in
316
+ h = self.conv_in(z)
317
+
318
+ # middle
319
+ h = self.mid.block_1(h, temb)
320
+ h = self.mid.attn_1(h)
321
+ h = self.mid.block_2(h, temb)
322
+
323
+ # upsampling
324
+ for i_level in reversed(range(self.num_resolutions)):
325
+ for i_block in range(self.num_res_blocks+1):
326
+ h = self.up[i_level].block[i_block](h, temb)
327
+ if len(self.up[i_level].attn) > 0:
328
+ h = self.up[i_level].attn[i_block](h)
329
+ if i_level != 0:
330
+ h = self.up[i_level].upsample(h)
331
+
332
+ h = self.norm_out(h)
333
+ h = nonlinearity(h)
334
+ h = self.conv_out(h)
335
+ return h
my_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "output_dir": "logs/vae_test",
3
+ "overwrite_output_dir": true,
4
+ "model_type": "vae",
5
+ "report_to": [
6
+ "wandb"
7
+ ],
8
+ "wandb_project_name": "train_vae",
9
+ "run_name": "train_vae",
10
+ "num_train_epochs": 0.1,
11
+ "logging_strategy": "steps",
12
+ "logging_steps": 0.01,
13
+ "save_strategy": "epoch",
14
+ "save_steps": 1,
15
+ "eval_strategy": "no",
16
+ "eval_steps": 0.1,
17
+ "do_train": true,
18
+ "do_eval": false,
19
+ "resume_from_checkpoint": null,
20
+ "remove_unused_columns": false,
21
+ "per_device_train_batch_size": 128,
22
+ "per_device_eval_batch_size": 32,
23
+ "gradient_accumulation_steps": 1,
24
+ "max_grad_norm": 5.0,
25
+ "bf16": true,
26
+ "fp16": false,
27
+ "use_cpu": false,
28
+ "save_only_model": false,
29
+ "adam_beta1": 0.5,
30
+ "adam_beta2": 0.9,
31
+ "learning_rate": 0.0002,
32
+ "weight_decay": 0.01,
33
+ "warmup_steps": 1000,
34
+ "lr_scheduler_type": "cosine_with_min_lr",
35
+ "lr_scheduler_kwargs": {
36
+ "min_lr": 1e-05
37
+ },
38
+ "train_file": "/home/pj24002027/ku50001104/data/mutual_dataset/few_data/train.jsonl",
39
+ "validation_file": "/home/pj24002027/ku50001104/data/mutual_dataset/few_data/test.jsonl",
40
+ "max_eval_samples": 2048,
41
+ "max_train_samples": 2048,
42
+ "exp_setup": null,
43
+ "use_sensor_keys": "CAM_FRONT",
44
+ "dataloader_num_workers": 1,
45
+ "disable_tqdm": true,
46
+ "config_overrides": {
47
+ "encoder_type": "Simple",
48
+ "decoder_type": "Simple",
49
+ "quantizer_type": "VQ",
50
+ "resolution": [
51
+ 64,
52
+ 64
53
+ ],
54
+ "z_channels": 64,
55
+ "codebook_dim": 0,
56
+ "codebook_size": 0,
57
+ "num_res_blocks": 2,
58
+ "channels": 128,
59
+ "channels_mult": [
60
+ 1,
61
+ 1,
62
+ 1,
63
+ 2,
64
+ 2
65
+ ],
66
+ "attn_resolutions": [],
67
+ "drop_out": 0,
68
+ "image_mean": [
69
+ 0.5,
70
+ 0.5,
71
+ 0.5
72
+ ],
73
+ "image_std": [
74
+ 0.5,
75
+ 0.5,
76
+ 0.5
77
+ ],
78
+ "w_mse": 2,
79
+ "w_l1": 0.2,
80
+ "w_perceptual": 0,
81
+ "w_commit": 0,
82
+ "w_dino": 0,
83
+ "w_kl": 1
84
+ }
85
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_vae.VAEImageProcessor"
4
+ },
5
+ "do_normalize": true,
6
+ "do_rescale": true,
7
+ "do_resize": true,
8
+ "image_mean": [
9
+ 0.5,
10
+ 0.5,
11
+ 0.5
12
+ ],
13
+ "image_processor_type": "VAEImageProcessor",
14
+ "image_size": [
15
+ 64,
16
+ 64
17
+ ],
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "rescale_factor": 0.00392156862745098
24
+ }
train_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 0.125,
3
+ "train_loss": 1.1237460374832153,
4
+ "train_runtime": 20.6388,
5
+ "train_samples_per_second": 9.923,
6
+ "train_steps_per_second": 0.097
7
+ }
trainer_state.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.125,
6
+ "eval_steps": 1,
7
+ "global_step": 2,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.0625,
14
+ "grad_norm": 4.648393154144287,
15
+ "learning_rate": 0.0,
16
+ "loss": 1.1228,
17
+ "step": 1
18
+ },
19
+ {
20
+ "epoch": 0.125,
21
+ "grad_norm": 4.56954288482666,
22
+ "learning_rate": 2.0000000000000002e-07,
23
+ "loss": 1.1247,
24
+ "step": 2
25
+ },
26
+ {
27
+ "epoch": 0.125,
28
+ "step": 2,
29
+ "total_flos": 147645328785408.0,
30
+ "train_loss": 1.1237460374832153,
31
+ "train_runtime": 20.6388,
32
+ "train_samples_per_second": 9.923,
33
+ "train_steps_per_second": 0.097
34
+ }
35
+ ],
36
+ "logging_steps": 1,
37
+ "max_steps": 2,
38
+ "num_input_tokens_seen": 0,
39
+ "num_train_epochs": 1,
40
+ "save_steps": 1,
41
+ "stateful_callbacks": {
42
+ "TrainerControl": {
43
+ "args": {
44
+ "should_epoch_stop": false,
45
+ "should_evaluate": false,
46
+ "should_log": false,
47
+ "should_save": true,
48
+ "should_training_stop": true
49
+ },
50
+ "attributes": {}
51
+ }
52
+ },
53
+ "total_flos": 147645328785408.0,
54
+ "train_batch_size": 128,
55
+ "trial_name": null,
56
+ "trial_params": null
57
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:494f49e8ffe42f196e661babfa0f4516d40ccf2f9a923a613986c80ce7a70477
3
+ size 5368