Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from typing import Sequence | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from utils.torch_utilities import ( | |
| load_pretrained_model, merge_matched_keys, create_mask_from_length, | |
| loss_with_mask, create_alignment_path | |
| ) | |
| class LoadPretrainedBase(nn.Module): | |
| def process_state_dict( | |
| self, model_dict: dict[str, torch.Tensor], | |
| state_dict: dict[str, torch.Tensor] | |
| ): | |
| """ | |
| Custom processing functions of each model that transforms `state_dict` loaded from | |
| checkpoints to the state that can be used in `load_state_dict`. | |
| Use `merge_mathced_keys` to update parameters with matched names and shapes by | |
| default. | |
| Args | |
| model_dict: | |
| The state dict of the current model, which is going to load pretrained parameters | |
| state_dict: | |
| A dictionary of parameters from a pre-trained model. | |
| Returns: | |
| dict[str, torch.Tensor]: | |
| The updated state dict, where parameters with matched keys and shape are | |
| updated with values in `state_dict`. | |
| """ | |
| state_dict = merge_matched_keys(model_dict, state_dict) | |
| return state_dict | |
| def load_pretrained(self, ckpt_path: str | Path): | |
| load_pretrained_model( | |
| self, ckpt_path, state_dict_process_fn=self.process_state_dict | |
| ) | |
| class CountParamsBase(nn.Module): | |
| def count_params(self): | |
| num_params = 0 | |
| trainable_params = 0 | |
| for param in self.parameters(): | |
| num_params += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| return num_params, trainable_params | |
| class SaveTrainableParamsBase(nn.Module): | |
| def param_names_to_save(self): | |
| names = [] | |
| for name, param in self.named_parameters(): | |
| if param.requires_grad: | |
| names.append(name) | |
| for name, _ in self.named_buffers(): | |
| names.append(name) | |
| return names | |
| def load_state_dict(self, state_dict, strict=True): | |
| for key in self.param_names_to_save: | |
| if key not in state_dict: | |
| raise Exception( | |
| f"{key} not found in either pre-trained models (e.g. BERT)" | |
| " or resumed checkpoints (e.g. epoch_40/model.pt)" | |
| ) | |
| return super().load_state_dict(state_dict, strict) | |
| class DurationAdapterMixin: | |
| def __init__( | |
| self, | |
| latent_token_rate: int, | |
| offset: float = 1.0, | |
| frame_resolution: float | None = None | |
| ): | |
| self.latent_token_rate = latent_token_rate | |
| self.offset = offset | |
| self.frame_resolution = frame_resolution | |
| def get_global_duration_loss( | |
| self, | |
| pred: torch.Tensor, | |
| latent_mask: torch.Tensor, | |
| reduce: bool = True, | |
| ): | |
| target = torch.log( | |
| latent_mask.sum(1) / self.latent_token_rate + self.offset | |
| ) | |
| loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none") | |
| return loss | |
| def get_local_duration_loss( | |
| self, ground_truth: torch.Tensor, pred: torch.Tensor, | |
| mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool | |
| ): | |
| n_frames = torch.round(ground_truth / self.frame_resolution) | |
| target = torch.log(n_frames + self.offset) | |
| loss = loss_with_mask( | |
| (target - pred)**2, | |
| mask, | |
| reduce=False, | |
| ) | |
| loss *= is_time_aligned | |
| if reduce: | |
| if is_time_aligned.sum().item() == 0: | |
| loss *= 0.0 | |
| loss = loss.mean() | |
| else: | |
| loss = loss.sum() / is_time_aligned.sum() | |
| return loss | |
| def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor): | |
| pred = torch.exp(pred) * mask | |
| pred = torch.ceil(pred) - self.offset | |
| pred *= self.frame_resolution | |
| return pred | |
| def prepare_global_duration( | |
| self, | |
| global_pred: torch.Tensor, | |
| local_pred: torch.Tensor, | |
| is_time_aligned: Sequence[bool], | |
| use_local: bool = True, | |
| ): | |
| """ | |
| global_pred: predicted duration value, processed by logarithmic and offset | |
| local_pred: predicted latent length | |
| """ | |
| global_pred = torch.exp(global_pred) - self.offset | |
| result = global_pred | |
| # avoid error accumulation for each frame | |
| if use_local: | |
| pred_from_local = torch.round(local_pred * self.latent_token_rate) | |
| pred_from_local = pred_from_local.sum(1) / self.latent_token_rate | |
| result[is_time_aligned] = pred_from_local[is_time_aligned] | |
| return result | |
| def expand_by_duration( | |
| self, | |
| x: torch.Tensor, | |
| content_mask: torch.Tensor, | |
| local_duration: torch.Tensor, | |
| global_duration: torch.Tensor | None = None, | |
| ): | |
| n_latents = torch.round(local_duration * self.latent_token_rate) | |
| if global_duration is not None: | |
| latent_length = torch.round( | |
| global_duration * self.latent_token_rate | |
| ) | |
| else: | |
| latent_length = n_latents.sum(1) | |
| latent_mask = create_mask_from_length(latent_length).to( | |
| content_mask.device | |
| ) | |
| attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) | |
| align_path = create_alignment_path(n_latents, attn_mask) | |
| expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x) | |
| return expanded_x, latent_mask | |