|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class Noise(ABC): |
|
|
""" |
|
|
Baseline forward method to get noise parameters at a timestep |
|
|
""" |
|
|
|
|
|
def __call__( |
|
|
self, t: torch.Tensor | float |
|
|
) -> tuple[torch.Tensor | float, torch.Tensor | float]: |
|
|
|
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def inverse(self, alpha_t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Inverse function to compute the timestep t from the noise schedule param. |
|
|
""" |
|
|
raise NotImplementedError("Inverse function not implemented") |
|
|
|
|
|
|
|
|
class CosineNoise(Noise): |
|
|
def __init__(self, eps=1e-3): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.name = "cosine" |
|
|
|
|
|
def __call__(self, t): |
|
|
t = t.to(torch.float32) |
|
|
cos = -(1 - self.eps) * torch.cos(t * torch.pi / 2) |
|
|
sin = -(1 - self.eps) * torch.sin(t * torch.pi / 2) |
|
|
move_chance = cos + 1 |
|
|
alpha_t_prime = sin * torch.pi / 2 |
|
|
return 1 - move_chance, alpha_t_prime |
|
|
|
|
|
|
|
|
class ExponentialNoise(Noise): |
|
|
def __init__(self, exp=2, eps=1e-3): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.exp = exp |
|
|
self.name = f"exp_{exp}" |
|
|
|
|
|
def __call__(self, t): |
|
|
t = t.to(torch.float32) |
|
|
move_chance = torch.pow(t, self.exp) |
|
|
move_chance = torch.clamp(move_chance, min=self.eps) |
|
|
alpha_t_prime = -self.exp * torch.pow(t, self.exp - 1) |
|
|
return alpha_t_prime, 1 - move_chance |
|
|
|
|
|
|
|
|
class LogarithmicNoise(Noise): |
|
|
def __init__(self, eps=1e-3): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.name = "logarithmic" |
|
|
|
|
|
def __call__(self, t): |
|
|
t = t.to(torch.float32) |
|
|
move_chance = torch.log1p(t) / torch.log(torch.tensor(2.0)) |
|
|
alpha_t_prime = -1 / (torch.log(torch.tensor(2.0)) * (1 + t)) |
|
|
return 1 - move_chance, alpha_t_prime |
|
|
|
|
|
|
|
|
class LinearNoise(Noise): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.name = "linear" |
|
|
|
|
|
def inverse(self, alpha_t): |
|
|
return 1 - alpha_t |
|
|
|
|
|
def __call__(self, t): |
|
|
t = t.to(torch.float32) |
|
|
alpha_t_prime = -torch.ones_like(t) |
|
|
move_chance = t |
|
|
return 1 - move_chance, alpha_t_prime |
|
|
|