from typing import Any, Optional, Union, List, Sequence import inspect import random from tqdm import tqdm import numpy as np import copy import torch import torch.nn as nn import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.training_utils import compute_density_for_timestep_sampling from models.autoencoder.autoencoder_base import AutoEncoderBase from models.content_encoder.content_encoder import ContentEncoder from models.content_adapter import ContentAdapterBase from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase from utils.torch_utilities import ( create_alignment_path, create_mask_from_length, loss_with_mask, trim_or_pad_length ) from constants import SAME_LENGTH_TASKS class FlowMatchingMixin: def __init__( self, cfg_drop_ratio: float = 0.2, sample_strategy: str = 'normal', num_train_steps: int = 1000 ) -> None: r""" Args: cfg_drop_ratio (float): Dropout ratio for the autoencoder. sample_strategy (str): Sampling strategy for timesteps during training. num_train_steps (int): Number of training steps for the noise scheduler. """ self.sample_strategy = sample_strategy self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=num_train_steps ) self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler) self.classifier_free_guidance = cfg_drop_ratio > 0.0 self.cfg_drop_ratio = cfg_drop_ratio def get_input_target_and_timesteps( self, latent: torch.Tensor, training: bool, ): batch_size = latent.shape[0] noise = torch.randn_like(latent) if training: if self.sample_strategy == 'normal': u = compute_density_for_timestep_sampling( weighting_scheme="logit_normal", batch_size=batch_size, logit_mean=0, logit_std=1, mode_scale=None, ) elif self.sample_strategy == 'uniform': u = torch.rand(batch_size, ) else: raise NotImplementedError( f"{self.sample_strategy} samlping for timesteps is not supported now" ) indices = ( u * self.train_noise_scheduler.config.num_train_timesteps ).long() else: indices = ( self.train_noise_scheduler.config.num_train_timesteps // 2 ) * torch.ones((batch_size, )).long() # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval timesteps = self.train_noise_scheduler.timesteps[indices].to( device=latent.device ) sigmas = self.get_sigmas( timesteps, n_dim=latent.ndim, dtype=latent.dtype ) noisy_latent = (1.0 - sigmas) * latent + sigmas * noise target = noise - latent return noisy_latent, target, timesteps def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32): device = timesteps.device # a list from 1 declining to 1/num_train_steps sigmas = self.train_noise_scheduler.sigmas.to( device=device, dtype=dtype ) schedule_timesteps = self.train_noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def retrieve_timesteps( self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): # used in inference, retrieve new timesteps on given inference timesteps scheduler = self.infer_noise_scheduler if timesteps is not None and sigmas is not None: raise ValueError( "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" ) if timesteps is not None: accepts_timesteps = "timesteps" in set( inspect.signature(scheduler.set_timesteps).parameters.keys() ) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps( timesteps=timesteps, device=device, **kwargs ) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set( inspect.signature(scheduler.set_timesteps).parameters.keys() ) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps( num_inference_steps, device=device, **kwargs ) timesteps = scheduler.timesteps return timesteps, num_inference_steps class ContentEncoderAdapterMixin: def __init__( self, content_encoder: ContentEncoder, content_adapter: ContentAdapterBase | None = None ): self.content_encoder = content_encoder self.content_adapter = content_adapter def encode_content( self, content: list[Any], task: list[str], device: str | torch.device, instruction: torch.Tensor | None = None, instruction_lengths: torch.Tensor | None = None ): content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_output["content"], content_output[ "content_mask"] if instruction is not None: instruction_mask = create_mask_from_length(instruction_lengths) ( content, content_mask, global_duration_pred, local_duration_pred, ) = self.content_adapter( content, content_mask, instruction, instruction_mask ) return_dict = { "content": content, "content_mask": content_mask, "length_aligned_content": content_output["length_aligned_content"], } if instruction is not None: return_dict["global_duration_pred"] = global_duration_pred return_dict["local_duration_pred"] = local_duration_pred return return_dict class SingleTaskCrossAttentionAudioFlowMatching( LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, FlowMatchingMixin, ContentEncoderAdapterMixin ): def __init__( self, autoencoder: nn.Module, content_encoder: ContentEncoder, backbone: nn.Module, cfg_drop_ratio: float = 0.2, sample_strategy: str = 'normal', num_train_steps: int = 1000, ): nn.Module.__init__(self) FlowMatchingMixin.__init__( self, cfg_drop_ratio, sample_strategy, num_train_steps ) ContentEncoderAdapterMixin.__init__( self, content_encoder=content_encoder ) self.autoencoder = autoencoder for param in self.autoencoder.parameters(): param.requires_grad = False if hasattr( self.content_encoder, "audio_encoder" ) and self.content_encoder.audio_encoder is not None: self.content_encoder.audio_encoder.model = self.autoencoder self.backbone = backbone self.dummy_param = nn.Parameter(torch.empty(0)) def forward( self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs ): device = self.dummy_param.device self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) content_dict = self.encode_content(content, task, device) content, content_mask = content_dict["content"], content_dict[ "content_mask"] if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: content[mask_indices] = 0 noisy_latent, target, timesteps = self.get_input_target_and_timesteps( latent, training=self.training ) pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, context=content, x_mask=latent_mask, context_mask=content_mask ) loss = F.mse_loss(pred.float(), target.float(), reduction="none") loss = loss_with_mask(loss, latent_mask) return loss def iterative_denoise( self, latent: torch.Tensor, timesteps: list[int], num_steps: int, verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict ): progress_bar = tqdm(range(num_steps), disable=not verbose) for i, timestep in enumerate(timesteps): # expand the latent if we are doing classifier free guidance if cfg: latent_input = torch.cat([latent, latent]) else: latent_input = latent noise_pred: torch.Tensor = self.backbone( x=latent_input, timesteps=timestep, **backbone_input ) # perform guidance if cfg: noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) noise_pred = noise_pred_uncond + cfg_scale * ( noise_pred_content - noise_pred_uncond ) latent = self.infer_noise_scheduler.step( noise_pred, timestep, latent ).prev_sample progress_bar.update(1) progress_bar.close() return latent @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], latent_shape: Sequence[int], num_steps: int = 50, sway_sampling_coef: float | None = -1.0, guidance_scale: float = 3.0, num_samples_per_content: int = 1, disable_progress: bool = True, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 batch_size = len(content) * num_samples_per_content if classifier_free_guidance: content, content_mask = self.encode_content_classifier_free( content, task, num_samples_per_content ) else: content_output: dict[ str, torch.Tensor] = self.content_encoder.encode_content( content, task ) content, content_mask = content_output["content"], content_output[ "content_mask"] content = content.repeat_interleave(num_samples_per_content, 0) content_mask = content_mask.repeat_interleave( num_samples_per_content, 0 ) latent = self.prepare_latent( batch_size, latent_shape, content.dtype, device ) if not sway_sampling_coef: sigmas = np.linspace(1.0, 1 / num_steps, num_steps) else: t = torch.linspace(0, 1, num_steps + 1) t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) sigmas = 1 - t timesteps, num_steps = self.retrieve_timesteps( num_steps, device, timesteps=None, sigmas=sigmas ) latent = self.iterative_denoise( latent=latent, timesteps=timesteps, num_steps=num_steps, verbose=not disable_progress, cfg=classifier_free_guidance, cfg_scale=guidance_scale, backbone_input={ "context": content, "context_mask": content_mask, }, ) waveform = self.autoencoder.decode(latent) return waveform def prepare_latent( self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype, device: str ): shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=dtype ) return latent def encode_content_classifier_free( self, content: list[Any], task: list[str], device, num_samples_per_content: int = 1 ): content_dict = self.content_encoder.encode_content( content, task, device=device ) content, content_mask = content_dict["content"], content_dict[ "content_mask"] content = content.repeat_interleave(num_samples_per_content, 0) content_mask = content_mask.repeat_interleave( num_samples_per_content, 0 ) # get unconditional embeddings for classifier free guidance uncond_content = torch.zeros_like(content) uncond_content_mask = content_mask.detach().clone() uncond_content = uncond_content.repeat_interleave( num_samples_per_content, 0 ) uncond_content_mask = uncond_content_mask.repeat_interleave( num_samples_per_content, 0 ) # For classifier free guidance, we need to do two forward passes. # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes content = torch.cat([uncond_content, content]) content_mask = torch.cat([uncond_content_mask, content_mask]) return content, content_mask class DurationAdapterMixin: def __init__( self, latent_token_rate: int, offset: float = 1.0, frame_resolution: float | None = None ): self.latent_token_rate = latent_token_rate self.offset = offset self.frame_resolution = frame_resolution def get_global_duration_loss( self, pred: torch.Tensor, latent_mask: torch.Tensor, reduce: bool = True, ): target = torch.log( latent_mask.sum(1) / self.latent_token_rate + self.offset ) loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none") return loss def get_local_duration_loss( self, ground_truth: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool ): n_frames = torch.round(ground_truth / self.frame_resolution) target = torch.log(n_frames + self.offset) loss = loss_with_mask( (target - pred)**2, mask, reduce=False, ) loss *= is_time_aligned if reduce: if is_time_aligned.sum().item() == 0: loss *= 0.0 loss = loss.mean() else: loss = loss.sum() / is_time_aligned.sum() return loss def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor): pred = torch.exp(pred) * mask pred = torch.ceil(pred) - self.offset pred *= self.frame_resolution return pred def prepare_global_duration( self, global_pred: torch.Tensor, local_pred: torch.Tensor, is_time_aligned: Sequence[bool], use_local: bool = True, ): """ global_pred: predicted duration value, processed by logarithmic and offset local_pred: predicted latent length """ global_pred = torch.exp(global_pred) - self.offset result = global_pred # avoid error accumulation for each frame if use_local: pred_from_local = torch.round(local_pred * self.latent_token_rate) pred_from_local = pred_from_local.sum(1) / self.latent_token_rate result[is_time_aligned] = pred_from_local[is_time_aligned] return result def expand_by_duration( self, x: torch.Tensor, content_mask: torch.Tensor, local_duration: torch.Tensor, global_duration: torch.Tensor | None = None, ): n_latents = torch.round(local_duration * self.latent_token_rate) if global_duration is not None: latent_length = torch.round( global_duration * self.latent_token_rate ) else: latent_length = n_latents.sum(1) latent_mask = create_mask_from_length(latent_length).to( content_mask.device ) attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) align_path = create_alignment_path(n_latents, attn_mask) expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x) return expanded_x, latent_mask class CrossAttentionAudioFlowMatching( SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin ): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, content_adapter: ContentAdapterBase, backbone: nn.Module, content_dim: int, frame_resolution: float, duration_offset: float = 1.0, cfg_drop_ratio: float = 0.2, sample_strategy: str = 'normal', num_train_steps: int = 1000 ): super().__init__( autoencoder=autoencoder, content_encoder=content_encoder, backbone=backbone, cfg_drop_ratio=cfg_drop_ratio, sample_strategy=sample_strategy, num_train_steps=num_train_steps, ) ContentEncoderAdapterMixin.__init__( self, content_encoder=content_encoder, content_adapter=content_adapter ) DurationAdapterMixin.__init__( self, latent_token_rate=autoencoder.latent_token_rate, offset=duration_offset ) def encode_content_with_instruction( self, content: list[Any], task: list[str], device, instruction: torch.Tensor, instruction_lengths: torch.Tensor ): content_dict = self.encode_content( content, task, device, instruction, instruction_lengths ) return ( content_dict["content"], content_dict["content_mask"], content_dict["global_duration_pred"], content_dict["local_duration_pred"], content_dict["length_aligned_content"], ) def forward( self, content: list[Any], task: list[str], waveform: torch.Tensor, waveform_lengths: torch.Tensor, instruction: torch.Tensor, instruction_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs ): device = self.dummy_param.device loss_reduce = self.training or (loss_reduce and not self.training) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) content, content_mask, global_duration_pred, _, _ = \ self.encode_content_with_instruction( content, task, device, instruction, instruction_lengths ) global_duration_loss = self.get_global_duration_loss( global_duration_pred, latent_mask, reduce=loss_reduce ) if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: content[mask_indices] = 0 noisy_latent, target, timesteps = self.get_input_target_and_timesteps( latent, training=self.training ) pred: torch.Tensor = self.backbone( x=noisy_latent, timesteps=timesteps, context=content, x_mask=latent_mask, context_mask=content_mask, ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) return { "diff_loss": diff_loss, "global_duration_loss": global_duration_loss, } @torch.no_grad() def inference( self, content: list[Any], condition: list[Any], task: list[str], is_time_aligned: Sequence[bool], instruction: torch.Tensor, instruction_lengths: torch.Tensor, num_steps: int = 20, sway_sampling_coef: float | None = -1.0, guidance_scale: float = 3.0, disable_progress=True, use_gt_duration: bool = False, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 ( content, content_mask, global_duration_pred, local_duration_pred, _, ) = self.encode_content_with_instruction( content, task, device, instruction, instruction_lengths ) batch_size = content.size(0) if use_gt_duration: raise NotImplementedError( "Using ground truth global duration only is not implemented yet" ) # prepare global duration global_duration = self.prepare_global_duration( global_duration_pred, local_duration_pred, is_time_aligned, use_local=False ) # TODO: manually set duration for SE and AudioSR latent_length = torch.round(global_duration * self.latent_token_rate) task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task]) latent_length[task_mask] = content[task_mask].size(1) latent_mask = create_mask_from_length(latent_length).to(device) max_latent_length = latent_mask.sum(1).max().item() # prepare latent and noise if classifier_free_guidance: uncond_context = torch.zeros_like(content) uncond_content_mask = content_mask.detach().clone() context = torch.cat([uncond_context, content]) context_mask = torch.cat([uncond_content_mask, content_mask]) else: context = content context_mask = content_mask latent_shape = tuple( max_latent_length if dim is None else dim for dim in self.autoencoder.latent_shape ) shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=content.dtype ) if not sway_sampling_coef: sigmas = np.linspace(1.0, 1 / num_steps, num_steps) else: t = torch.linspace(0, 1, num_steps + 1) t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) sigmas = 1 - t timesteps, num_steps = self.retrieve_timesteps( num_steps, device, timesteps=None, sigmas=sigmas ) latent = self.iterative_denoise( latent=latent, timesteps=timesteps, num_steps=num_steps, verbose=not disable_progress, cfg=classifier_free_guidance, cfg_scale=guidance_scale, backbone_input={ "x_mask": latent_mask, "context": context, "context_mask": context_mask, } ) waveform = self.autoencoder.decode(latent) return waveform class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching): def __init__( self, autoencoder: AutoEncoderBase, content_encoder: ContentEncoder, content_adapter: ContentAdapterBase, backbone: nn.Module, content_dim: int, frame_resolution: float, duration_offset: float = 1.0, cfg_drop_ratio: float = 0.2, sample_strategy: str = 'normal', num_train_steps: int = 1000 ): super().__init__( autoencoder=autoencoder, content_encoder=content_encoder, content_adapter=content_adapter, backbone=backbone, content_dim=content_dim, frame_resolution=frame_resolution, duration_offset=duration_offset, cfg_drop_ratio=cfg_drop_ratio, sample_strategy=sample_strategy, num_train_steps=num_train_steps ) DurationAdapterMixin.__init__( self, latent_token_rate=autoencoder.latent_token_rate, offset=duration_offset, frame_resolution=frame_resolution ) self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) def get_backbone_input( self, target_length: int, content: torch.Tensor, content_mask: torch.Tensor, time_aligned_content: torch.Tensor, length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor ): # TODO compatility for 2D spectrogram VAE time_aligned_content = trim_or_pad_length( time_aligned_content, target_length, 1 ) length_aligned_content = trim_or_pad_length( length_aligned_content, target_length, 1 ) # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) # length_aligned_content: from aligned input (f0/energy) time_aligned_content = time_aligned_content + length_aligned_content time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( time_aligned_content.dtype ) context = content context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) # only use the first dummy non time aligned embedding context_mask = content_mask.detach().clone() context_mask[is_time_aligned, 1:] = False # truncate dummy non time aligned context if is_time_aligned.sum().item() < content.size(0): trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() else: trunc_nta_length = content.size(1) context = context[:, :trunc_nta_length] context_mask = context_mask[:, :trunc_nta_length] return context, context_mask, time_aligned_content def forward( self, content: list[Any], duration: Sequence[float], task: list[str], is_time_aligned: Sequence[bool], waveform: torch.Tensor, waveform_lengths: torch.Tensor, instruction: torch.Tensor, instruction_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs ): device = self.dummy_param.device loss_reduce = self.training or (loss_reduce and not self.training) self.autoencoder.eval() with torch.no_grad(): latent, latent_mask = self.autoencoder.encode( waveform.unsqueeze(1), waveform_lengths ) ( content, content_mask, global_duration_pred, local_duration_pred, length_aligned_content ) = self.encode_content_with_instruction( content, task, device, instruction, instruction_lengths ) # truncate unused non time aligned duration prediction if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # duration loss local_duration_pred = local_duration_pred[:, :trunc_ta_length] ta_content_mask = content_mask[:, :trunc_ta_length] local_duration_loss = self.get_local_duration_loss( duration, local_duration_pred, ta_content_mask, is_time_aligned, reduce=loss_reduce ) global_duration_loss = self.get_global_duration_loss( global_duration_pred, latent_mask, reduce=loss_reduce ) # -------------------------------------------------------------------- # prepare latent and noise # -------------------------------------------------------------------- noisy_latent, target, timesteps = self.get_input_target_and_timesteps( latent, training=self.training ) # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- if is_time_aligned.sum() == 0 and \ duration.size(1) < content_mask.size(1): duration = F.pad( duration, (0, content_mask.size(1) - duration.size(1)) ) time_aligned_content, _ = self.expand_by_duration( x=content[:, :trunc_ta_length], content_mask=ta_content_mask, local_duration=duration, ) # -------------------------------------------------------------------- # prepare input to the backbone # -------------------------------------------------------------------- # TODO compatility for 2D spectrogram VAE latent_length = noisy_latent.size(self.autoencoder.time_dim) context, context_mask, time_aligned_content = self.get_backbone_input( latent_length, content, content_mask, time_aligned_content, length_aligned_content, is_time_aligned ) # -------------------------------------------------------------------- # classifier free guidance # -------------------------------------------------------------------- if self.training and self.classifier_free_guidance: mask_indices = [ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio ] if len(mask_indices) > 0: context[mask_indices] = 0 time_aligned_content[mask_indices] = 0 pred: torch.Tensor = self.backbone( x=noisy_latent, x_mask=latent_mask, timesteps=timesteps, context=context, context_mask=context_mask, time_aligned_context=time_aligned_content, ) pred = pred.transpose(1, self.autoencoder.time_dim) target = target.transpose(1, self.autoencoder.time_dim) diff_loss = F.mse_loss(pred, target, reduction="none") diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) return { "diff_loss": diff_loss, "local_duration_loss": local_duration_loss, "global_duration_loss": global_duration_loss, } def inference( self, content: list[Any], task: list[str], is_time_aligned: Sequence[bool], instruction: torch.Tensor, instruction_lengths: Sequence[int], num_steps: int = 20, sway_sampling_coef: float | None = -1.0, guidance_scale: float = 3.0, disable_progress: bool = True, use_gt_duration: bool = False, **kwargs ): device = self.dummy_param.device classifier_free_guidance = guidance_scale > 1.0 ( content, content_mask, global_duration_pred, local_duration_pred, length_aligned_content ) = self.encode_content_with_instruction( content, task, device, instruction, instruction_lengths ) # print("content std: ", content.std()) batch_size = content.size(0) # truncate dummy time aligned duration prediction is_time_aligned = torch.as_tensor(is_time_aligned) if is_time_aligned.sum() > 0: trunc_ta_length = content_mask[is_time_aligned].sum(1).max() else: trunc_ta_length = content.size(1) # prepare local duration local_duration = self.prepare_local_duration( local_duration_pred, content_mask ) local_duration = local_duration[:, :trunc_ta_length] # use ground truth duration if use_gt_duration and "duration" in kwargs: local_duration = torch.as_tensor(kwargs["duration"]).to(device) # prepare global duration global_duration = self.prepare_global_duration( global_duration_pred, local_duration, is_time_aligned ) # -------------------------------------------------------------------- # duration adapter # -------------------------------------------------------------------- time_aligned_content, latent_mask = self.expand_by_duration( x=content[:, :trunc_ta_length], content_mask=content_mask[:, :trunc_ta_length], local_duration=local_duration, global_duration=global_duration, ) context, context_mask, time_aligned_content = self.get_backbone_input( target_length=time_aligned_content.size(1), content=content, content_mask=content_mask, time_aligned_content=time_aligned_content, length_aligned_content=length_aligned_content, is_time_aligned=is_time_aligned ) # -------------------------------------------------------------------- # prepare unconditional input # -------------------------------------------------------------------- if classifier_free_guidance: uncond_time_aligned_content = torch.zeros_like( time_aligned_content ) uncond_context = torch.zeros_like(context) uncond_context_mask = context_mask.detach().clone() time_aligned_content = torch.cat([ uncond_time_aligned_content, time_aligned_content ]) context = torch.cat([uncond_context, context]) context_mask = torch.cat([uncond_context_mask, context_mask]) latent_mask = torch.cat([ latent_mask, latent_mask.detach().clone() ]) # -------------------------------------------------------------------- # prepare input to the backbone # -------------------------------------------------------------------- latent_length = latent_mask.sum(1).max().item() latent_shape = tuple( latent_length if dim is None else dim for dim in self.autoencoder.latent_shape ) shape = (batch_size, *latent_shape) latent = randn_tensor( shape, generator=None, device=device, dtype=content.dtype ) if not sway_sampling_coef: sigmas = np.linspace(1.0, 1 / num_steps, num_steps) else: t = torch.linspace(0, 1, num_steps + 1) t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) sigmas = 1 - t timesteps, num_steps = self.retrieve_timesteps( num_steps, device, timesteps=None, sigmas=sigmas ) latent = self.iterative_denoise( latent=latent, timesteps=timesteps, num_steps=num_steps, verbose=not disable_progress, cfg=classifier_free_guidance, cfg_scale=guidance_scale, backbone_input={ "x_mask": latent_mask, "context": context, "context_mask": context_mask, "time_aligned_context": time_aligned_content, } ) waveform = self.autoencoder.decode(latent) return waveform class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching): def get_backbone_input( self, target_length: int, content: torch.Tensor, content_mask: torch.Tensor, time_aligned_content: torch.Tensor, length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor ): # TODO compatility for 2D spectrogram VAE time_aligned_content = trim_or_pad_length( time_aligned_content, target_length, 1 ) context_length = min(content.size(1), time_aligned_content.size(1)) time_aligned_content[~is_time_aligned, :context_length] = content[ ~is_time_aligned, :context_length] length_aligned_content = trim_or_pad_length( length_aligned_content, target_length, 1 ) # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) # length_aligned_content: from aligned input (f0/energy) time_aligned_content = time_aligned_content + length_aligned_content context = content context_mask = content_mask.detach().clone() return context, context_mask, time_aligned_content class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching): def get_backbone_input( self, target_length: int, content: torch.Tensor, content_mask: torch.Tensor, time_aligned_content: torch.Tensor, length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor ): # TODO compatility for 2D spectrogram VAE time_aligned_content = trim_or_pad_length( time_aligned_content, target_length, 1 ) length_aligned_content = trim_or_pad_length( length_aligned_content, target_length, 1 ) # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) # length_aligned_content: from aligned input (f0/energy) time_aligned_content = time_aligned_content + length_aligned_content time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( time_aligned_content.dtype ) context = content context_mask = content_mask.detach().clone() return context, context_mask, time_aligned_content