import numpy as np import torch import matchms from typing import Optional from rdkit.Chem import AllChem as Chem from mvp.definitions import CHEM_ELEMS_SMALL from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms from massspecgym.data.transforms import SpecBinner import dgllife.utils as chemutils import re class SpecBinnerLog(SpecTransform): def __init__( self, max_mz: float = 1005, bin_width: float = 1, ) -> None: self.max_mz = max_mz self.bin_width = bin_width if not (max_mz / bin_width).is_integer(): raise ValueError("`max_mz` must be divisible by `bin_width`.") def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum: return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None) def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor: """ Bin the spectrum into a fixed number of bins. """ binned_spec = self._bin_mass_spectrum( mzs=spec.peaks.mz, intensities=spec.peaks.intensities, max_mz=self.max_mz, bin_width=self.bin_width, ) return torch.from_numpy(binned_spec).to(dtype=torch.float32) def _bin_mass_spectrum( self, mzs, intensities, max_mz, bin_width ): # Calculate the number of bins num_bins = int(np.ceil(max_mz / bin_width)) # Calculate the bin indices for each mass bin_indices = np.floor(mzs -1 / bin_width).astype(int) # Filter out mzs that exceed max_mz valid_indices = bin_indices[mzs <= max_mz] valid_intensities = intensities[mzs <= max_mz] # Clip bin indices to ensure they are within the valid range valid_indices = np.clip(valid_indices, 0, num_bins - 1) # Initialize an array to store the binned intensities binned_intensities = np.zeros(num_bins) # Use np.add.at to sum intensities in the appropriate bins np.add.at(binned_intensities, valid_indices, valid_intensities) binned_intensities = binned_intensities/np.max(binned_intensities) * 999 binned_intensities = np.log10(binned_intensities + 1) / 3 return binned_intensities class SpecFormulaFeaturizer(SpecTransform): ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only''' def __init__( self, add_intensities: bool, max_mz: float = 1005, element_list: list = CHEM_ELEMS_SMALL, formula_normalize_vector: Optional[np.array] = None ) -> None: self.max_mz = max_mz self.elem_to_pos = {e: i for i, e in enumerate(element_list)} self.add_intensities = add_intensities if formula_normalize_vector is None: formula_normalize_vector = np.ones(len(element_list)) self.formula_normalize_vector = formula_normalize_vector self.CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)" def matchms_transforms(self, spec: matchms.Spectrum): return spec def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor: mzs = spec.peaks.mz intensities = spec.peaks.intensities formulas = spec.metadata['formulas'] # list of formulas peak_idx = np.where(mzs <= self.max_mz)[0] intensities = intensities[peak_idx] formulas = formulas[peak_idx] spec = self._featurize_formula(formulas) spec = spec/self.formula_normalize_vector if self.add_intensities: spec = np.concatenate((spec, intensities.reshape(-1,1)), axis=1) spec = spec.astype(np.float32) return torch.from_numpy(spec) def _featurize_formula(self, formulas): formula_vector = np.zeros((len(formulas), len(self.elem_to_pos))) for i, f in enumerate(formulas): try: for (e, ct) in re.findall(self.CHEM_FORMULA_SIZE, f): ct = 1 if ct == "" else int(ct) try: formula_vector[i][self.elem_to_pos[e]]+=ct except: print(f"Couldn't vectorize {f}, element {e} not supported") continue except: print(f"Couldn't vectorize {f}, formula not supported") continue return formula_vector class MolToGraph(MolTransform): def __init__ (self, atom_feature: str = "full", bond_feature: str = "full", element_list: list = CHEM_ELEMS_SMALL): self.atom_feature = atom_feature self.bond_feature = bond_feature self.node_featurizer = self._get_atom_featurizer(element_list=element_list) self.edge_featurizer = self._get_bond_featurizer() def from_smiles(self, mol:str): mol = Chem.MolFromSmiles(mol) g = chemutils.mol_to_bigraph(mol, node_featurizer=self.node_featurizer, edge_featurizer=self.edge_featurizer, add_self_loop = True, num_virtual_nodes = 0, canonical_atom_order=False) # atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] # added for visualization # g.ndata['atom_id'] = torch.tensor(atom_ids, dtype=torch.long) return g def _get_atom_featurizer(self, element_list) -> dict: feature_mode = self.atom_feature atom_mass_fun = chemutils.ConcatFeaturizer( [chemutils.atom_mass] ) def atom_bond_type_one_hot(atom): bs = atom.GetBonds() bt = np.array([chemutils.bond_type_one_hot(b) for b in bs]) return [any(bt[:, i]) for i in range(bt.shape[1])] def atom_type_one_hot(atom): return chemutils.atom_type_one_hot( atom, allowable_set = element_list, encode_unknown = True ) if feature_mode == 'light': atom_featurizer_funs = chemutils.ConcatFeaturizer([ chemutils.atom_mass, atom_type_one_hot ]) elif feature_mode == 'full': atom_featurizer_funs = chemutils.ConcatFeaturizer([ chemutils.atom_mass, atom_type_one_hot, atom_bond_type_one_hot, chemutils.atom_degree_one_hot, chemutils.atom_total_degree_one_hot, chemutils.atom_explicit_valence_one_hot, chemutils.atom_implicit_valence_one_hot, chemutils.atom_hybridization_one_hot, chemutils.atom_total_num_H_one_hot, chemutils.atom_formal_charge_one_hot, chemutils.atom_num_radical_electrons_one_hot, chemutils.atom_is_aromatic_one_hot, chemutils.atom_is_in_ring_one_hot, chemutils.atom_chiral_tag_one_hot ]) elif feature_mode == 'medium': atom_featurizer_funs = chemutils.ConcatFeaturizer([ chemutils.atom_mass, atom_type_one_hot, atom_bond_type_one_hot, chemutils.atom_total_degree_one_hot, chemutils.atom_total_num_H_one_hot, chemutils.atom_is_aromatic_one_hot, chemutils.atom_is_in_ring_one_hot, ]) return chemutils.BaseAtomFeaturizer( {"h": atom_featurizer_funs, "m": atom_mass_fun} ) def _get_bond_featurizer(self, self_loop=True) -> dict: feature_mode = self.bond_feature if feature_mode == 'light': return chemutils.BaseBondFeaturizer( featurizer_funcs = {'e': chemutils.ConcatFeaturizer([ chemutils.bond_type_one_hot ])}, self_loop = self_loop ) elif feature_mode == 'full': return chemutils.CanonicalBondFeaturizer( bond_data_field='e', self_loop = self_loop )