Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import json | |
| import typing as T | |
| import numpy as np | |
| import torch | |
| import matchms | |
| import massspecgym.utils as utils | |
| from pathlib import Path | |
| from rdkit import Chem | |
| from torch.utils.data.dataset import Dataset | |
| from torch.utils.data.dataloader import default_collate | |
| from matchms.importing import load_from_mgf | |
| from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey | |
| class MassSpecDataset(Dataset): | |
| """ | |
| Dataset containing mass spectra and their corresponding molecular structures. This class is | |
| responsible for loading the data from disk and applying transformation steps to the spectra and | |
| molecules. | |
| """ | |
| def __init__( | |
| self, | |
| spec_transform: T.Optional[T.Union[SpecTransform, T.Dict[str, SpecTransform]]] = None, | |
| mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]] = None, | |
| pth: T.Optional[Path] = None, | |
| return_mol_freq: bool = True, | |
| return_identifier: bool = True, | |
| dtype: T.Type = torch.float32 | |
| ): | |
| """ | |
| Args: | |
| pth (Optional[Path], optional): Path to the .tsv or .mgf file containing the mass spectra. | |
| Default is None, in which case the MassSpecGym dataset is downloaded from HuggingFace Hub. | |
| """ | |
| self.pth = pth | |
| self.spec_transform = spec_transform | |
| self.mol_transform = mol_transform | |
| self.return_mol_freq = return_mol_freq | |
| if self.pth is None: | |
| self.pth = utils.hugging_face_download("MassSpecGym.tsv") | |
| if isinstance(self.pth, str): | |
| self.pth = Path(self.pth) | |
| if self.pth.suffix == ".tsv": | |
| self.metadata = pd.read_csv(self.pth, sep="\t") | |
| self.spectra = self.metadata.apply( | |
| lambda row: matchms.Spectrum( | |
| mz=np.array([float(m) for m in row["mzs"].split(",")]), | |
| intensities=np.array( | |
| [float(i) for i in row["intensities"].split(",")] | |
| ), | |
| metadata={"precursor_mz": row["precursor_mz"]}, | |
| ), | |
| axis=1, | |
| ) | |
| self.metadata = self.metadata.drop(columns=["mzs", "intensities"]) | |
| elif self.pth.suffix == ".mgf": | |
| self.spectra = list(load_from_mgf(str(self.pth))) | |
| self.metadata = pd.DataFrame([s.metadata for s in self.spectra]) | |
| else: | |
| raise ValueError(f"{self.pth.suffix} file format not supported.") | |
| if self.return_mol_freq: | |
| if "inchikey" not in self.metadata.columns: | |
| self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key) | |
| self.metadata["mol_freq"] = self.metadata.groupby("inchikey")["inchikey"].transform("count") | |
| self.return_identifier = return_identifier | |
| self.dtype = dtype | |
| def __len__(self) -> int: | |
| return len(self.spectra) | |
| def __getitem__( | |
| self, i: int, transform_spec: bool = True, transform_mol: bool = True | |
| ) -> dict: | |
| spec = self.spectra[i] | |
| metadata = self.metadata.iloc[i] | |
| mol = metadata["smiles"] | |
| # Apply all transformations to the spectrum | |
| item = {} | |
| if transform_spec and self.spec_transform: | |
| if isinstance(self.spec_transform, dict): | |
| for key, transform in self.spec_transform.items(): | |
| item[key] = transform(spec) if transform is not None else spec | |
| else: | |
| item["spec"] = self.spec_transform(spec) | |
| else: | |
| item["spec"] = spec | |
| # Apply all transformations to the molecule | |
| if transform_mol and self.mol_transform: | |
| if isinstance(self.mol_transform, dict): | |
| for key, transform in self.mol_transform.items(): | |
| item[key] = transform(mol) if transform is not None else mol | |
| else: | |
| item["mol"] = self.mol_transform(mol) | |
| else: | |
| item["mol"] = mol | |
| # Add other metadata to the item | |
| # item.update({ | |
| # k: metadata[k] for k in ["precursor_mz", "adduct"] | |
| # }) | |
| if self.return_mol_freq: | |
| item["mol_freq"] = metadata["mol_freq"] | |
| if self.return_identifier: | |
| item["identifier"] = metadata["identifier"] | |
| # TODO: this should be refactored | |
| for k, v in item.items(): | |
| if not isinstance(v, str): | |
| try: | |
| item[k] = torch.as_tensor(v, dtype=self.dtype) | |
| except: | |
| continue | |
| return item | |
| def collate_fn(batch: T.Iterable[dict]) -> dict: | |
| """ | |
| Custom collate function to handle the outputs of __getitem__. | |
| """ | |
| return default_collate(batch) | |
| class RetrievalDataset(MassSpecDataset): | |
| """ | |
| Dataset containing mass spectra and their corresponding molecular structures, with additional | |
| candidates of molecules for retrieval based on spectral similarity. | |
| """ | |
| def __init__( | |
| self, | |
| mol_label_transform: MolTransform = MolToInChIKey(), | |
| candidates_pth: T.Optional[T.Union[Path, str]] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.candidates_pth = candidates_pth | |
| self.mol_label_transform = mol_label_transform | |
| # Download candidates from HuggigFace Hub if not a path to exisiting file is passed | |
| if self.candidates_pth is None: | |
| self.candidates_pth = utils.hugging_face_download( | |
| "molecules/MassSpecGym_retrieval_candidates_mass.json" | |
| ) | |
| elif isinstance(self.candidates_pth, str): | |
| if Path(self.candidates_pth).is_file(): | |
| self.candidates_pth = Path(self.candidates_pth) | |
| else: | |
| self.candidates_pth = utils.hugging_face_download(candidates_pth) | |
| # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES | |
| with open(self.candidates_pth, "r") as file: | |
| self.candidates = json.load(file) | |
| def __getitem__(self, i) -> dict: | |
| item = super().__getitem__(i, transform_mol=False) | |
| # Save the original SMILES representation of the query molecule (for evaluation) | |
| item["smiles"] = item["mol"] | |
| # Get candidates | |
| if item["mol"] not in self.candidates: | |
| raise ValueError(f'No candidates for the query molecule {item["mol"]}.') | |
| item["candidates"] = self.candidates[item["mol"]] | |
| # Save the original SMILES representations of the canidates (for evaluation) | |
| item["candidates_smiles"] = item["candidates"] | |
| # Create neg/pos label mask by matching the query molecule with the candidates | |
| item_label = self.mol_label_transform(item["mol"]) | |
| item["labels"] = [ | |
| self.mol_label_transform(c) == item_label for c in item["candidates"] | |
| ] | |
| if not any(item["labels"]): | |
| raise ValueError( | |
| f'Query molecule {item["mol"]} not found in the candidates list.' | |
| ) | |
| # Transform the query and candidate molecules | |
| item["mol"] = self.mol_transform(item["mol"]) | |
| item["candidates"] = [self.mol_transform(c) for c in item["candidates"]] | |
| if isinstance(item["mol"], np.ndarray): | |
| item["mol"] = torch.as_tensor(item["mol"], dtype=self.dtype) | |
| # item["candidates"] = [torch.as_tensor(c, dtype=self.dtype) for c in item["candidates"]] | |
| return item | |
| def collate_fn(batch: T.Iterable[dict]) -> dict: | |
| # Standard collate for everything except candidates and their labels (which may have different length per sample) | |
| collated_batch = {} | |
| for k in batch[0].keys(): | |
| if k not in ["candidates", "labels", "candidates_smiles"]: | |
| collated_batch[k] = default_collate([item[k] for item in batch]) | |
| # Collate candidates and labels by concatenating and storing sizes of each list | |
| collated_batch["candidates"] = torch.as_tensor( | |
| np.concatenate([item["candidates"] for item in batch]) | |
| ) | |
| collated_batch["labels"] = torch.as_tensor( | |
| sum([item["labels"] for item in batch], start=[]) | |
| ) | |
| collated_batch["batch_ptr"] = torch.as_tensor( | |
| [len(item["candidates"]) for item in batch] | |
| ) | |
| collated_batch["candidates_smiles"] = \ | |
| sum([item["candidates_smiles"] for item in batch], start=[]) | |
| return collated_batch | |
| # TODO: Datasets for unlabeled data. | |