from abc import ABC, abstractmethod import torch class Noise(ABC): """ Baseline forward method to get noise parameters at a timestep """ def __call__( self, t: torch.Tensor | float ) -> tuple[torch.Tensor | float, torch.Tensor | float]: # Assume time goes from 0 to 1 pass @abstractmethod def inverse(self, alpha_t: torch.Tensor) -> torch.Tensor: """ Inverse function to compute the timestep t from the noise schedule param. """ raise NotImplementedError("Inverse function not implemented") class CosineNoise(Noise): def __init__(self, eps=1e-3): super().__init__() self.eps = eps self.name = "cosine" def __call__(self, t): t = t.to(torch.float32) cos = -(1 - self.eps) * torch.cos(t * torch.pi / 2) sin = -(1 - self.eps) * torch.sin(t * torch.pi / 2) move_chance = cos + 1 alpha_t_prime = sin * torch.pi / 2 return 1 - move_chance, alpha_t_prime class ExponentialNoise(Noise): def __init__(self, exp=2, eps=1e-3): super().__init__() self.eps = eps self.exp = exp self.name = f"exp_{exp}" def __call__(self, t): t = t.to(torch.float32) move_chance = torch.pow(t, self.exp) move_chance = torch.clamp(move_chance, min=self.eps) alpha_t_prime = -self.exp * torch.pow(t, self.exp - 1) return alpha_t_prime, 1 - move_chance class LogarithmicNoise(Noise): def __init__(self, eps=1e-3): super().__init__() self.eps = eps self.name = "logarithmic" def __call__(self, t): t = t.to(torch.float32) move_chance = torch.log1p(t) / torch.log(torch.tensor(2.0)) alpha_t_prime = -1 / (torch.log(torch.tensor(2.0)) * (1 + t)) return 1 - move_chance, alpha_t_prime class LinearNoise(Noise): def __init__(self): super().__init__() self.name = "linear" def inverse(self, alpha_t): return 1 - alpha_t def __call__(self, t): t = t.to(torch.float32) alpha_t_prime = -torch.ones_like(t) move_chance = t return 1 - move_chance, alpha_t_prime