UniFlow-Audio / models /common.py
wsntxxn's picture
Init commit
b4bbb92
raw
history blame
5.67 kB
from pathlib import Path
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.torch_utilities import (
load_pretrained_model, merge_matched_keys, create_mask_from_length,
loss_with_mask, create_alignment_path
)
class LoadPretrainedBase(nn.Module):
def process_state_dict(
self, model_dict: dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor]
):
"""
Custom processing functions of each model that transforms `state_dict` loaded from
checkpoints to the state that can be used in `load_state_dict`.
Use `merge_mathced_keys` to update parameters with matched names and shapes by
default.
Args
model_dict:
The state dict of the current model, which is going to load pretrained parameters
state_dict:
A dictionary of parameters from a pre-trained model.
Returns:
dict[str, torch.Tensor]:
The updated state dict, where parameters with matched keys and shape are
updated with values in `state_dict`.
"""
state_dict = merge_matched_keys(model_dict, state_dict)
return state_dict
def load_pretrained(self, ckpt_path: str | Path):
load_pretrained_model(
self, ckpt_path, state_dict_process_fn=self.process_state_dict
)
class CountParamsBase(nn.Module):
def count_params(self):
num_params = 0
trainable_params = 0
for param in self.parameters():
num_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return num_params, trainable_params
class SaveTrainableParamsBase(nn.Module):
@property
def param_names_to_save(self):
names = []
for name, param in self.named_parameters():
if param.requires_grad:
names.append(name)
for name, _ in self.named_buffers():
names.append(name)
return names
def load_state_dict(self, state_dict, strict=True):
for key in self.param_names_to_save:
if key not in state_dict:
raise Exception(
f"{key} not found in either pre-trained models (e.g. BERT)"
" or resumed checkpoints (e.g. epoch_40/model.pt)"
)
return super().load_state_dict(state_dict, strict)
class DurationAdapterMixin:
def __init__(
self,
latent_token_rate: int,
offset: float = 1.0,
frame_resolution: float | None = None
):
self.latent_token_rate = latent_token_rate
self.offset = offset
self.frame_resolution = frame_resolution
def get_global_duration_loss(
self,
pred: torch.Tensor,
latent_mask: torch.Tensor,
reduce: bool = True,
):
target = torch.log(
latent_mask.sum(1) / self.latent_token_rate + self.offset
)
loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
return loss
def get_local_duration_loss(
self, ground_truth: torch.Tensor, pred: torch.Tensor,
mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
):
n_frames = torch.round(ground_truth / self.frame_resolution)
target = torch.log(n_frames + self.offset)
loss = loss_with_mask(
(target - pred)**2,
mask,
reduce=False,
)
loss *= is_time_aligned
if reduce:
if is_time_aligned.sum().item() == 0:
loss *= 0.0
loss = loss.mean()
else:
loss = loss.sum() / is_time_aligned.sum()
return loss
def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
pred = torch.exp(pred) * mask
pred = torch.ceil(pred) - self.offset
pred *= self.frame_resolution
return pred
def prepare_global_duration(
self,
global_pred: torch.Tensor,
local_pred: torch.Tensor,
is_time_aligned: Sequence[bool],
use_local: bool = True,
):
"""
global_pred: predicted duration value, processed by logarithmic and offset
local_pred: predicted latent length
"""
global_pred = torch.exp(global_pred) - self.offset
result = global_pred
# avoid error accumulation for each frame
if use_local:
pred_from_local = torch.round(local_pred * self.latent_token_rate)
pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
result[is_time_aligned] = pred_from_local[is_time_aligned]
return result
def expand_by_duration(
self,
x: torch.Tensor,
content_mask: torch.Tensor,
local_duration: torch.Tensor,
global_duration: torch.Tensor | None = None,
):
n_latents = torch.round(local_duration * self.latent_token_rate)
if global_duration is not None:
latent_length = torch.round(
global_duration * self.latent_token_rate
)
else:
latent_length = n_latents.sum(1)
latent_mask = create_mask_from_length(latent_length).to(
content_mask.device
)
attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
align_path = create_alignment_path(n_latents, attn_mask)
expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
return expanded_x, latent_mask