File size: 4,387 Bytes
5e90249 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import abc
import torch
import torch.nn as nn
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def get_noise(config, dtype=torch.float32):
if config.noise.type == 'geometric':
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
elif config.noise.type == 'loglinear':
return LogLinearNoise()
elif config.noise.type == 'cosine':
return CosineNoise()
elif config.noise.type == 'cosinesqr':
return CosineSqrNoise()
elif config.noise.type == 'linear':
return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
else:
raise ValueError(f'{config.noise.type} is not a valid noise')
def binary_discretization(z):
z_hard = torch.sign(z)
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
return z_soft + (z_hard - z_soft).detach()
class Noise(abc.ABC, nn.Module):
"""
Baseline forward method to get the total + rate of noise at a timestep
"""
def forward(self, t):
# Assume time goes from 0 to 1
return self.total_noise(t), self.rate_noise(t)
class CosineNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2)
return - torch.log(self.eps + (1 - self.eps) * cos)
class CosineSqrNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * (
torch.cos(t * torch.pi / 2) ** 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2) ** 2
return - torch.log(self.eps + (1 - self.eps) * cos)
class Linear(Noise):
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
super().__init__()
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
def rate_noise(self):
return self.sigma_max - self.sigma_min
def total_noise(self, t):
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
return (sigma_t - self.sigma_min) / (
self.sigma_max - self.sigma_min)
class GeometricNoise(Noise):
def __init__(self, sigma_min=1e-3, sigma_max=1):
super().__init__()
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
def rate_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
self.sigmas[1].log() - self.sigmas[0].log())
def total_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
class LogLinearNoise(Noise):
"""Log Linear noise schedule.
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
~1 when t varies from 0 to 1. Total noise is
-log(1 - (1 - eps) * t), so the sigma will be
(1 - eps) * t.
"""
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.sigma_max = self.total_noise(torch.tensor(1.0))
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
def rate_noise(self, t):
return (1 - self.eps) / (1 - (1 - self.eps) * t)
def total_noise(self, t):
return -torch.log1p(-(1 - self.eps) * t)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
t = - torch.expm1(- sigma_t) / (1 - self.eps)
return t
class LogPolyNoise(Noise):
"""
Log Polynomial noise schedule for slower masking of peptide bond tokens
"""
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.sigma_max = self.total_noise(torch.tensor(1.0))
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
def rate_noise(self, t):
# derivative of -log(1-t^w)
return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
def total_noise(self, t):
# -log(1-t^w)
return -torch.log1p(-(1 - self.eps) * (t**3)) |