|
|
from typing import Optional, Union, Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import Tensor |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.utils import logging, ModelOutput |
|
|
|
|
|
from torchvision.models import vgg16, VGG16_Weights |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from einops import rearrange |
|
|
|
|
|
from .configuration_vae import VAEConfig, EncoderType, DecoderType |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VAEOutput(ModelOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
reconstruction: torch.FloatTensor = None |
|
|
mse_loss: Optional[torch.FloatTensor] = None |
|
|
l1_loss: Optional[torch.FloatTensor] = None |
|
|
perceptual_loss: Optional[torch.FloatTensor] = None |
|
|
dino_loss: Optional[torch.FloatTensor] = None |
|
|
kl_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class Vgg16(nn.Module): |
|
|
|
|
|
def __init__(self, layers): |
|
|
super().__init__() |
|
|
features = vgg16(weights=VGG16_Weights.DEFAULT).features |
|
|
self.to_relu_1_2 = nn.Sequential() |
|
|
self.to_relu_2_2 = nn.Sequential() |
|
|
self.to_relu_3_3 = nn.Sequential() |
|
|
self.to_relu_4_3 = nn.Sequential() |
|
|
|
|
|
for x in range(4): |
|
|
self.to_relu_1_2.add_module(str(x), features[x]) |
|
|
for x in range(4, 9): |
|
|
self.to_relu_2_2.add_module(str(x), features[x]) |
|
|
for x in range(9, 16): |
|
|
self.to_relu_3_3.add_module(str(x), features[x]) |
|
|
for x in range(16, 23): |
|
|
self.to_relu_4_3.add_module(str(x), features[x]) |
|
|
|
|
|
|
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
h = self.to_relu_1_2(x) |
|
|
h_relu_1_2 = h |
|
|
h = self.to_relu_2_2(h) |
|
|
h_relu_2_2 = h |
|
|
h = self.to_relu_3_3(h) |
|
|
h_relu_3_3 = h |
|
|
h = self.to_relu_4_3(h) |
|
|
h_relu_4_3 = h |
|
|
out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3) |
|
|
return out |
|
|
|
|
|
|
|
|
class PerceptualLoss(nn.Module): |
|
|
def __init__(self, layers=(3, 8, 15, 22), unnorm_mean=None, unnorm_std=None, weights=None): |
|
|
super().__init__() |
|
|
self.vgg = Vgg16(layers=layers) |
|
|
self.layers = layers |
|
|
self.weights = weights or [1.0 / len(layers)] * len(layers) |
|
|
|
|
|
def forward(self, x, y): |
|
|
x_vgg = self.vgg(x) |
|
|
y_vgg = self.vgg(y) |
|
|
loss = 0.0 |
|
|
for x_vgg_layer, y_vgg_layer in zip(x_vgg, y_vgg): |
|
|
loss += F.mse_loss(x_vgg_layer, y_vgg_layer) |
|
|
return loss |
|
|
|
|
|
class DinoLoss(nn.Module): |
|
|
def __init__(self, patch_size, use_large=False): |
|
|
super().__init__() |
|
|
size = 'b' if use_large else 's' |
|
|
dino = f'dino_vit{size}{patch_size}' |
|
|
self.vit = torch.hub.load('facebookresearch/dino:main', dino) |
|
|
print('use ', dino) |
|
|
self.vit.eval() |
|
|
for param in self.vit.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, gt, embed): |
|
|
with torch.no_grad(): |
|
|
dino_features = self.vit.prepare_tokens(gt) |
|
|
for blk in self.vit.blocks: |
|
|
dino_features = blk(dino_features) |
|
|
dino_features = self.vit.norm(dino_features) |
|
|
dino_features = dino_features[:, 1:] |
|
|
embed_features = rearrange(embed, 'b c h w -> b (h w) c').contiguous() |
|
|
dtype = embed.dtype |
|
|
dino_loss = 1 - F.cosine_similarity(dino_features.to(torch.float32), embed_features.to(torch.float32), dim=2) |
|
|
dino_loss = dino_loss.mean() |
|
|
dino_loss = dino_loss.to(dtype) |
|
|
return dino_loss |
|
|
|
|
|
|
|
|
class VAEModel(PreTrainedModel): |
|
|
config_class = VAEConfig |
|
|
main_input_name = "s0_img" |
|
|
|
|
|
def __init__(self, config: VAEConfig): |
|
|
super().__init__(config) |
|
|
dict_config = config.to_dict() |
|
|
self.encoder = EncoderType[config.encoder_type].value(**dict_config) |
|
|
enc_out_dim = self.config.z_channels * (self.config.resolution[0] // (2 ** (len(self.config.channels_mult) - 1))) ** 2 |
|
|
latent_dim = 64 |
|
|
self.cond_mlp = nn.Sequential( |
|
|
nn.Linear(enc_out_dim * 2, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, latent_dim * 2), |
|
|
) |
|
|
self.in_mlp = nn.Sequential( |
|
|
nn.Linear(enc_out_dim, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, latent_dim * 2), |
|
|
) |
|
|
self.cond_mlp_out = nn.Sequential( |
|
|
nn.Linear(latent_dim + enc_out_dim, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, enc_out_dim), |
|
|
) |
|
|
self.out_mlp = nn.Sequential( |
|
|
nn.Linear(latent_dim, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, config.z_channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.z_channels, enc_out_dim), |
|
|
) |
|
|
self.decoder = DecoderType[config.decoder_type].value(**dict_config) |
|
|
if config.w_perceptual > 0: |
|
|
self.perceptual_loss = PerceptualLoss( |
|
|
unnorm_mean=config.image_mean, |
|
|
unnorm_std=config.image_std |
|
|
) |
|
|
if config.w_dino > 0: |
|
|
assert config.z_channels in [384, 768] |
|
|
patch_size = 2 ** (len(config.channels_mult) - 1) |
|
|
self.dino_loss = DinoLoss(patch_size=patch_size) |
|
|
self.log_state = { |
|
|
"loss": None, |
|
|
"mse_loss": None, |
|
|
"l1_loss": None, |
|
|
"perceptual_loss": None, |
|
|
"dino_loss": None, |
|
|
"gt": None, |
|
|
"recon": None, |
|
|
} |
|
|
self.post_init() |
|
|
|
|
|
def encode(self, s0_img: Tensor, s1_img: Tensor, a0: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: |
|
|
|
|
|
s0 = None |
|
|
s1 = self.encoder(s1_img).reshape(s1_img.shape[0], -1) |
|
|
|
|
|
s1_mean_var = self.in_mlp(s1) |
|
|
s1_mean, s1_logvar = s1_mean_var.chunk(2, dim=1) |
|
|
s1_stddev = torch.exp(s1_logvar * 0.5) |
|
|
s1_latent = s1_mean + s1_stddev * torch.randn_like(s1_mean) |
|
|
return s1_latent, s0, s1_mean, s1_logvar |
|
|
|
|
|
def decode(self, s1_latent: Tensor, s0: Tensor) -> Tensor: |
|
|
quant_h = int(self.config.resolution[0] / (2 ** (len(self.config.channels_mult) - 1))) |
|
|
quant_w = int(self.config.resolution[1] / (2 ** (len(self.config.channels_mult) - 1))) |
|
|
|
|
|
s1_latent = self.out_mlp(s1_latent).reshape(s1_latent.shape[0], self.config.z_channels, quant_h, quant_w) |
|
|
return self.decoder(s1_latent) |
|
|
|
|
|
def forward(self, |
|
|
s0_img: Tensor, |
|
|
s1_img: Tensor, |
|
|
action: Tensor, |
|
|
return_loss: bool = True, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, VAEOutput]: |
|
|
return_dict = return_dict if return_dict is not None else False |
|
|
s1_latent, s0, s1_mean, s1_logvar = self.encode(s0_img, s1_img, action) |
|
|
recon = self.decode(s1_latent, s0) |
|
|
|
|
|
loss = None |
|
|
if return_loss: |
|
|
|
|
|
mse_loss = F.mse_loss(recon, s1_img) |
|
|
l1_loss = F.l1_loss(recon, s1_img) |
|
|
if self.config.w_perceptual > 0: |
|
|
perceptual_loss = self.perceptual_loss(recon, s1_img) |
|
|
else: |
|
|
perceptual_loss = torch.zeros_like(mse_loss).to(mse_loss.device) |
|
|
if self.config.w_dino > 0: |
|
|
dino_loss = self.dino_loss(s1_img, None) |
|
|
else: |
|
|
dino_loss = torch.zeros_like(mse_loss).to(mse_loss.device) |
|
|
|
|
|
kl_loss = torch.mean(-0.5 * torch.sum(1 + s1_logvar - s1_mean**2 - s1_logvar.exp(), dim=1)) |
|
|
|
|
|
loss = self.config.w_mse * mse_loss + \ |
|
|
self.config.w_l1 * l1_loss + \ |
|
|
self.config.w_perceptual * perceptual_loss + \ |
|
|
self.config.w_dino * dino_loss + \ |
|
|
self.config.w_kl * kl_loss |
|
|
if not return_dict: |
|
|
self.log_state["loss"] = loss.item() |
|
|
self.log_state["mse_loss"] = mse_loss.item() |
|
|
self.log_state["l1_loss"] = l1_loss.item() |
|
|
self.log_state["perceptual_loss"] = perceptual_loss.item() |
|
|
self.log_state["dino_loss"] = dino_loss.item() |
|
|
self.log_state["kl_loss"] = kl_loss.item() |
|
|
self.log_state["gt"] = s0_img.clone().detach().cpu()[:4].to(torch.float32) |
|
|
self.log_state["recon"] = recon.clone().detach().cpu()[:4].to(torch.float32) |
|
|
return ((loss,) + (recon,)) if loss is not None else recon |
|
|
return VAEOutput( |
|
|
loss=loss, |
|
|
reconstruction=recon, |
|
|
mse_loss=mse_loss, |
|
|
l1_loss=l1_loss, |
|
|
perceptual_loss=perceptual_loss, |
|
|
dino_loss=dino_loss, |
|
|
) |
|
|
|
|
|
def get_last_layer(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
|