Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| import numpy as np | |
| class NormGPS(nn.Module): | |
| def __init__(self, input_key="gps", output_key="x_0", normalize=True): | |
| super().__init__() | |
| self.input_key = input_key | |
| self.output_key = output_key | |
| self.normalize = normalize | |
| if self.normalize: | |
| self.register_buffer( | |
| "gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0) | |
| ) | |
| def forward(self, batch): | |
| """Normalize latitude longtitude radians to -1, 1.""" # not used currently | |
| x = batch[self.input_key] | |
| if self.normalize: | |
| x = x * self.gps_normalize | |
| batch[self.output_key] = x | |
| return batch | |
| class GPStoCartesian(nn.Module): | |
| def __init__(self, input_key="gps", output_key="x_0"): | |
| super().__init__() | |
| self.input_key = input_key | |
| self.output_key = output_key | |
| def forward(self, batch): | |
| """Project latitude longtitude radians to 3D coordinates.""" | |
| x = batch[self.input_key] | |
| lat, lon = x[:, 0], x[:, 1] | |
| x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1) | |
| batch[self.output_key] = x | |
| return batch | |
| class PrecomputedPreconditioning: | |
| def __init__( | |
| self, | |
| input_key="emb", | |
| output_key="emb", | |
| ): | |
| self.input_key = input_key | |
| self.output_key = output_key | |
| def __call__(self, batch, device=None): | |
| batch[self.output_key] = batch[self.input_key] | |
| return batch | |