FLARE / massspecgym /data /datasets.py
yzhouchen001's picture
msgym
e689ac2
import pandas as pd
import json
import typing as T
import numpy as np
import torch
import matchms
import massspecgym.utils as utils
from pathlib import Path
from rdkit import Chem
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import default_collate
from matchms.importing import load_from_mgf
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
class MassSpecDataset(Dataset):
"""
Dataset containing mass spectra and their corresponding molecular structures. This class is
responsible for loading the data from disk and applying transformation steps to the spectra and
molecules.
"""
def __init__(
self,
spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]] = None,
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]] = None,
pth: T.Optional[Path] = None,
return_mol_freq: bool = True,
return_identifier: bool = True,
dtype: T.Type = torch.float32
):
"""
Args:
pth (Optional[Path], optional): Path to the .tsv or .mgf file containing the mass spectra.
Default is None, in which case the MassSpecGym dataset is downloaded from HuggingFace Hub.
"""
self.pth = pth
self.spec_transform = spec_transform
self.mol_transform = mol_transform
self.return_mol_freq = return_mol_freq
if self.pth is None:
self.pth = utils.hugging_face_download("MassSpecGym.tsv")
if isinstance(self.pth, str):
self.pth = Path(self.pth)
if self.pth.suffix == ".tsv":
self.metadata = pd.read_csv(self.pth, sep="\t")
self.spectra = self.metadata.apply(
lambda row: matchms.Spectrum(
mz=np.array([float(m) for m in row["mzs"].split(",")]),
intensities=np.array(
[float(i) for i in row["intensities"].split(",")]
),
metadata={"precursor_mz": row["precursor_mz"]},
),
axis=1,
)
self.metadata = self.metadata.drop(columns=["mzs", "intensities"])
elif self.pth.suffix == ".mgf":
self.spectra = list(load_from_mgf(str(self.pth)))
self.metadata = pd.DataFrame([s.metadata for s in self.spectra])
else:
raise ValueError(f"{self.pth.suffix} file format not supported.")
if self.return_mol_freq:
if "inchikey" not in self.metadata.columns:
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count")
self.return_identifier = return_identifier
self.dtype = dtype
def __len__(self) -> int:
return len(self.spectra)
def __getitem__(
self, i: int, transform_spec: bool = True, transform_mol: bool = True
) -> dict:
spec = self.spectra[i]
metadata = self.metadata.iloc[i]
mol = metadata["smiles"]
# Apply all transformations to the spectrum
item = {}
if transform_spec and self.spec_transform:
if isinstance(self.spec_transform, dict):
for key, transform in self.spec_transform.items():
item[key] = transform(spec) if transform is not None else spec
else:
item["spec"] = self.spec_transform(spec)
else:
item["spec"] = spec
# Apply all transformations to the molecule
if transform_mol and self.mol_transform:
if isinstance(self.mol_transform, dict):
for key, transform in self.mol_transform.items():
item[key] = transform(mol) if transform is not None else mol
else:
item["mol"] = self.mol_transform(mol)
else:
item["mol"] = mol
# Add other metadata to the item
# item.update({
# k: metadata[k] for k in ["precursor_mz", "adduct"]
# })
if self.return_mol_freq:
item["mol_freq"] = metadata["mol_freq"]
if self.return_identifier:
item["identifier"] = metadata["identifier"]
# TODO: this should be refactored
for k, v in item.items():
if not isinstance(v, str):
try:
item[k] = torch.as_tensor(v, dtype=self.dtype)
except:
continue
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict]) -> dict:
"""
Custom collate function to handle the outputs of __getitem__.
"""
return default_collate(batch)
class RetrievalDataset(MassSpecDataset):
"""
Dataset containing mass spectra and their corresponding molecular structures, with additional
candidates of molecules for retrieval based on spectral similarity.
"""
def __init__(
self,
mol_label_transform: MolTransform = MolToInChIKey(),
candidates_pth: T.Optional[T.Union[Path, str]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.candidates_pth = candidates_pth
self.mol_label_transform = mol_label_transform
# Download candidates from HuggigFace Hub if not a path to exisiting file is passed
if self.candidates_pth is None:
self.candidates_pth = utils.hugging_face_download(
"molecules/MassSpecGym_retrieval_candidates_mass.json"
)
elif isinstance(self.candidates_pth, str):
if Path(self.candidates_pth).is_file():
self.candidates_pth = Path(self.candidates_pth)
else:
self.candidates_pth = utils.hugging_face_download(candidates_pth)
# Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
with open(self.candidates_pth, "r") as file:
self.candidates = json.load(file)
def __getitem__(self, i) -> dict:
item = super().__getitem__(i, transform_mol=False)
# Save the original SMILES representation of the query molecule (for evaluation)
item["smiles"] = item["mol"]
# Get candidates
if item["mol"] not in self.candidates:
raise ValueError(f'No candidates for the query molecule {item["mol"]}.')
item["candidates"] = self.candidates[item["mol"]]
# Save the original SMILES representations of the canidates (for evaluation)
item["candidates_smiles"] = item["candidates"]
# Create neg/pos label mask by matching the query molecule with the candidates
item_label = self.mol_label_transform(item["mol"])
item["labels"] = [
self.mol_label_transform(c) == item_label for c in item["candidates"]
]
if not any(item["labels"]):
raise ValueError(
f'Query molecule {item["mol"]} not found in the candidates list.'
)
# Transform the query and candidate molecules
item["mol"] = self.mol_transform(item["mol"])
item["candidates"] = [self.mol_transform(c) for c in item["candidates"]]
if isinstance(item["mol"], np.ndarray):
item["mol"] = torch.as_tensor(item["mol"], dtype=self.dtype)
# item["candidates"] = [torch.as_tensor(c, dtype=self.dtype) for c in item["candidates"]]
return item
@staticmethod
def collate_fn(batch: T.Iterable[dict]) -> dict:
# Standard collate for everything except candidates and their labels (which may have different length per sample)
collated_batch = {}
for k in batch[0].keys():
if k not in ["candidates", "labels", "candidates_smiles"]:
collated_batch[k] = default_collate([item[k] for item in batch])
# Collate candidates and labels by concatenating and storing sizes of each list
collated_batch["candidates"] = torch.as_tensor(
np.concatenate([item["candidates"] for item in batch])
)
collated_batch["labels"] = torch.as_tensor(
sum([item["labels"] for item in batch], start=[])
)
collated_batch["batch_ptr"] = torch.as_tensor(
[len(item["candidates"]) for item in batch]
)
collated_batch["candidates_smiles"] = \
sum([item["candidates_smiles"] for item in batch], start=[])
return collated_batch
# TODO: Datasets for unlabeled data.