|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
|
stream=sys.stdout, |
|
|
) |
|
|
logger = logging.getLogger("srmsnorm") |
|
|
|
|
|
|
|
|
class SimpleRMSNorm(nn.Module): |
|
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()).type_as(x) |
|
|
|
|
|
return output |
|
|
|