Spaces:
Sleeping
Sleeping
| import torch | |
| from massspecgym.models.base import Stage | |
| from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel | |
| class RandomRetrieval(RetrievalMassSpecGymModel): | |
| def step( | |
| self, batch: dict, stage: Stage = Stage.NONE | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # Generate random retrieval scores | |
| scores = torch.rand(batch["candidates"].shape[0]).to(self.device) | |
| # Random baseline, so we return a dummy loss | |
| loss = torch.tensor(0.0, requires_grad=True) | |
| return dict(loss=loss, scores=scores) | |
| def configure_optimizers(self): | |
| # No optimizer needed for a random baseline | |
| return None | |