Feature Extraction
Transformers
PyTorch
e2d2
custom_code
e2d2-wmt / noise_schedule_noise_schedules.py
yairschiff's picture
Add model and code
b5a4dcb verified
from abc import ABC, abstractmethod
import torch
class Noise(ABC):
"""
Baseline forward method to get noise parameters at a timestep
"""
def __call__(
self, t: torch.Tensor | float
) -> tuple[torch.Tensor | float, torch.Tensor | float]:
# Assume time goes from 0 to 1
pass
@abstractmethod
def inverse(self, alpha_t: torch.Tensor) -> torch.Tensor:
"""
Inverse function to compute the timestep t from the noise schedule param.
"""
raise NotImplementedError("Inverse function not implemented")
class CosineNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.name = "cosine"
def __call__(self, t):
t = t.to(torch.float32)
cos = -(1 - self.eps) * torch.cos(t * torch.pi / 2)
sin = -(1 - self.eps) * torch.sin(t * torch.pi / 2)
move_chance = cos + 1
alpha_t_prime = sin * torch.pi / 2
return 1 - move_chance, alpha_t_prime
class ExponentialNoise(Noise):
def __init__(self, exp=2, eps=1e-3):
super().__init__()
self.eps = eps
self.exp = exp
self.name = f"exp_{exp}"
def __call__(self, t):
t = t.to(torch.float32)
move_chance = torch.pow(t, self.exp)
move_chance = torch.clamp(move_chance, min=self.eps)
alpha_t_prime = -self.exp * torch.pow(t, self.exp - 1)
return alpha_t_prime, 1 - move_chance
class LogarithmicNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.name = "logarithmic"
def __call__(self, t):
t = t.to(torch.float32)
move_chance = torch.log1p(t) / torch.log(torch.tensor(2.0))
alpha_t_prime = -1 / (torch.log(torch.tensor(2.0)) * (1 + t))
return 1 - move_chance, alpha_t_prime
class LinearNoise(Noise):
def __init__(self):
super().__init__()
self.name = "linear"
def inverse(self, alpha_t):
return 1 - alpha_t
def __call__(self, t):
t = t.to(torch.float32)
alpha_t_prime = -torch.ones_like(t)
move_chance = t
return 1 - move_chance, alpha_t_prime