Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import matchms | |
| from typing import Optional | |
| from rdkit.Chem import AllChem as Chem | |
| from mvp.definitions import CHEM_ELEMS_SMALL | |
| from massspecgym.data.transforms import MolTransform, SpecTransform, default_matchms_transforms | |
| from massspecgym.data.transforms import SpecBinner | |
| import dgllife.utils as chemutils | |
| import re | |
| class SpecBinnerLog(SpecTransform): | |
| def __init__( | |
| self, | |
| max_mz: float = 1005, | |
| bin_width: float = 1, | |
| ) -> None: | |
| self.max_mz = max_mz | |
| self.bin_width = bin_width | |
| if not (max_mz / bin_width).is_integer(): | |
| raise ValueError("`max_mz` must be divisible by `bin_width`.") | |
| def matchms_transforms(self, spec: matchms.Spectrum) -> matchms.Spectrum: | |
| return default_matchms_transforms(spec, mz_to=self.max_mz, n_max_peaks=None) | |
| def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor: | |
| """ | |
| Bin the spectrum into a fixed number of bins. | |
| """ | |
| binned_spec = self._bin_mass_spectrum( | |
| mzs=spec.peaks.mz, | |
| intensities=spec.peaks.intensities, | |
| max_mz=self.max_mz, | |
| bin_width=self.bin_width, | |
| ) | |
| return torch.from_numpy(binned_spec).to(dtype=torch.float32) | |
| def _bin_mass_spectrum( | |
| self, mzs, intensities, max_mz, bin_width | |
| ): | |
| # Calculate the number of bins | |
| num_bins = int(np.ceil(max_mz / bin_width)) | |
| # Calculate the bin indices for each mass | |
| bin_indices = np.floor(mzs -1 / bin_width).astype(int) | |
| # Filter out mzs that exceed max_mz | |
| valid_indices = bin_indices[mzs <= max_mz] | |
| valid_intensities = intensities[mzs <= max_mz] | |
| # Clip bin indices to ensure they are within the valid range | |
| valid_indices = np.clip(valid_indices, 0, num_bins - 1) | |
| # Initialize an array to store the binned intensities | |
| binned_intensities = np.zeros(num_bins) | |
| # Use np.add.at to sum intensities in the appropriate bins | |
| np.add.at(binned_intensities, valid_indices, valid_intensities) | |
| binned_intensities = binned_intensities/np.max(binned_intensities) * 999 | |
| binned_intensities = np.log10(binned_intensities + 1) / 3 | |
| return binned_intensities | |
| class SpecFormulaFeaturizer(SpecTransform): | |
| ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only''' | |
| def __init__( | |
| self, | |
| add_intensities: bool, | |
| max_mz: float = 1005, | |
| element_list: list = CHEM_ELEMS_SMALL, | |
| formula_normalize_vector: Optional[np.array] = None | |
| ) -> None: | |
| self.max_mz = max_mz | |
| self.elem_to_pos = {e: i for i, e in enumerate(element_list)} | |
| self.add_intensities = add_intensities | |
| if formula_normalize_vector is None: | |
| formula_normalize_vector = np.ones(len(element_list)) | |
| self.formula_normalize_vector = formula_normalize_vector | |
| self.CHEM_FORMULA_SIZE = "([A-Z][a-z]*)([0-9]*)" | |
| def matchms_transforms(self, spec: matchms.Spectrum): | |
| return spec | |
| def matchms_to_torch(self, spec: matchms.Spectrum) -> torch.Tensor: | |
| mzs = spec.peaks.mz | |
| intensities = spec.peaks.intensities | |
| formulas = spec.metadata['formulas'] # list of formulas | |
| peak_idx = np.where(mzs <= self.max_mz)[0] | |
| intensities = intensities[peak_idx] | |
| formulas = formulas[peak_idx] | |
| spec = self._featurize_formula(formulas) | |
| spec = spec/self.formula_normalize_vector | |
| if self.add_intensities: | |
| spec = np.concatenate((spec, intensities.reshape(-1,1)), axis=1) | |
| spec = spec.astype(np.float32) | |
| return torch.from_numpy(spec) | |
| def _featurize_formula(self, formulas): | |
| formula_vector = np.zeros((len(formulas), len(self.elem_to_pos))) | |
| for i, f in enumerate(formulas): | |
| try: | |
| for (e, ct) in re.findall(self.CHEM_FORMULA_SIZE, f): | |
| ct = 1 if ct == "" else int(ct) | |
| try: | |
| formula_vector[i][self.elem_to_pos[e]]+=ct | |
| except: | |
| print(f"Couldn't vectorize {f}, element {e} not supported") | |
| continue | |
| except: | |
| print(f"Couldn't vectorize {f}, formula not supported") | |
| continue | |
| return formula_vector | |
| class MolToGraph(MolTransform): | |
| def __init__ (self, atom_feature: str = "full", bond_feature: str = "full", element_list: list = CHEM_ELEMS_SMALL): | |
| self.atom_feature = atom_feature | |
| self.bond_feature = bond_feature | |
| self.node_featurizer = self._get_atom_featurizer(element_list=element_list) | |
| self.edge_featurizer = self._get_bond_featurizer() | |
| def from_smiles(self, mol:str): | |
| mol = Chem.MolFromSmiles(mol) | |
| g = chemutils.mol_to_bigraph(mol, node_featurizer=self.node_featurizer, edge_featurizer=self.edge_featurizer, add_self_loop = True, | |
| num_virtual_nodes = 0, canonical_atom_order=False) | |
| # atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] # added for visualization | |
| # g.ndata['atom_id'] = torch.tensor(atom_ids, dtype=torch.long) | |
| return g | |
| def _get_atom_featurizer(self, element_list) -> dict: | |
| feature_mode = self.atom_feature | |
| atom_mass_fun = chemutils.ConcatFeaturizer( | |
| [chemutils.atom_mass] | |
| ) | |
| def atom_bond_type_one_hot(atom): | |
| bs = atom.GetBonds() | |
| bt = np.array([chemutils.bond_type_one_hot(b) for b in bs]) | |
| return [any(bt[:, i]) for i in range(bt.shape[1])] | |
| def atom_type_one_hot(atom): | |
| return chemutils.atom_type_one_hot( | |
| atom, allowable_set = element_list, encode_unknown = True | |
| ) | |
| if feature_mode == 'light': | |
| atom_featurizer_funs = chemutils.ConcatFeaturizer([ | |
| chemutils.atom_mass, | |
| atom_type_one_hot | |
| ]) | |
| elif feature_mode == 'full': | |
| atom_featurizer_funs = chemutils.ConcatFeaturizer([ | |
| chemutils.atom_mass, | |
| atom_type_one_hot, | |
| atom_bond_type_one_hot, | |
| chemutils.atom_degree_one_hot, | |
| chemutils.atom_total_degree_one_hot, | |
| chemutils.atom_explicit_valence_one_hot, | |
| chemutils.atom_implicit_valence_one_hot, | |
| chemutils.atom_hybridization_one_hot, | |
| chemutils.atom_total_num_H_one_hot, | |
| chemutils.atom_formal_charge_one_hot, | |
| chemutils.atom_num_radical_electrons_one_hot, | |
| chemutils.atom_is_aromatic_one_hot, | |
| chemutils.atom_is_in_ring_one_hot, | |
| chemutils.atom_chiral_tag_one_hot | |
| ]) | |
| elif feature_mode == 'medium': | |
| atom_featurizer_funs = chemutils.ConcatFeaturizer([ | |
| chemutils.atom_mass, | |
| atom_type_one_hot, | |
| atom_bond_type_one_hot, | |
| chemutils.atom_total_degree_one_hot, | |
| chemutils.atom_total_num_H_one_hot, | |
| chemutils.atom_is_aromatic_one_hot, | |
| chemutils.atom_is_in_ring_one_hot, | |
| ]) | |
| return chemutils.BaseAtomFeaturizer( | |
| {"h": atom_featurizer_funs, | |
| "m": atom_mass_fun} | |
| ) | |
| def _get_bond_featurizer(self, self_loop=True) -> dict: | |
| feature_mode = self.bond_feature | |
| if feature_mode == 'light': | |
| return chemutils.BaseBondFeaturizer( | |
| featurizer_funcs = {'e': chemutils.ConcatFeaturizer([ | |
| chemutils.bond_type_one_hot | |
| ])}, self_loop = self_loop | |
| ) | |
| elif feature_mode == 'full': | |
| return chemutils.CanonicalBondFeaturizer( | |
| bond_data_field='e', self_loop = self_loop | |
| ) | |