Spaces:
Sleeping
Sleeping
| import numpy as np | |
| # import seaborn as sns | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors | |
| import matplotlib.cm as cm | |
| import matplotlib.colors as mcolors | |
| import matplotlib.ticker as ticker | |
| import pandas as pd | |
| import typing as T | |
| import pulp | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from itertools import groupby | |
| from pathlib import Path | |
| # from myopic_mces.myopic_mces import MCES | |
| from rdkit.Chem import AllChem as Chem | |
| from rdkit.Chem import DataStructs, Draw | |
| from rdkit.Chem.Descriptors import ExactMolWt | |
| # from huggingface_hub import hf_hub_download | |
| # from standardizeUtils.standardizeUtils import ( | |
| # standardize_structure_with_pubchem, | |
| # standardize_structure_list_with_pubchem, | |
| # ) | |
| from torchmetrics.wrappers import BootStrapper | |
| from torchmetrics.metric import Metric | |
| def load_massspecgym(fold: T.Optional[str] = None) -> pd.DataFrame: | |
| """ | |
| Load the MassSpecGym dataset. | |
| Args: | |
| fold (str, optional): Fold name to load. If None, the entire dataset is loaded. | |
| """ | |
| df = pd.read_csv(hugging_face_download("MassSpecGym.tsv"), sep="\t") | |
| df = df.set_index("identifier") | |
| df['mzs'] = df['mzs'].apply(parse_spec_array) | |
| df['intensities'] = df['intensities'].apply(parse_spec_array) | |
| if fold is not None: | |
| df = df[df['fold'] == fold] | |
| return df | |
| def load_unlabeled_mols(col_name: str = "smiles") -> pd.Series: | |
| """ | |
| Load a list of unlabeled molecules. | |
| Args: | |
| col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"]. | |
| """ | |
| return pd.read_csv( | |
| hugging_face_download( | |
| "molecules/MassSpecGym_molecules_MCES2_disjoint_with_test_fold_4M.tsv" | |
| ), | |
| sep="\t" | |
| )[col_name] | |
| def load_train_mols(col_name: str = "smiles") -> pd.Series: | |
| """ | |
| Load a list of training molecules. | |
| Args: | |
| col_name (str, optional): Name of the column to return. Should be one of ["smiles", "selfies"]. | |
| """ | |
| return load_massspecgym("train")[col_name] | |
| def pad_spectrum( | |
| spec: np.ndarray, max_n_peaks: int, pad_value: float = 0.0 | |
| ) -> np.ndarray: | |
| """ | |
| Pad a spectrum to a fixed number of peaks by appending zeros to the end of the spectrum. | |
| Args: | |
| spec (np.ndarray): Spectrum to pad represented as numpy array of shape (n_peaks, 2). | |
| max_n_peaks (int): Maximum number of peaks in the padded spectrum. | |
| pad_value (float, optional): Value to use for padding. | |
| """ | |
| n_peaks = spec.shape[0] | |
| if n_peaks > max_n_peaks: | |
| raise ValueError( | |
| f"Number of peaks in the spectrum ({n_peaks}) is greater than the maximum number of peaks." | |
| ) | |
| else: | |
| return np.pad( | |
| spec, | |
| ((0, max_n_peaks - n_peaks), (0, 0)), | |
| mode="constant", | |
| constant_values=pad_value, | |
| ) | |
| def morgan_fp(mol: Chem.Mol, fp_size=2048, radius=2, to_np=True): | |
| """ | |
| Compute Morgan fingerprint for a molecule. | |
| Args: | |
| mol (Chem.Mol): _description_ | |
| fp_size (int, optional): Size of the fingerprint. | |
| radius (int, optional): Radius of the fingerprint. | |
| to_np (bool, optional): Convert the fingerprint to numpy array. | |
| """ | |
| fp = Chem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=fp_size) | |
| if to_np: | |
| fp_np = np.zeros((0,), dtype=np.int32) | |
| DataStructs.ConvertToNumpyArray(fp, fp_np) | |
| fp = fp_np | |
| return fp | |
| def tanimoto_morgan_similarity(mol1: T.Union[Chem.Mol, str], mol2: T.Union[Chem.Mol, str]) -> float: | |
| """ | |
| Compute Tanimoto similarity between two molecules using Morgan fingerprints. | |
| Args: | |
| mol1 (T.Union[Chem.Mol, str]): First molecule as RDKit molecule or SMILES string. | |
| mol2 (T.Union[Chem.Mol, str]): Second molecule as RDKit molecule or SMILES string. | |
| """ | |
| if isinstance(mol1, str): | |
| mol1 = Chem.MolFromSmiles(mol1) | |
| if isinstance(mol2, str): | |
| mol2 = Chem.MolFromSmiles(mol2) | |
| return DataStructs.TanimotoSimilarity(morgan_fp(mol1, to_np=False), morgan_fp(mol2, to_np=False)) | |
| def standardize_smiles(smiles: T.Union[str, T.List[str]]) -> T.Union[str, T.List[str]]: | |
| """ | |
| Standardize SMILES representation of a molecule using PubChem standardization. | |
| """ | |
| if isinstance(smiles, str): | |
| return standardize_structure_with_pubchem(smiles, 'smiles') | |
| elif isinstance(smiles, list): | |
| return standardize_structure_list_with_pubchem(smiles, 'smiles') | |
| else: | |
| raise ValueError("Input should be a SMILES tring or a list of SMILES strings.") | |
| def mol_to_inchi_key(mol: Chem.Mol, twod: bool = True) -> str: | |
| """ | |
| Convert a molecule to InChI Key representation. | |
| Args: | |
| mol (Chem.Mol): RDKit molecule object. | |
| twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key). | |
| """ | |
| inchi_key = Chem.MolToInchiKey(mol) | |
| if twod: | |
| inchi_key = inchi_key.split("-")[0] | |
| return inchi_key | |
| def smiles_to_inchi_key(mol: str, twod: bool = True) -> str: | |
| """ | |
| Convert a SMILES molecule to InChI Key representation. | |
| Args: | |
| mol (str): SMILES string. | |
| twod (bool, optional): Return 2D InChI Key (first 14 characers of InChI Key). | |
| """ | |
| mol = Chem.MolFromSmiles(mol) | |
| return mol_to_inchi_key(mol, twod) | |
| def hugging_face_download(file_name: str) -> str: | |
| """ | |
| Download a file from the Hugging Face Hub and return its location on disk. | |
| Args: | |
| file_name (str): Name of the file to download. | |
| """ | |
| return hf_hub_download( | |
| repo_id="roman-bushuiev/MassSpecGym", | |
| filename="data/" + file_name, | |
| repo_type="dataset", | |
| ) | |
| def init_plotting(figsize=(6, 2), font_scale=1.0, style="whitegrid"): | |
| # Set default figure size | |
| plt.show() # Does not work without this line for some reason | |
| sns.set_theme(rc={"figure.figsize": figsize}) | |
| mpl.rcParams['svg.fonttype'] = 'none' | |
| # Set default style and font scale | |
| sns.set_style(style) | |
| sns.set_context("paper", font_scale=font_scale) | |
| sns.set_palette(["#009473", "#D94F70", "#5A5B9F", "#F0C05A", "#7BC4C4", "#FF6F61"]) | |
| def parse_spec_array(arr: str) -> np.ndarray: | |
| return np.array(list(map(float, arr.split(",")))) | |
| def spec_array_to_str(arr: np.ndarray) -> str: | |
| return ",".join(map(str, arr)) | |
| def compute_mass(smiles: str) -> float: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| raise ValueError("Invalid SMILES string.") | |
| return ExactMolWt(mol) | |
| def plot_spectrum(spec, hue=None, xlim=None, ylim=None, mirror_spec=None, highl_idx=None, | |
| figsize=(6, 2), colors=None, save_pth=None): | |
| if colors is not None: | |
| assert len(colors) >= 3 | |
| else: | |
| colors = ['blue', 'green', 'red'] | |
| # Normalize input spectrum | |
| def norm_spec(spec): | |
| assert len(spec.shape) == 2 | |
| if spec.shape[0] != 2: | |
| spec = spec.T | |
| mzs, ins = spec[0], spec[1] | |
| return mzs, ins / max(ins) * 100 | |
| mzs, ins = norm_spec(spec) | |
| # Initialize plotting | |
| init_plotting(figsize=figsize) | |
| fig, ax = plt.subplots(1, 1) | |
| # Setup color palette | |
| if hue is not None: | |
| norm = matplotlib.colors.Normalize(vmin=min(hue), vmax=max(hue), clip=True) | |
| mapper = cm.ScalarMappable(norm=norm, cmap=cm.cool) | |
| plt.colorbar(mapper, ax=ax) | |
| # Plot spectrum | |
| for i in range(len(mzs)): | |
| if hue is not None: | |
| color = mcolors.to_hex(mapper.to_rgba(hue[i])) | |
| else: | |
| color = colors[0] | |
| plt.plot([mzs[i], mzs[i]], [0, ins[i]], color=color, marker='o', markevery=(1, 2), mfc='white', zorder=2) | |
| # Plot mirror spectrum | |
| if mirror_spec is not None: | |
| mzs_m, ins_m = norm_spec(mirror_spec) | |
| def major_formatter(x, pos): | |
| label = str(round(-x)) if x < 0 else str(round(x)) | |
| return label | |
| for i in range(len(mzs_m)): | |
| plt.plot([mzs_m[i], mzs_m[i]], [0, -ins_m[i]], color=colors[2], marker='o', markevery=(1, 2), mfc='white', | |
| zorder=1) | |
| ax.yaxis.set_major_formatter(major_formatter) | |
| # Setup axes | |
| if xlim is not None: | |
| plt.xlim(xlim[0], xlim[1]) | |
| else: | |
| plt.xlim(0, max(mzs) + 10) | |
| if ylim is not None: | |
| plt.ylim(ylim[0], ylim[1]) | |
| plt.xlabel('m/z') | |
| plt.ylabel('Intensity [%]') | |
| if save_pth is not None: | |
| raise NotImplementedError() | |
| def show_mols(mols, legends='new_indices', smiles_in=False, svg=False, sort_by_legend=False, max_mols=500, | |
| legend_float_decimals=4, mols_per_row=6, save_pth: T.Optional[Path] = None): | |
| """ | |
| Returns svg image representing a grid of skeletal structures of the given molecules. Copy-pasted | |
| from https://github.com/pluskal-lab/DreaMS/blob/main/dreams/utils/mols.py | |
| :param mols: list of rdkit molecules | |
| :param smiles_in: True - SMILES inputs, False - RDKit mols | |
| :param legends: list of labels for each molecule, length must be equal to the length of mols | |
| :param svg: True - return svg image, False - return png image | |
| :param sort_by_legend: True - sort molecules by legend values | |
| :param max_mols: maximum number of molecules to show | |
| :param legend_float_decimals: number of decimal places to show for float legends | |
| :param mols_per_row: number of molecules per row to show | |
| :param save_pth: path to save the .svg image to | |
| """ | |
| if smiles_in: | |
| mols = [Chem.MolFromSmiles(e) for e in mols] | |
| if legends == 'new_indices': | |
| legends = list(range(len(mols))) | |
| elif legends == 'masses': | |
| legends = [ExactMolWt(m) for m in mols] | |
| elif callable(legends): | |
| legends = [legends(e) for e in mols] | |
| if sort_by_legend: | |
| idx = np.argsort(legends).tolist() | |
| legends = [legends[i] for i in idx] | |
| mols = [mols[i] for i in idx] | |
| legends = [f'{l:.{legend_float_decimals}f}' if isinstance(l, float) else str(l) for l in legends] | |
| img = Draw.MolsToGridImage(mols, maxMols=max_mols, legends=legends, molsPerRow=min(max_mols, mols_per_row), | |
| useSVG=svg, returnPNG=False) | |
| if save_pth: | |
| with open(save_pth, 'w') as f: | |
| f.write(img.data) | |
| return img | |
| class MyopicMCES(): | |
| def __init__( | |
| self, | |
| ind: int = 0, # dummy index | |
| solver: str = pulp.listSolvers(onlyAvailable=True)[0], # Use the first available solver | |
| threshold: int = 15, # MCES threshold | |
| always_stronger_bound: bool = True, # "False" makes computations a lot faster, but leads to overall higher MCES values | |
| solver_options: dict = None | |
| ): | |
| self.ind = ind | |
| self.solver = solver | |
| self.threshold = threshold | |
| self.always_stronger_bound = always_stronger_bound | |
| if solver_options is None: | |
| solver_options = dict(msg=0) # make ILP solver silent | |
| self.solver_options = solver_options | |
| # def __call__(self, smiles_1: str, smiles_2: str) -> float: | |
| # retval = MCES( | |
| # s1=smiles_1, | |
| # s2=smiles_2, | |
| # ind=self.ind, | |
| # threshold=self.threshold, | |
| # always_stronger_bound=self.always_stronger_bound, | |
| # solver=self.solver, | |
| # solver_options=self.solver_options | |
| # ) | |
| # dist = retval[1] | |
| # return dist | |
| def __call__(self, smiles_1: str, smiles_2: str) -> float: | |
| retval = MCES( | |
| smiles_1, | |
| smiles_2, | |
| threshold=self.threshold, | |
| always_stronger_bound=self.always_stronger_bound, | |
| solver=self.solver, | |
| solver_options = self.solver_options | |
| ) | |
| dist = retval[1] | |
| return dist | |
| class ReturnScalarBootStrapper(BootStrapper): | |
| def __init__( | |
| self, | |
| base_metric: Metric, | |
| num_bootstraps: int = 10, | |
| mean: bool = False, | |
| std: bool = False, | |
| quantile: T.Optional[T.Union[float, torch.Tensor]] = None, | |
| raw: bool = False, | |
| sampling_strategy: str = "poisson", | |
| **kwargs: T.Any | |
| ) -> None: | |
| """Wrapper for BootStrapper that returns a scalar value in compute instead of a dictionary.""" | |
| if mean + std + bool(quantile) + raw != 1: | |
| raise ValueError("Exactly one of mean, std, quantile or raw should be True.") | |
| if std: | |
| self.compute_key = "std" | |
| else: | |
| raise NotImplementedError("Currently only std is implemented.") | |
| super().__init__( | |
| base_metric=base_metric, | |
| num_bootstraps=num_bootstraps, | |
| mean=mean, | |
| std=std, | |
| quantile=quantile, | |
| raw=raw, | |
| sampling_strategy=sampling_strategy, | |
| **kwargs | |
| ) | |
| def compute(self): | |
| return super().compute()[self.compute_key] | |
| def batch_ptr_to_batch_idx(batch_ptr: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert a tensor of batch pointers to a tensor of batch indexes. | |
| For example [1, 3, 2] -> [0, 1, 1, 1, 2, 2] | |
| Args: | |
| batch_ptr (Tensor): Tensor of batch pointers. | |
| """ | |
| indexes = torch.arange(batch_ptr.size(0), device=batch_ptr.device) | |
| indexes = torch.repeat_interleave(indexes, batch_ptr) | |
| return indexes | |
| def unbatch_list(batch_list: list, batch_idx: torch.Tensor) -> list: | |
| """ | |
| Unbatch a list of items using the batch indexes (i.e., number of samples per batch). | |
| Args: | |
| batch_list (list): List of items to unbatch. | |
| batch_idx (Tensor): Tensor of batch indexes. | |
| """ | |
| return [ | |
| [batch_list[j] for j in range(len(batch_list)) if batch_idx[j] == i] | |
| for i in range(batch_idx[-1] + 1) | |
| ] | |
| class CosSimLoss(nn.Module): | |
| def __init__(self): | |
| super(CosSimLoss, self).__init__() | |
| def forward(self, inputs, targets): | |
| return 1 - F.cosine_similarity(inputs, targets).mean() | |
| def parse_sirius_ms(spectra_file: str) -> T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]: | |
| """ | |
| Parses spectra from the SIRIUS .ms file. | |
| Copied from the code of Goldman et al.: | |
| https://github.com/samgoldman97/mist/blob/4c23d34fc82425ad5474a53e10b4622dcdbca479/src/mist/utils/parse_utils.py#LL10C77-L10C77. | |
| :return T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]: metadata and list of spectra tuples containing name and array | |
| """ | |
| lines = [i.strip() for i in open(spectra_file, "r").readlines()] | |
| group_num = 0 | |
| metadata = {} | |
| spectras = [] | |
| my_iterator = groupby( | |
| lines, lambda line: line.startswith(">") or line.startswith("#") | |
| ) | |
| for index, (start_line, lines) in enumerate(my_iterator): | |
| group_lines = list(lines) | |
| subject_lines = list(next(my_iterator)[1]) | |
| # Get spectra | |
| if group_num > 0: | |
| spectra_header = group_lines[0].split(">")[1] | |
| peak_data = [ | |
| [float(x) for x in peak.split()[:2]] | |
| for peak in subject_lines | |
| if peak.strip() | |
| ] | |
| # Check if spectra is empty | |
| if len(peak_data): | |
| peak_data = np.vstack(peak_data) | |
| # Add new tuple | |
| spectras.append((spectra_header, peak_data)) | |
| # Get meta data | |
| else: | |
| entries = {} | |
| for i in group_lines: | |
| if " " not in i: | |
| continue | |
| elif i.startswith("#INSTRUMENT TYPE"): | |
| key = "#INSTRUMENT TYPE" | |
| val = i.split(key)[1].strip() | |
| entries[key[1:]] = val | |
| else: | |
| start, end = i.split(" ", 1) | |
| start = start[1:] | |
| while start in entries: | |
| start = f"{start}'" | |
| entries[start] = end | |
| metadata.update(entries) | |
| group_num += 1 | |
| metadata["_FILE_PATH"] = spectra_file | |
| metadata["_FILE"] = Path(spectra_file).stem | |
| return metadata, spectras | |