Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Callable | |
| from functools import partial | |
| import numpy as np | |
| from omegaconf import II | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from fairseq.modules import EMAModule, EMAModuleConfig | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.models import BaseFairseqModel, register_model | |
| from examples.data2vec.data.modality import Modality | |
| from examples.data2vec.models.modalities.base import ( | |
| MaskSeed, | |
| D2vModalityConfig, | |
| ModalitySpecificEncoder, | |
| get_annealed_rate, | |
| ) | |
| from examples.data2vec.models.modalities.modules import ( | |
| D2vDecoderConfig, | |
| AltBlock, | |
| Decoder1d, | |
| ) | |
| from .modalities.audio import ( | |
| D2vAudioConfig, | |
| AudioEncoder, | |
| ) | |
| from examples.data2vec.models.modalities.images import ( | |
| D2vImageConfig, | |
| ImageEncoder, | |
| ) | |
| from examples.data2vec.models.modalities.text import ( | |
| D2vTextConfig, | |
| TextEncoder, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class D2vModalitiesConfig(FairseqDataclass): | |
| audio: D2vAudioConfig = D2vAudioConfig() | |
| image: D2vImageConfig = D2vImageConfig() | |
| text: D2vTextConfig = D2vTextConfig() | |
| class Data2VecMultiConfig(FairseqDataclass): | |
| loss_beta: float = field( | |
| default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} | |
| ) | |
| loss_scale: Optional[float] = field( | |
| default=None, | |
| metadata={ | |
| "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" | |
| }, | |
| ) | |
| input_feature_ndim: int = 40 | |
| depth: int = 8 | |
| start_drop_path_rate: float = 0 | |
| end_drop_path_rate: float = 0 | |
| num_heads: int = 12 | |
| norm_eps: float = 1e-6 | |
| norm_affine: bool = True | |
| encoder_dropout: float = 0.1 | |
| post_mlp_drop: float = 0.1 | |
| attention_dropout: float = 0.1 | |
| activation_dropout: float = 0.0 | |
| dropout_input: float = 0.0 | |
| layerdrop: float = 0.0 | |
| embed_dim: int = 768 | |
| mlp_ratio: float = 4 | |
| layer_norm_first: bool = False | |
| average_top_k_layers: int = field( | |
| default=8, metadata={"help": "how many layers to average"} | |
| ) | |
| end_of_block_targets: bool = False | |
| clone_batch: int = 1 | |
| layer_norm_target_layer: bool = False | |
| batch_norm_target_layer: bool = False | |
| instance_norm_target_layer: bool = False | |
| instance_norm_targets: bool = False | |
| layer_norm_targets: bool = False | |
| ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) | |
| ema_same_dtype: bool = True | |
| log_norms: bool = True | |
| ema_end_decay: float = field( | |
| default=0.9999, metadata={"help": "final ema decay rate"} | |
| ) | |
| # when to finish annealing ema decay rate | |
| ema_anneal_end_step: int = II("optimization.max_update") | |
| ema_encoder_only: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "whether to momentum update only the shared transformer encoder" | |
| }, | |
| ) | |
| max_update: int = II("optimization.max_update") | |
| modalities: D2vModalitiesConfig = D2vModalitiesConfig() | |
| shared_decoder: Optional[D2vDecoderConfig] = None | |
| min_target_var: float = field( | |
| default=0.1, metadata={"help": "stop training if target var falls below this"} | |
| ) | |
| min_pred_var: float = field( | |
| default=0.01, | |
| metadata={"help": "stop training if prediction var falls below this"}, | |
| ) | |
| supported_modality: Optional[Modality] = None | |
| mae_init: bool = False | |
| seed: int = II("common.seed") | |
| skip_ema: bool = False | |
| cls_loss: float = 0 | |
| recon_loss: float = 0 | |
| d2v_loss: float = 1 | |
| decoder_group: bool = False | |
| class Data2VecMultiModel(BaseFairseqModel): | |
| def make_modality_encoder( | |
| self, | |
| cfg: D2vModalityConfig, | |
| embed_dim: int, | |
| make_block: Callable[[float], nn.ModuleList], | |
| norm_layer: Callable[[int], nn.LayerNorm], | |
| layer_norm_first: bool, | |
| alibi_biases, | |
| task, | |
| ) -> ModalitySpecificEncoder: | |
| if cfg.type == Modality.AUDIO: | |
| enc_cls = AudioEncoder | |
| elif cfg.type == Modality.IMAGE: | |
| enc_cls = ImageEncoder | |
| elif cfg.type == Modality.TEXT: | |
| enc_cls = TextEncoder | |
| if hasattr(task, "text_task"): | |
| task = task.text_task | |
| else: | |
| raise Exception(f"unsupported modality {cfg.type}") | |
| return enc_cls( | |
| cfg, | |
| embed_dim, | |
| make_block, | |
| norm_layer, | |
| layer_norm_first, | |
| alibi_biases, | |
| task, | |
| ) | |
| def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.modalities = modalities | |
| self.task = task | |
| make_layer_norm = partial( | |
| nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine | |
| ) | |
| def make_block(drop_path, dim=None, heads=None): | |
| return AltBlock( | |
| cfg.embed_dim if dim is None else dim, | |
| cfg.num_heads if heads is None else heads, | |
| cfg.mlp_ratio, | |
| qkv_bias=True, | |
| drop=cfg.encoder_dropout, | |
| attn_drop=cfg.attention_dropout, | |
| mlp_drop=cfg.activation_dropout, | |
| post_mlp_drop=cfg.post_mlp_drop, | |
| drop_path=drop_path, | |
| norm_layer=make_layer_norm, | |
| layer_norm_first=cfg.layer_norm_first, | |
| ffn_targets=not cfg.end_of_block_targets, | |
| ) | |
| self.alibi_biases = {} | |
| self.modality_encoders = nn.ModuleDict() | |
| for mod in self.modalities: | |
| mod_cfg = getattr(cfg.modalities, mod.name.lower()) | |
| enc = self.make_modality_encoder( | |
| mod_cfg, | |
| cfg.embed_dim, | |
| make_block, | |
| make_layer_norm, | |
| cfg.layer_norm_first, | |
| self.alibi_biases, | |
| task, | |
| ) | |
| self.modality_encoders[mod.name] = enc | |
| self.ema = None | |
| self.average_top_k_layers = cfg.average_top_k_layers | |
| self.loss_beta = cfg.loss_beta | |
| self.loss_scale = cfg.loss_scale | |
| self.dropout_input = nn.Dropout(cfg.dropout_input) | |
| dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) | |
| self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) | |
| self.norm = None | |
| if cfg.layer_norm_first: | |
| self.norm = make_layer_norm(cfg.embed_dim) | |
| if self.cfg.mae_init: | |
| self.apply(self._init_weights) | |
| else: | |
| from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
| self.apply(init_bert_params) | |
| for mod_enc in self.modality_encoders.values(): | |
| mod_enc.reset_parameters() | |
| if not skip_ema: | |
| self.ema = self.make_ema_teacher(cfg.ema_decay) | |
| self.shared_decoder = ( | |
| Decoder1d(cfg.shared_decoder, cfg.embed_dim) | |
| if self.cfg.shared_decoder is not None | |
| else None | |
| ) | |
| if self.shared_decoder is not None: | |
| self.shared_decoder.apply(self._init_weights) | |
| self.recon_proj = None | |
| if cfg.recon_loss > 0: | |
| self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) | |
| for pn, p in self.named_parameters(): | |
| if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn: | |
| p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} | |
| if cfg.decoder_group and "decoder" in pn: | |
| p.param_group = "decoder" | |
| self.num_updates = 0 | |
| def _init_weights(self, m): | |
| try: | |
| from apex.normalization import FusedLayerNorm | |
| fn = FusedLayerNorm | |
| except: | |
| fn = nn.LayerNorm | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm) or isinstance(m, fn): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| def make_ema_teacher(self, ema_decay): | |
| ema_config = EMAModuleConfig( | |
| ema_decay=ema_decay, | |
| ema_fp32=True, | |
| log_norms=self.cfg.log_norms, | |
| add_missing_params=False, | |
| ) | |
| model_copy = self.make_target_model() | |
| return EMAModule( | |
| model_copy, | |
| ema_config, | |
| copy_model=False, | |
| ) | |
| def make_target_model(self): | |
| logger.info("making target model") | |
| model_copy = Data2VecMultiModel( | |
| self.cfg, self.modalities, skip_ema=True, task=self.task | |
| ) | |
| if self.cfg.ema_encoder_only: | |
| model_copy = model_copy.blocks | |
| for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()): | |
| p_t.data.copy_(p_s.data) | |
| else: | |
| for p_s, p_t in zip(self.parameters(), model_copy.parameters()): | |
| p_t.data.copy_(p_s.data) | |
| for mod_enc in model_copy.modality_encoders.values(): | |
| mod_enc.decoder = None | |
| if not mod_enc.modality_cfg.ema_local_encoder: | |
| mod_enc.local_encoder = None | |
| mod_enc.project_features = None | |
| model_copy.requires_grad_(False) | |
| return model_copy | |
| def set_num_updates(self, num_updates): | |
| super().set_num_updates(num_updates) | |
| if self.ema is not None and ( | |
| (self.num_updates == 0 and num_updates > 1) | |
| or self.num_updates >= num_updates | |
| ): | |
| pass | |
| elif self.training and self.ema is not None: | |
| ema_weight_decay = None | |
| if self.cfg.ema_decay != self.cfg.ema_end_decay: | |
| if num_updates >= self.cfg.ema_anneal_end_step: | |
| decay = self.cfg.ema_end_decay | |
| else: | |
| decay = get_annealed_rate( | |
| self.cfg.ema_decay, | |
| self.cfg.ema_end_decay, | |
| num_updates, | |
| self.cfg.ema_anneal_end_step, | |
| ) | |
| self.ema.set_decay(decay, weight_decay=ema_weight_decay) | |
| if self.ema.get_decay() < 1: | |
| self.ema.step(self.blocks if self.cfg.ema_encoder_only else self) | |
| self.num_updates = num_updates | |
| def state_dict(self, destination=None, prefix="", keep_vars=False): | |
| state = super().state_dict(destination, prefix, keep_vars) | |
| if self.ema is not None: | |
| state[prefix + "_ema"] = self.ema.fp32_params | |
| return state | |
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
| k = prefix + "_ema" | |
| if self.ema is not None: | |
| assert k in state_dict | |
| self.ema.restore(state_dict[k], True) | |
| del state_dict[k] | |
| elif k in state_dict: | |
| del state_dict[k] | |
| return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
| def build_model(cls, cfg: Data2VecMultiConfig, task=None): | |
| """Build a new model instance.""" | |
| if task is None or not hasattr(task, "supported_modalities"): | |
| modalities = ( | |
| [cfg.supported_modality] | |
| if cfg.supported_modality is not None | |
| else [ | |
| Modality.AUDIO, | |
| Modality.IMAGE, | |
| Modality.TEXT, | |
| ] | |
| ) | |
| else: | |
| modalities = task.supported_modalities | |
| return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema) | |
| def forward( | |
| self, | |
| source, | |
| target=None, | |
| id=None, | |
| mode=None, | |
| padding_mask=None, | |
| mask=True, | |
| features_only=False, | |
| force_remove_masked=False, | |
| remove_extra_tokens=True, | |
| precomputed_mask=None, | |
| corpus_key=None, # for config compatiblity | |
| ): | |
| if mode is None: | |
| assert self.cfg.supported_modality is not None | |
| mode = self.cfg.supported_modality | |
| if isinstance(mode, Modality): | |
| mode = mode.name | |
| feature_extractor = self.modality_encoders[mode] | |
| mask_seeds = None | |
| if id is not None: | |
| mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id) | |
| extractor_out = feature_extractor( | |
| source, | |
| padding_mask, | |
| mask, | |
| remove_masked=not features_only or force_remove_masked, | |
| clone_batch=self.cfg.clone_batch if not features_only else 1, | |
| mask_seeds=mask_seeds, | |
| precomputed_mask=precomputed_mask, | |
| ) | |
| x = extractor_out["x"] | |
| encoder_mask = extractor_out["encoder_mask"] | |
| masked_padding_mask = extractor_out["padding_mask"] | |
| masked_alibi_bias = extractor_out.get("alibi_bias", None) | |
| alibi_scale = extractor_out.get("alibi_scale", None) | |
| if self.dropout_input is not None: | |
| x = self.dropout_input(x) | |
| layer_results = [] | |
| for i, blk in enumerate(self.blocks): | |
| if ( | |
| not self.training | |
| or self.cfg.layerdrop == 0 | |
| or (np.random.random() > self.cfg.layerdrop) | |
| ): | |
| ab = masked_alibi_bias | |
| if ab is not None and alibi_scale is not None: | |
| scale = ( | |
| alibi_scale[i] | |
| if alibi_scale.size(0) > 1 | |
| else alibi_scale.squeeze(0) | |
| ) | |
| ab = ab * scale.type_as(ab) | |
| x, lr = blk( | |
| x, | |
| padding_mask=masked_padding_mask, | |
| alibi_bias=ab, | |
| ) | |
| if features_only: | |
| layer_results.append((x, lr)) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| if features_only: | |
| if remove_extra_tokens: | |
| x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] | |
| if masked_padding_mask is not None: | |
| masked_padding_mask = masked_padding_mask[ | |
| :, feature_extractor.modality_cfg.num_extra_tokens : | |
| ] | |
| return { | |
| "x": x, | |
| "padding_mask": masked_padding_mask, | |
| "layer_results": layer_results, | |
| "mask": encoder_mask, | |
| } | |
| xs = [] | |
| if self.shared_decoder is not None: | |
| dx = self.forward_decoder( | |
| x, | |
| feature_extractor, | |
| self.shared_decoder, | |
| encoder_mask, | |
| ) | |
| xs.append(dx) | |
| if feature_extractor.decoder is not None: | |
| dx = self.forward_decoder( | |
| x, | |
| feature_extractor, | |
| feature_extractor.decoder, | |
| encoder_mask, | |
| ) | |
| xs.append(dx) | |
| orig_x = x | |
| assert len(xs) > 0 | |
| p = next(self.ema.model.parameters()) | |
| device = x.device | |
| dtype = x.dtype | |
| ema_device = p.device | |
| ema_dtype = p.dtype | |
| if not self.cfg.ema_same_dtype: | |
| dtype = ema_dtype | |
| if ema_device != device or ema_dtype != dtype: | |
| logger.info(f"adjusting ema dtype to {dtype} and device to {device}") | |
| self.ema.model = self.ema.model.to(dtype=dtype, device=device) | |
| ema_dtype = dtype | |
| def to_device(d): | |
| for k, p in d.items(): | |
| if isinstance(d[k], dict): | |
| to_device(d[k]) | |
| else: | |
| d[k] = p.to(device=device) | |
| to_device(self.ema.fp32_params) | |
| tm = self.ema.model | |
| with torch.no_grad(): | |
| tm.eval() | |
| if self.cfg.ema_encoder_only: | |
| assert target is None | |
| ema_input = extractor_out["local_features"] | |
| ema_input = feature_extractor.contextualized_features( | |
| ema_input.to(dtype=ema_dtype), | |
| padding_mask, | |
| mask=False, | |
| remove_masked=False, | |
| ) | |
| ema_blocks = tm | |
| else: | |
| ema_blocks = tm.blocks | |
| if feature_extractor.modality_cfg.ema_local_encoder: | |
| inp = ( | |
| target.to(dtype=ema_dtype) | |
| if target is not None | |
| else source.to(dtype=ema_dtype) | |
| ) | |
| ema_input = tm.modality_encoders[mode]( | |
| inp, | |
| padding_mask, | |
| mask=False, | |
| remove_masked=False, | |
| ) | |
| else: | |
| assert target is None | |
| ema_input = extractor_out["local_features"] | |
| ema_feature_enc = tm.modality_encoders[mode] | |
| ema_input = ema_feature_enc.contextualized_features( | |
| ema_input.to(dtype=ema_dtype), | |
| padding_mask, | |
| mask=False, | |
| remove_masked=False, | |
| ) | |
| ema_padding_mask = ema_input["padding_mask"] | |
| ema_alibi_bias = ema_input.get("alibi_bias", None) | |
| ema_alibi_scale = ema_input.get("alibi_scale", None) | |
| ema_input = ema_input["x"] | |
| y = [] | |
| ema_x = [] | |
| extra_tokens = feature_extractor.modality_cfg.num_extra_tokens | |
| for i, blk in enumerate(ema_blocks): | |
| ab = ema_alibi_bias | |
| if ab is not None and alibi_scale is not None: | |
| scale = ( | |
| ema_alibi_scale[i] | |
| if ema_alibi_scale.size(0) > 1 | |
| else ema_alibi_scale.squeeze(0) | |
| ) | |
| ab = ab * scale.type_as(ab) | |
| ema_input, lr = blk( | |
| ema_input, | |
| padding_mask=ema_padding_mask, | |
| alibi_bias=ab, | |
| ) | |
| y.append(lr[:, extra_tokens:]) | |
| ema_x.append(ema_input[:, extra_tokens:]) | |
| y = self.make_targets(y, self.average_top_k_layers) | |
| orig_targets = y | |
| if self.cfg.clone_batch > 1: | |
| y = y.repeat_interleave(self.cfg.clone_batch, 0) | |
| masked = encoder_mask.mask.unsqueeze(-1) | |
| masked_b = encoder_mask.mask.bool() | |
| y = y[masked_b] | |
| if xs[0].size(1) == masked_b.size(1): | |
| xs = [x[masked_b] for x in xs] | |
| else: | |
| xs = [x.reshape(-1, x.size(-1)) for x in xs] | |
| sample_size = masked.sum().long() | |
| result = { | |
| "losses": {}, | |
| "sample_size": sample_size, | |
| } | |
| sample_size = result["sample_size"] | |
| if self.cfg.cls_loss > 0: | |
| assert extra_tokens > 0 | |
| cls_target = orig_targets.mean(dim=1) | |
| if self.cfg.clone_batch > 1: | |
| cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) | |
| cls_pred = x[:, extra_tokens - 1] | |
| result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * ( | |
| self.cfg.cls_loss * sample_size | |
| ) | |
| if self.cfg.recon_loss > 0: | |
| with torch.no_grad(): | |
| target = feature_extractor.patchify(source) | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
| if self.cfg.clone_batch > 1: | |
| target = target.repeat_interleave(self.cfg.clone_batch, 0) | |
| if masked_b is not None: | |
| target = target[masked_b] | |
| recon = xs[0] | |
| if self.recon_proj is not None: | |
| recon = self.recon_proj(recon) | |
| result["losses"]["recon"] = ( | |
| self.d2v_loss(recon, target.float()) * self.cfg.recon_loss | |
| ) | |
| if self.cfg.d2v_loss > 0: | |
| for i, x in enumerate(xs): | |
| reg_loss = self.d2v_loss(x, y) | |
| n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression" | |
| result["losses"][n] = reg_loss * self.cfg.d2v_loss | |
| suffix = "" if len(self.modalities) == 1 else f"_{mode}" | |
| with torch.no_grad(): | |
| if encoder_mask is not None: | |
| result["masked_pct"] = 1 - ( | |
| encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1) | |
| ) | |
| for i, x in enumerate(xs): | |
| n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}" | |
| result[n] = self.compute_var(x.float()) | |
| if self.ema is not None: | |
| for k, v in self.ema.logs.items(): | |
| result[k] = v | |
| y = y.float() | |
| result[f"target_var{suffix}"] = self.compute_var(y) | |
| if self.num_updates > 5000: | |
| if result[f"target_var{suffix}"] < self.cfg.min_target_var: | |
| logger.error( | |
| f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" | |
| ) | |
| raise Exception( | |
| f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" | |
| ) | |
| for k in result.keys(): | |
| if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var: | |
| logger.error( | |
| f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" | |
| ) | |
| raise Exception( | |
| f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" | |
| ) | |
| result["ema_decay"] = self.ema.get_decay() * 1000 | |
| return result | |
| def forward_decoder( | |
| self, | |
| x, | |
| feature_extractor, | |
| decoder, | |
| mask_info, | |
| ): | |
| x = feature_extractor.decoder_input(x, mask_info) | |
| x = decoder(*x) | |
| return x | |
| def d2v_loss(self, x, y): | |
| x = x.view(-1, x.size(-1)).float() | |
| y = y.view(-1, x.size(-1)) | |
| if self.loss_beta == 0: | |
| loss = F.mse_loss(x, y, reduction="none") | |
| else: | |
| loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta) | |
| if self.loss_scale is not None: | |
| scale = self.loss_scale | |
| else: | |
| scale = 1 / math.sqrt(x.size(-1)) | |
| reg_loss = loss * scale | |
| return reg_loss | |
| def make_targets(self, y, num_layers): | |
| with torch.no_grad(): | |
| target_layer_results = y[-num_layers:] | |
| permuted = False | |
| if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: | |
| target_layer_results = [ | |
| tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT | |
| ] | |
| permuted = True | |
| if self.cfg.batch_norm_target_layer: | |
| target_layer_results = [ | |
| F.batch_norm( | |
| tl.float(), running_mean=None, running_var=None, training=True | |
| ) | |
| for tl in target_layer_results | |
| ] | |
| if self.cfg.instance_norm_target_layer: | |
| target_layer_results = [ | |
| F.instance_norm(tl.float()) for tl in target_layer_results | |
| ] | |
| if permuted: | |
| target_layer_results = [ | |
| tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC | |
| ] | |
| if self.cfg.layer_norm_target_layer: | |
| target_layer_results = [ | |
| F.layer_norm(tl.float(), tl.shape[-1:]) | |
| for tl in target_layer_results | |
| ] | |
| y = target_layer_results[0].float() | |
| for tl in target_layer_results[1:]: | |
| y.add_(tl.float()) | |
| y = y.div_(len(target_layer_results)) | |
| if self.cfg.layer_norm_targets: | |
| y = F.layer_norm(y, y.shape[-1:]) | |
| if self.cfg.instance_norm_targets: | |
| y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) | |
| return y | |
| def compute_var(y): | |
| y = y.view(-1, y.size(-1)) | |
| if dist.is_initialized(): | |
| zc = torch.tensor(y.size(0)).cuda() | |
| zs = y.sum(dim=0) | |
| zss = (y**2).sum(dim=0) | |
| dist.all_reduce(zc) | |
| dist.all_reduce(zs) | |
| dist.all_reduce(zss) | |
| var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) | |
| return torch.sqrt(var + 1e-6).mean() | |
| else: | |
| return torch.sqrt(y.var(dim=0) + 1e-6).mean() | |
| def extract_features( | |
| self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True | |
| ): | |
| res = self.forward( | |
| source, | |
| mode=mode, | |
| padding_mask=padding_mask, | |
| mask=mask, | |
| features_only=True, | |
| remove_extra_tokens=remove_extra_tokens, | |
| ) | |
| return res | |
| def remove_pretraining_modules(self, modality=None, keep_decoder=False): | |
| self.ema = None | |
| self.cfg.clone_batch = 1 | |
| self.recon_proj = None | |
| if not keep_decoder: | |
| self.shared_decoder = None | |
| modality = modality.lower() if modality is not None else None | |
| for k in list(self.modality_encoders.keys()): | |
| if modality is not None and k.lower() != modality: | |
| del self.modality_encoders[k] | |
| else: | |
| self.modality_encoders[k].remove_pretraining_modules( | |
| keep_decoder=keep_decoder | |
| ) | |
| if not keep_decoder: | |
| self.modality_encoders[k].decoder = None | |