MVP / mvp /data /data_module.py
yzhouchen001's picture
model code
d9df210
from torch.utils.data.dataloader import DataLoader
from massspecgym.data.data_module import MassSpecDataModule
from mvp.data.datasets import ContrastiveDataset
from functools import partial
from massspecgym.models.base import Stage
class TestDataModule(MassSpecDataModule):
def __init__(
self,
collate_fn,
**kwargs
):
super().__init__(**kwargs)
self.collate_fn = collate_fn
def prepare_data(self):
pass
def setup(self, stage=None):
if stage == "test":
self.test_dataset = self.dataset
else:
raise Exception("Data module supports test set only")
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=self.collate_fn,
)
def train_dataloader(self):
return None
def val_dataset(self):
return None
class ContrastiveDataModule(MassSpecDataModule):
def __init__(
self,
collate_fn,
**kwargs
):
super().__init__(**kwargs)
self.collate_fn = collate_fn
self.regularization_flag = False
def train_dataloader(self):
self.train_contrastive_dataset = ContrastiveDataset(self.train_dataset)
return DataLoader(self.train_contrastive_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=partial(self.collate_fn, stage=Stage.TRAIN),
)
def val_dataloader(self):
self.val_contrastive_dataset = ContrastiveDataset(self.val_dataset)
return DataLoader(self.val_contrastive_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=partial(self.collate_fn, stage=Stage.VAL))
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
drop_last=False,
collate_fn=self.dataset.collate_fn,
)