import numpy as np import sys import itertools import time import torch from torch import Tensor import math import torch.nn.functional as F import numpy as np import random as rd import lightning as L import torchmetrics from dataclasses import dataclass import gc import utils.utils as utils from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer import noise_schedule from torch.optim.lr_scheduler import _LRScheduler import roformer as roformer from utils.app import PeptideAnalyzer import pandas as pd base_path = '/path/to/your/home' def _sample_categorical(categorical_probs): gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()) return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long) def _sample_categorical_gradient(categorical_probs, temp = 1.0): gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()) output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2) return output def _unsqueeze(x, reference): return x.view( * x.shape, * ((1,) * (len(reference.shape) - len(x.shape)))) def sample_batched_categorical(categorical_probs, batch_size): """ Generates `m` distinct sequences sampled from categorical probabilities using the Gumbel distribution to ensure randomness while following probabilities Args: categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length) representing categorical probabilities m (int): number of distinct sequences to sample Returns: torch.Tensor: tensor of shape (m, sequence_length), where each row is a distinct sequence of sampled category indices. """ _, sequence_length, vocab_size = categorical_probs.shape # add Gumbel noise and sample m sequences gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device) noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities # select the highest score (most likely category after Gumbel noise) sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length) return sampled_sequences def sample_batched_top_k(categorical_probs, batch_size, k): """ Generates `m` sequences sampled from the top-k probabilities of each token using Gumbel noise to ensure randomness and reduce bias towards the most likely options. Args: categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length) representing categorical probabilities. m (int): Number of sequences to sample. k (int): Number of top probabilities to consider for sampling. Returns: torch.Tensor: A tensor of shape (m, sequence_length), where each row is a sampled sequence of category indices. """ _, sequence_length, vocab_length = categorical_probs.shape # Add Gumbel noise to the log probabilities gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device) noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length) # Get the top-k categories based on noisy scores top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k) # Convert top-k scores back to probabilities and normalize top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k) # Sample randomly from the top-k probabilities sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device) sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length) # Map sampled indices back to the original vocabulary indices sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long) return sampled_sequences @dataclass class Loss: loss: torch.FloatTensor nlls: torch.FloatTensor attn_mask: torch.FloatTensor class NLL(torchmetrics.aggregation.MeanMetric): pass class BPD(NLL): def compute(self) -> Tensor: """Computes the bits per dimension. Returns: bpd """ return self.mean_value / self.weight / math.log(2) class Perplexity(NLL): def compute(self) -> Tensor: """Computes the Perplexity. Returns: Perplexity """ return torch.exp(self.mean_value / self.weight) class Diffusion(L.LightningModule): def __init__( self, config, tokenizer = None, mode="finetune", device=None, ): super().__init__() self.config = config #self.save_hyperparameters() # PeptideCLM tokenizer if tokenizer is None: self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt', f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt') else: self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size self.mask_index = self.tokenizer.mask_token_id self.sampler = self.config.sampling.predictor self.analyzer = PeptideAnalyzer() # backbone LM PeptideCLM model self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device) if mode == "finetune": self.backbone.freeze_model() self.backbone.unfreeze_n_layers(n=8) elif mode == "eval": self.backbone.freeze_model() self.backbone.requires_grad_(False) self.backbone.eval() elif mode == "train": self.backbone.requires_grad_(True) self.backbone.train() self.neg_infinity = -1000000.0 self.T = config.T # noise schedule for non-peptide bond tokens (default to log-linear) self.noise = noise_schedule.get_noise(config) # noise schedule for peptide bonds (log-polynomial) self.bond_noise = noise_schedule.LogPolyNoise() self.time_conditioning = self.config.time_conditioning self.fast_forward_epochs = None self.fast_forward_batches = None self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path self.gen_ppl_metric = Perplexity() self.lr = self.config.optim.lr self.sampling_eps = self.config.training.sampling_eps metrics = torchmetrics.MetricCollection({ 'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity(), }) metrics.set_dtype(torch.float64) self.train_metrics = metrics.clone(prefix='trainer/') self.valid_metrics = metrics.clone(prefix='val/') self.test_metrics = metrics.clone(prefix='test/') ### FOR THE EXPANSION AND ROLLOUT STEP ### def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5): num_steps = args.total_num_steps B = args.batch_size x_rollout = self.sample_prior( B, args.seq_length).to(self.device) log_rnd = torch.zeros(args.batch_size, device=self.device) timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) dt = (1 - eps) / num_steps for i in range(num_steps): t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device) log_p, x_next, log_policy_step, log_pretrained_step = \ self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained) log_rnd += log_pretrained_step - log_policy_step x_rollout = x_next # if mask token remains, fully unmask mask_positions = (x_rollout == self.mask_index) # (B, L) bool # does **any** mask remain in any sequence any_mask_global = mask_positions.any().item() # true if mask remains if any_mask_global: log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt) x_rollout = x_next childSequences = self.tokenizer.batch_decode(x_rollout) # change rewards for peptides valid_x_final = [] validSequences = [] valid_log_rnd = [] for i in range(B): # string sequence childSeq = childSequences[i] # check if the peptide is valid if self.analyzer.is_peptide(childSeq): valid_x_final.append(x_rollout[i]) validSequences.append(childSeq) valid_log_rnd.append(log_rnd[i]) # compute multi-objective rewards score_vectors = reward_model(input_seqs=validSequences) scalar_rewards = np.sum(score_vectors, axis=-1) scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device) print(f"scalar reward dim{len(scalar_rewards)}") valid_log_rnd = torch.stack(valid_log_rnd, dim=0) log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha valid_x_final = torch.stack(valid_x_final, dim=0) return valid_x_final, log_rnd, scalar_rewards def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() print(f"device:{self.device}") if batch_size is None: batch_size = args.batch_size num_steps = args.total_num_steps x_rollout = self.sample_prior( batch_size, args.seq_length).to(self.device, dtype=torch.long) timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) dt = torch.tensor((1 - eps) / num_steps, device=self.device) for i in range(num_steps): t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device) log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt) x_rollout = x_next x_rollout = x_rollout.to(self.device) # if mask token remains, fully unmask mask_positions = (x_rollout == self.mask_index) # (B, L) bool # does **any** mask remain in any sequence any_mask_global = mask_positions.any().item() # true if mask remains if any_mask_global: log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt) x_rollout = x_next x_rollout = x_rollout.to(self.device) childSequences = self.tokenizer.batch_decode(x_rollout) valid_x_final = [] validSequences = [] for idx, seq in enumerate(childSequences): if self.analyzer.is_peptide(seq): valid_x_final.append(x_rollout[idx]) validSequences.append(seq) valid_fraction = len(validSequences) / batch_size if (len(validSequences) != 0): # add scores to log score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives) average_scores = score_vectors.T affinity = average_scores[0] sol = average_scores[1] hemo = average_scores[2] nf = average_scores[3] permeability = average_scores[4] else: zeros = [0.0] affinity = zeros sol = zeros hemo = zeros nf = zeros permeability = zeros if dataframe: df = pd.DataFrame({ "Peptide Sequence": validSequences, "Binding Affinity": affinity if len(validSequences) else [0.0], "Solubility": sol if len(validSequences) else [0.0], "Hemolysis": hemo if len(validSequences) else [0.0], "Nonfouling": nf if len(validSequences) else [0.0], "Permeability": permeability if len(validSequences) else [0.0], }) return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() sigma_t, _ = self.noise(t) if token_array.ndim == 1: token_array = token_array.unsqueeze(0) if x_next.ndim == 1: x_next = x_next.unsqueeze(0) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if attn_mask is None: attn_mask = torch.ones_like(token_array).to(self.device) log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] copy_flag = (token_array != self.mask_index) assert copy_flag.dtype == torch.bool, "copy_flag must be bool" changed_mask = (~copy_flag) # compute the per-sequence log-probability under the pretrained model log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype) log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1) # returns: # log_policy_step (B, ) log probability x_next tokens under policy if log_policy_step.ndim == 1: log_policy_step = log_policy_step.squeeze(0) return log_policy_step def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None): torch.cuda.empty_cache() dev = self.device self.backbone.to(dev).eval() self.noise.eval() t = t.to(dev) dt = torch.as_tensor(dt, device=dev, dtype=t.dtype) assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) sigma_t = sigma_t.to(dev) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if attn_mask is None: attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long) else: attn_mask = attn_mask.to(dev) if p_x0 is None: log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() else: # ensure provided p_x0 is on dev log_p = None p_x0 = p_x0.to(dev) assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] x_changed = _sample_categorical(q_xs) if x_changed.device != dev or x_changed.dtype != token_array.dtype: x_changed = x_changed.to(dev, dtype=token_array.dtype) copy_flag = (token_array != self.mask_index) int_copy_flag = copy_flag.to(token_array.dtype) x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed # returns: # log_p (B, L, D) log probabilties of each token under the policy model # x_next (B, L) next sequences return log_p, x_next def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if attn_mask is None: attn_mask = torch.ones_like(token_array).to(self.device) if p_x0 is None: log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() assert change_prob_t.ndim == p_x0.ndim # changed for noise removal p_x0 = p_x0.clone() p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK q_xs = p_x0 * (change_prob_t - change_prob_s) x_changed = _sample_categorical(q_xs) copy_flag = (token_array != self.mask_index) int_copy_flag = copy_flag.to(token_array.dtype) x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed # returns: # log_p (B, L, D) log probabilties of each token under the policy model # x_next (B, L) next sequences return log_p, x_next def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if attn_mask is None: attn_mask = torch.ones_like(token_array).to(self.device) if p_x0 is None: log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] x_changed = _sample_categorical(q_xs) copy_flag = (token_array != self.mask_index) int_copy_flag = copy_flag.to(token_array.dtype) x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed # compute the log-probability under pretrained model at each step with torch.no_grad(): # pretrained should output log-probs over vocab at each position given the *parent* (masked) input log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) # log-prob of the *sampled token* at each position log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L] # sum only over the sites actually sampled this step (i.e., where parent was mask) assert copy_flag.dtype == torch.bool, "copy_flag must be bool" changed_mask = (~copy_flag) # mask of tokens that were unmasked in this step unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype) log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1) # compute the per-sequence log-probability under the pretrained model log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L] log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1) # returns: # log_p (B, L, D) log probabilties of each token under the policy model # x_next (B, L) next sequences # log_policy_step (B, ) log probability of all unmasked tokens under policy # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model return log_p, x_next, log_policy_step, log_pretrained_step def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if attn_mask is None: attn_mask = torch.ones_like(token_array).to(self.device) if p_x0 is None: log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() assert change_prob_t.ndim == p_x0.ndim # changed for noise removal p_x0 = p_x0.clone() p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK q_xs = p_x0 * (change_prob_t - change_prob_s) x_changed = _sample_categorical(q_xs) copy_flag = (token_array != self.mask_index) int_copy_flag = copy_flag.to(token_array.dtype) x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed # compute the log-probability under pretrained model at each step with torch.no_grad(): # pretrained should output log-probs over vocab at each position given the *parent* (masked) input log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) # log-prob of the *sampled token* at each position log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L] # sum only over the sites actually sampled this step (i.e., where parent was mask) assert copy_flag.dtype == torch.bool, "copy_flag must be bool" changed_mask = (~copy_flag) # mask of tokens that were unmasked in this step unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype) log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1) # compute the per-sequence log-probability under the pretrained model log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L] log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1) # returns: # log_p (B, L, D) log probabilties of each token under the policy model # x_next (B, L) next sequences # log_policy_step (B, ) log probability of all unmasked tokens under policy # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model return log_p, x_next, log_policy_step, log_pretrained_step # first step in expansion def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None): torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if token_array.dim() == 1: token_array = token_array.unsqueeze(0) # expand to match (num_children, L) if attn_mask is None: attn_mask = torch.ones_like(token_array).to(self.device) token_array = token_array.to(self.device) sigma_t = sigma_t.to(self.device) if p_x0 is None: log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) p_x0 = log_p.exp() assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] # repeat the parent token along the first dimension which will be unmasked into distinct sequences token_array_expanded = token_array.repeat(batch_size, 1) if self.config.mcts.sampling == 0: x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size) else: x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling) copy_flag = (token_array_expanded != self.mask_index) int_copy_flag = copy_flag.to(token_array.dtype) x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed # compute the log-probability under pretrained model at each step with torch.no_grad(): # pretrained should output log-probs over vocab at each position given the *parent* (masked) input log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t) # expand to match the shape of x_children log_pre = log_pre.repeat(batch_size, 1, 1) # log-prob of the *sampled token* at each position log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L] # sum only over the sites actually sampled this step (i.e., where parent was mask) assert copy_flag.dtype == torch.bool, "copy_flag must be bool" changed_mask = (~copy_flag) # mask of tokens that were unmasked in this step unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype) log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1) # compute the per-child log-probability under the pretrained model log_p = log_p.repeat(batch_size, 1, 1) log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token #print(log_policy_token) log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1) # returns: # log_p (B, L, D) log probabilties of each token under the policy model # x_children (B, L) child sequences # log_policy_step (B, ) log probability of all unmasked tokens under policy # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model return log_p, x_children, log_policy_step, log_pretrained_step def compute_invalid_loss(self, logits, k=None, temp=None): """ Penalizes logits that produce invalid sequences using the `is_peptide` function, scaling penalties inversely with token probabilities. Args: logits: Tensor of shape [batch_size, seq_len, vocab_size]. k: Number of samples for Gumbel-Rao. temp: Temperature for softmax. Returns: loss: A scalar tensor representing the total loss for invalid sequences. """ #samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size) # Convert logits to sequences using the tokenizer batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len) sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) # Check validity of each sampled sequence (not differentiable) penalties = torch.tensor( [1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences], dtype=torch.float32, device=self.device ) #print(penalties) # Compute probabilities for each token (batch_size, seq_length) sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device) # scale penalties by softmax probability of sampled tokens scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length) return scaled_penalty.to(self.device) ### DIFFUSION LOSS ### def sample_t(self, n, device): """ Sample random time steps for batch training """ # sample values uniformly at random from [0, 1) eps_t = torch.rand(n, device=device) # antithetic sampling: reduce variance by pairing each sample with complementary sample if self.config.training.antithetic_sampling: # compute interval between sampled time steps offset = torch.arange(n, device=device) / n # ensure that each eps value is evenly spaced between [0, 1) eps_t = ((eps_t / n) + offset) % 1 # ensures values are not exactly 0 or 1 t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps return t """def mask_samples(self, x0, mask_prob): # generate array of values in range [0, 1] uniformly at random # will be used to determine which tokens are masked mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L) # select tokens to mask if the random value in mask_indices is less than mask_prob # this will mask approximately the fraction of tokens indicated by mask_prob zt = torch.where(mask_indices < mask_prob, self.mask_index, x0) return zt""" def q_xt(self, x, mask_prob): """Computes the noisy sample xt. Args: x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. move_chance: float torch.Tensor with shape (batch_size, 1). """ actual_seq_length = (x != 0).sum(dim=-1, keepdim=True) #print(actual_seq_length) max_mask_length = (actual_seq_length * 0.75).long() mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool) for i in range(x.shape[0]): true_positions = torch.where(mask_indices[i])[0] if len(true_positions) > max_mask_length[i]: selected_positions = true_positions[:max_mask_length[i].item()] restricted_move_indices[i, selected_positions] = True else: restricted_move_indices[i] = mask_indices[i] xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x) return xt def sample_prior(self, *batch_dims): """ Returns array of fully masked sequences with same shape as input """ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64) ### COMPUTING LOSS ### def compute_diffusion_loss(self, model_output, xt, x0, t): """ Computes diffusion loss term in ELBO (evaluates how accurately the model predicts the token probabilities at each time step) Inputs: - model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position - zt: corrupted version of original input x0 at timestep t - x0: original input sequence - t: timestep """ # compute interval between each timestep dt = 1 / self.T # compute vectorized alpha scaling terms for the logits at timestep s and t alpha_t = 1 - t + torch.zeros_like(x0) # s = t - dt alpha_s = 1 - (t - dt) + torch.zeros_like(x0) # gather vector of log-probabilities for each token in x0 # log log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size) # gather log-probabillities for assigning a masked token at each position in the sequence at time t # log log_x_theta_at_m = model_output[:, :, self.mask_index] # obtain non-log probability of assigning a masked token # x_theta_at_m = log_x_theta_at_m.exp() # first term of diffusion loss term_1_coef = dt / t term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1) term_1_log_denom = log_x_theta_at_x0 # second term of diffusion loss term_2_coef = 1 - (dt / t) term_2_log_numerator = term_1_log_numerator term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1) L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) + term_2_coef * (term_2_log_numerator - term_2_log_denom)) # multiply by term L_vb = L_vb_masked * (xt == self.mask_index) # scale by T and return return self.T * L_vb def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None): """ Training reverse diffusion model x_theta to reconstruct samples x0 bond_mask: (batch, seq_length) """ # randomly sample time steps to start the denoising process for each x0 in batch t = self.sample_t(x0.shape[0], self.device) # if we are training the intermediate transition blocks if self.T > 0: # scale by total timesteps T and cast to integer t = (t * self.T).to(torch.int) # scale down by T to get a multiple of 1/T t = t / self.T # add 1/T to ensure no 0 values t += (1 / self.T) # get noise and rate of noise at timestep t # sigma = -log(1-t); dsigma = 1 / (1-t) sigma, dsigma = self.noise(t) time_conditioning = sigma[:, None] # Get masking probabilities for all tokens for each batch # log-linear: 1 - alpha = t base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L) if self.config.noise.state_dependent and (bond_mask is not None): # log-polynomial masking schedule: alpha = 1 - t^w # bond_sigma = -log(1-t^w) for w = 3 (default) # bond_dsigma = -wt^(w-1) / (1-t^w) bond_sigma, bond_dsigma = self.bond_noise(t) # scalar # expand dimensions for broadcasting to (B, L) bond_sigma = bond_sigma[:, None] bond_dsigma = bond_dsigma[:, None] sigma = sigma[:, None] dsigma = dsigma[:, None] # compute masking probability for peptide bonds 1 - bond_alpha = t^w bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device) # piece together (B, L) tensor with modified masking prob at peptide-bond locations mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device) #print(mask_prob) dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device) sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device) else: mask_prob = base_mask_prob.to(self.device) # get masked samples at different timesteps if mask is None: zt = self.q_xt(x0, mask_prob).to(self.device) else: zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device) model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device) # debugging assert not torch.isnan(model_output).any() assert model_output.is_cuda utils.print_nans(model_output, 'model_output') # compute invalid loss invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L) #print(invalid_loss) if self.T > 0: # compute diffusion loss diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t) return diffusion_loss # compute loss for the final that converts from z0 to x0 # -log(p_theta) # get (batch_size, L) array of log-probabilities log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L) if self.config.noise.state_dependent and (bond_mask is not None): return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device) else: return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device) def _loss(self, x0, attn_mask, bond_mask=None, mask=None): loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask) # negative log loss nlls = loss * attn_mask # count number of tokens num_tokens = attn_mask.sum() # compute batch loss batch_nll = nlls.sum() # compute per token loss token_nll = batch_nll / num_tokens # return losses return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device)) def _compute_loss(self, batch, prefix, bond_mask=None): attn_mask = batch['attention_mask'].to(self.device) if 'mask' in batch: mask = batch['mask'].to(self.device) else: mask = None if 'bond_mask' in batch: bond_mask = batch['bond_mask'].to(self.device) else: bond_mask = None losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask) loss = losses.loss if prefix == 'train': self.train_metrics.update( losses.nlls.to(self.device), losses.attn_mask.to(self.device) ) metrics = self.train_metrics elif prefix == 'val': self.valid_metrics.update( losses.nlls.to(self.device), losses.attn_mask.to(self.device) ) metrics = self.valid_metrics elif prefix == 'test': self.test_metrics.update(losses.nlls, losses.attn_mask) metrics = self.test_metrics else: raise ValueError(f'Invalid prefix: {prefix}') self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True) return loss ### SAMPLING ### def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5): # get number of timesteps if sample_steps is None: sample_steps = self.config.sampling.steps if seq_length is None: seq_length = self.config.sampling.seq_length # sample fully masked sequences z = self.sample_prior(num_samples, seq_length).to(self.device) # create vector of sample_steps timesteps timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device) # compute interval between timesteps dt = (1 - eps) / sample_steps for i in range(sample_steps): t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device) z = self.single_reverse_step(z, t, dt) return z ### SAMPLING STEP ### """ def single_reverse_step(self, zt, t, dt, attn_mask=None): # get sigma values that determine masking prob sigma_t, _ = self.noise(t) sigma_s, _ = self.noise(t - dt) # reshape sigmas if sigma_t.ndim > 1: sigma_t = sigma_t.squeeze(-1) if sigma_s.ndim > 1: sigma_s = sigma_s.squeeze(-1) assert sigma_t.ndim == 1, sigma_t.shape assert sigma_s.ndim == 1, sigma_s.shape # compute masking probabilities for each timestep change_prob_t = 1 - torch.exp(-sigma_t) change_prob_s = 1 - torch.exp(-sigma_s) # expand dimensions change_prob_t = change_prob_t[:, None, None] change_prob_s = change_prob_s[:, None, None] # get prodiction model that outputs token probabilities log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t) # check dimensions match assert change_prob_t.ndim == log_p_x0.ndim # compute reverse diffusion probability of being unmasked at timestep s # (sigma_s - sigma_t)*x_theta q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s) # compute reverse diffusion probability of remaining masked at timestep s # (1 - sigma_s)*m q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0] # sample sequence at timestep s from categorical distribution of q_zs z_changed = _sample_categorical(q_zs) copy_flag = (zt != self.mask_index).to(zt.dtype) return (copy_flag * zt) + ((1 - copy_flag) * z_changed)""" def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None): assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if p_x0 is None: p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp() assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] x_changed = _sample_categorical(q_xs) copy_flag = (x != self.mask_index).to(x.dtype) return p_x0, copy_flag * x + (1 - copy_flag) * x_changed # first step in expansion def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None): """ Generates batch_size different samples from the same starting point for the first expansion step of MCTS """ assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 change_prob_t = t[:, None, None] change_prob_s = (t - dt)[:, None, None] assert change_prob_t.ndim == 3, change_prob_t.shape if token_array.dim() == 1: token_array = token_array.unsqueeze(0) #token_array = token_array.repeat(batch_size, 1) attn_mask = torch.ones_like(token_array).to(self.device) if p_x0 is None: p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp() assert change_prob_t.ndim == p_x0.ndim q_xs = p_x0 * (change_prob_t - change_prob_s) # zero-masking probability q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0] # repeat the parent token along the first dimension which will be unmasked into distinct sequences token_array = token_array.repeat(batch_size, 1) if self.config.mcts.sampling == 0: x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size) else: x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling) copy_flag = (token_array != self.mask_index).to(token_array.dtype) return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed def _process_sigma(self, sigma): if sigma.ndim > 1: sigma = sigma.squeeze(-1) if not self.time_conditioning: sigma = torch.zeros_like(sigma) assert sigma.ndim == 1, sigma.shape return sigma def forward(self, zt, attn_mask, sigma): """ Predicts the token log-probabilities from zt at time t with noise schedule sigma """ sigma = self._process_sigma(sigma) with torch.cuda.amp.autocast(dtype=torch.float32): logits = self.backbone(zt, attn_mask).to(self.device) return self.subs_parameterization(logits, zt) def subs_parameterization(self, logits, zt): """ Updates reverse diffusion logits based on SUBS parameterization: - zero masking probabilities: -infinity probability of being masked during reverse diffusion - carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion Args: logits: vector of token probabilities for unmasking masked tokens zt: partially unmasked sequence at current timestep """ logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token] logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device) unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length] batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices batch_idx = batch_idx.to(self.device) seq_idx = seq_idx.to(self.device) tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions #assert logits.is_contiguous(), "logits tensor is not contiguous" #assert unmasked_indices.shape == zt.shape, "same shape" #assert not torch.isnan(logits).any(), "NaN values found in logits" #assert tokens.max() < logits.shape[-1], "token indices out of bounds" #assert batch_idx.max() < logits.shape[0], "batch index out of bounds" #assert seq_idx.max() < logits.shape[1], "seq index out of bounds" #assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent" logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first logits[unmasked_indices, zt[unmasked_indices]] = 0 # Set only the specific token positions to 0 # return logits with SUBS parameterization return logits.to(self.device) """SAMPLING""" @torch.no_grad() def _sample(self, num_steps=None, eps=1e-5, x_input=None): """ Generate samples """ batch_size_per_gpu = self.config.eval.perplexity_batch_size if num_steps is None: num_steps = self.config.sampling.steps if x_input is not None: x = x_input['input_ids'].to(self.device) attn_mask = x_input['attention_mask'].to(self.device) else: x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device) attn_mask = torch.ones_like(x).to(self.device) timesteps = torch.linspace(1, eps, num_steps+1, device=self.device) dt = (1 - eps) / num_steps p_x0_cache = None generation_history = [] # used to track which tokens are unmasked for i in range(num_steps): t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device) if self.sampler == 'ddpm': x = self.single_reverse_step(x, t, dt).to(self.device) elif self.sampler == 'ddpm_cache': p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask) if (not torch.allclose(x_next, x) or self.time_conditioning): # Disable caching p_x0_cache = None x = x_next.to(self.device) #print(self.tokenizer.decode(x.squeeze())) else: x = self._analytic_update(x, t, dt, attn_mask).to(self.device) if self.config.sampling.noise_removal: t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device) if self.sampler == 'analytic': x = self._denoiser_update(x, t).to(self.device) else: time_conditioning = self.noise(t)[0].to(self.device) x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device) #print(self.tokenizer.decode(x.squeeze())) return x.to(self.device) def restore_model_and_sample(self, num_steps, eps=1e-5): """Generate samples from the model.""" self.backbone.eval() self.noise.eval() samples = self._sample(num_steps=num_steps, eps=eps) self.backbone.train() self.noise.train() return samples def get_score(self, zt, sigma, attn_mask=None): # score(x, t) = p_t(y) / p_t(x) # => log score(x, t) = log p_t(y) - log p_t(x) # case 1: x = masked # (i) y = unmasked # log score(x, t) = log p_\theta(x)|_y + log k # where k = exp(- sigma) / (1 - exp(- sigma)) # (ii) y = masked # log score(x, t) = 0 # case 2: x = unmasked # (i) y != masked, y != x # log score(x_i, t) = - inf # (ii) y = x # log score(x_i, t) = 0 # (iii) y = masked token # log score(x_i, t) = - log k # where k = exp(- sigma) / (1 - exp(- sigma)) model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma) log_k = -torch.log(torch.expm1(sigma)).squeeze(-1) assert log_k.ndim == 1 masked_score = model_output + log_k[:, None, None] masked_score[:, :, self.mask_index] = 0 unmasked_score = self.neg_infinity * torch.ones_like(model_output) unmasked_score = torch.scatter( unmasked_score, -1, zt[..., None], torch.zeros_like(unmasked_score[..., :1])) unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt)) masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None] model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices)) return model_output.exp() def _staggered_score(self, score, dsigma): score = score.clone() extra_const = (1 - dsigma.exp()) * score.sum(dim=-1) score *= dsigma.exp()[:, None] score[..., self.mask_index] += extra_const return score def _analytic_update(self, x, t, step_size, attn_mask=None): curr_sigma, _ = self.noise(t) next_sigma, _ = self.noise(t - step_size) dsigma = curr_sigma - next_sigma score = self.get_score(x, attn_mask, curr_sigma) stag_score = self._staggered_score(score, dsigma) probs = stag_score * self._transp_transition(x, dsigma) return _sample_categorical(probs) def _denoiser_update(self, x, t): sigma, _ = self.noise(t) score = self.get_score(x, sigma) stag_score = self._staggered_score(score, sigma) probs = stag_score * self._transp_transition(x, sigma) probs[..., self.mask_index] = 0 samples = _sample_categorical(probs) return samples def _transp_transition(self, i, sigma): sigma = unsqueeze(sigma, reference=i[..., None]) edge = torch.exp(-sigma) * F.one_hot( i, num_classes=self.vocab_size) edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None] return edge """TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py""" def on_train_epoch_start(self): torch.cuda.empty_cache() self.backbone.train() self.noise.train() def training_step(self, batch, batch_idx): # Initialize throughput calculation start_time = time.time() if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles': loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask']) else: loss = self._compute_loss(batch, prefix='train') self.log(name='trainer/loss', value=loss.item(), on_step=True, on_epoch=False, sync_dist=True) # Calculate throughput elapsed_time = time.time() - start_time total_tokens = batch['input_ids'].numel() throughput = total_tokens / elapsed_time self.log(name='trainer/throughput', value=throughput, on_step=True, on_epoch=False, sync_dist=True) return loss def on_load_checkpoint(self, checkpoint): self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] ### VALIDATION ### def on_validation_epoch_start(self): gc.collect() torch.cuda.empty_cache() self.backbone.eval() self.noise.eval() assert self.valid_metrics.nll.mean_value == 0 assert self.valid_metrics.nll.weight == 0 def validation_step(self, batch, batch_idx): if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles': loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask']) else: loss = self._compute_loss(batch, prefix='val') self.log(name='trainer/val_loss', value=loss.item(), on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) return loss def on_validation_epoch_end(self): gc.collect() torch.cuda.empty_cache() ### OPTIMIZATION ### def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) gc.collect() torch.cuda.empty_cache() def configure_optimizers(self): optimizer = torch.optim.AdamW( itertools.chain(self.backbone.parameters(),self.noise.parameters()), lr=self.config.optim.lr, betas=(self.config.optim.beta1, self.config.optim.beta2), eps=self.config.optim.eps, weight_decay=self.config.optim.weight_decay ) self.total_steps = self.config.trainer.max_steps scheduler = CosineWarmup(optimizer, warmup_steps=self.config.lr_scheduler.num_warmup_steps, total_steps=self.total_steps) scheduler_dict = { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1, 'monitor': 'val/loss', 'name': 'trainer/lr' } return [optimizer], [scheduler_dict] @torch.no_grad() def compute_masked_perplexity(self, generated_ids, input_ids): """ Computes masked perplexity between array of generated token ids and masked ids that are converted to logits """ total_nll = 0 total_tokens = 0 input_ids = torch.tensor(input_ids).to(self.device) #print(input_ids) for sequence in generated_ids: # tokenize the sequence gt_ids = torch.tensor(sequence).to(self.device) #print(gt_ids) sys.stdout.flush() # forward pass thorugh backbone peptideclm model attn_mask = torch.ones_like(input_ids).to(self.device) # compute logits using backbone if self.config.mode in ['train', 'ppl_eval']: outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask) elif self.config.mode == 'sample_eval': outputs = self.backbone.forward(input_ids=input_ids) # get logits for each position in sequence across all tokens in vocab #logits = outputs[-1] # (batch_size, seq_length, vocab_size) logits = outputs.view(-1, outputs.size(-1)) gt_ids = gt_ids.view(-1) #print(logits.shape) #print(gt_ids.shape) # compute loss # shift_logits = logits[:, :-1, :].contiguous() # remove eos # shift_labels = input_ids[:, 1:].contiguous() # print(masked) loss = F.cross_entropy(logits, gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1), reduction='sum') total_nll += loss.item() # count all non-padding tokens total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos # compute pseudo-perplexity # print(total_nll, ",;,", total_tokens) pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens)) self.gen_ppl_metric.update(pseudo_perplexity) return pseudo_perplexity.item() def unsqueeze(x, reference): return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape)))) class CosineWarmup(_LRScheduler): def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): self.warmup_steps = warmup_steps self.total_steps = total_steps self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate super(CosineWarmup, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_steps: return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio return [decayed_lr * base_lr for base_lr in self.base_lrs]