MVP / massspecgym /data /transforms.py
yzhouchen001's picture
partial push
94aa6f9
import numpy as np
import torch
import matchms
import matchms.filtering as ms_filters
from rdkit.Chem import AllChem as Chem
from typing import Optional
from abc import ABC, abstractmethod
import massspecgym.utils as utils
from massspecgym.definitions import CHEM_ELEMS
class SpecTransform(ABC):
"""
Base class for spectrum transformations. Custom transformatios should inherit from this class.
The transformation consists of two consecutive steps:
1. Apply a series of matchms filters to the input spectrum (method `matchms_transforms`).
2. Convert the matchms spectrum to a torch tensor (method `matchms_to_torch`).
"""
@abstractmethod
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
"""
Apply a series of matchms filters to the input spectrum. Abstract method.
"""
@abstractmethod
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
"""
Convert a matchms spectrum to a torch tensor. Abstract method.
"""
def __call__(self, spec: matchms.Spectrum) -> torch.Tensor:
"""
Compose the matchms filters and the torch conversion.
"""
return self.matchms_to_torch(self.matchms_transforms(spec))
def default_matchms_transforms(
spec: matchms.Spectrum,
n_max_peaks: int = 60,
mz_from: float = 10,
mz_to: float = 1000,
) -> matchms.Spectrum:
spec = ms_filters.select_by_mz(spec, mz_from=mz_from, mz_to=mz_to)
if n_max_peaks is not None:
spec = ms_filters.reduce_to_number_of_peaks(spec, n_max=n_max_peaks)
spec = ms_filters.normalize_intensities(spec)
return spec
class SpecTokenizer(SpecTransform):
def __init__(
self,
n_peaks: Optional[int] = 60,
prec_mz_intensity: Optional[float] = 1.1,
matchms_kwargs: Optional[dict] = None
) -> None:
self.n_peaks = n_peaks
self.prec_mz_intensity = prec_mz_intensity
self.matchms_kwargs = matchms_kwargs if matchms_kwargs is not None else {}
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
return default_matchms_transforms(spec, n_max_peaks=self.n_peaks, **self.matchms_kwargs)
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
"""
Stack arrays of mz and intensities into a matrix of shape (num_peaks, 2).
If the number of peaks is less than `n_peaks`, pad the matrix with zeros.
"""
spec_t = np.vstack([spec.peaks.mz, spec.peaks.intensities]).T
if self.prec_mz_intensity is not None:
spec_t = np.vstack([[spec.metadata["precursor_mz"], self.prec_mz_intensity], spec_t])
if self.n_peaks is not None:
spec_t = utils.pad_spectrum(
spec_t,
self.n_peaks + 1 if self.prec_mz_intensity is not None else self.n_peaks
)
return torch.from_numpy(spec_t)
class SpecBinner(SpecTransform):
def __init__(
self,
max_mz: float = 1005,
bin_width: float = 1,
to_rel_intensities: bool = True,
) -> None:
self.max_mz = max_mz
self.bin_width = bin_width
self.to_rel_intensities = to_rel_intensities
if not (max_mz / bin_width).is_integer():
raise ValueError("`max_mz` must be divisible by `bin_width`.")
def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum:
return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None)
def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor:
"""
Bin the spectrum into a fixed number of bins.
"""
binned_spec = self._bin_mass_spectrum(
mzs=spec.peaks.mz,
intensities=spec.peaks.intensities,
max_mz=self.max_mz,
bin_width=self.bin_width,
to_rel_intensities=self.to_rel_intensities,
)
return torch.from_numpy(binned_spec)
def _bin_mass_spectrum(
self, mzs, intensities, max_mz, bin_width, to_rel_intensities=True
):
# Calculate the number of bins
num_bins = int(np.ceil(max_mz / bin_width))
# Calculate the bin indices for each mass
bin_indices = np.floor(mzs / bin_width).astype(int)
# Filter out mzs that exceed max_mz
valid_indices = bin_indices[mzs <= max_mz]
valid_intensities = intensities[mzs <= max_mz]
# Clip bin indices to ensure they are within the valid range
valid_indices = np.clip(valid_indices, 0, num_bins - 1)
# Initialize an array to store the binned intensities
binned_intensities = np.zeros(num_bins)
# Use np.add.at to sum intensities in the appropriate bins
np.add.at(binned_intensities, valid_indices, valid_intensities)
# Generate the bin edges for reference
# bin_edges = np.arange(0, max_mz + bin_width, bin_width)
# Normalize the intensities to relative intensities
if to_rel_intensities:
binned_intensities /= np.max(binned_intensities)
return binned_intensities # , bin_edges
class MolTransform(ABC):
@abstractmethod
def from_smiles(self, mol: str):
"""
Convert a SMILES string to a tensor-like representation. Abstract method.
"""
def __call__(self, mol: str):
return self.from_smiles(mol)
class MolFingerprinter(MolTransform):
def __init__(self, type: str = "morgan", fp_size: int = 2048, radius: int = 2):
if type != "morgan":
raise NotImplementedError(
"Only Morgan fingerprints are implemented at the moment."
)
self.type = type
self.fp_size = fp_size
self.radius = radius
def from_smiles(self, mol: str):
mol = Chem.MolFromSmiles(mol)
return utils.morgan_fp(
mol, fp_size=self.fp_size, radius=self.radius, to_np=True
)
class MolToInChIKey(MolTransform):
def __init__(self, twod: bool = True) -> None:
self.twod = twod
def from_smiles(self, mol: str) -> str:
mol = Chem.MolFromSmiles(mol)
return utils.mol_to_inchi_key(mol, twod=self.twod)
class MolToFormulaVector(MolTransform):
def __init__(self):
self.element_index = {element: i for i, element in enumerate(CHEM_ELEMS)}
def from_smiles(self, smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Invalid SMILES string: {smiles}")
# Add explicit hydrogens to the molecule
mol = Chem.AddHs(mol)
# Initialize a vector of zeros for the 118 elements
formula_vector = np.zeros(118, dtype=np.int32)
# Iterate over atoms in the molecule and count occurrences of each element
for atom in mol.GetAtoms():
symbol = atom.GetSymbol()
if symbol in self.element_index:
index = self.element_index[symbol]
formula_vector[index] += 1
else:
raise ValueError(f"Element '{symbol}' not found in the list of 118 elements.")
return formula_vector
@staticmethod
def num_elements():
return len(CHEM_ELEMS)