{ "cells": [ { "cell_type": "code", "execution_count": 53, "id": "1e8a4ffb", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import plotly.graph_objects as go\n", "from plotly.subplots import make_subplots\n", "from rdkit import Chem\n", "from rdkit.Chem import rdDepictor\n", "from rdkit.Chem.Draw import rdMolDraw2D\n", "import pickle\n", "import copy" ] }, { "cell_type": "markdown", "id": "4c716c1a", "metadata": {}, "source": [ "## Ranking result" ] }, { "cell_type": "code", "execution_count": 2, "id": "f9555d50", "metadata": {}, "outputs": [], "source": [ "ranking_file = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/result_MassSpecGym_retrieval_candidates_formula.pkl\"\n", "with open(ranking_file, 'rb') as f:\n", " ranking = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": null, "id": "78cfd902", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
1520
rank20.68847.39172.368
\n", "
" ], "text/plain": [ " 1 5 20\n", "rank 20.688 47.391 72.368" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "r='rank'\n", "result = []\n", "\n", "top_k = [1, 5, 20]\n", "rank_result = {}\n", "for k in top_k:\n", " result.append(round(len(ranking[ranking[r]<=k])/len(ranking)*100, 3))\n", "rank_result[r] = result\n", "\n", "pd.DataFrame.from_dict(rank_result, orient='index', columns=['1', '5', '20'])" ] }, { "cell_type": "code", "execution_count": 23, "id": "11bbc0d4", "metadata": {}, "outputs": [], "source": [ "def get_target(candidates, labels):\n", " return np.array(candidates)[labels][0]\n", "\n", "def get_cand_at_1(candidates, scores):\n", " return candidates[np.argmax(scores)]\n", "\n", "def get_top_score(scores):\n", " return np.max(scores)\n", "\n", "def get_target_score(labels, scores):\n", " return np.array(scores)[labels][0]\n", "\n", "def get_n_heavy_atoms(smiles):\n", " mol = Chem.MolFromSmiles(smiles)\n", " return mol.GetNumHeavyAtoms()\n", "\n", "ranking['target'] = ranking.apply(lambda x: get_target(x['candidates'], x['labels']), axis=1)\n", "ranking['target_score'] = ranking.apply(lambda x: get_target_score(x['labels'], x['scores']), axis=1)\n", "\n", "ranking['cand@1'] = ranking.apply(lambda x: get_cand_at_1(x['candidates'], x['scores']), axis=1)\n", "ranking['top_score'] = ranking.apply(lambda x: get_top_score(x['scores']), axis=1)\n", "\n", "ranking['n_heavy_atoms'] = ranking['target'].apply(get_n_heavy_atoms)" ] }, { "cell_type": "code", "execution_count": 24, "id": "763ea617", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
identifiercandidatesscoreslabelsranktargettarget_scorecand@1top_scoren_heavy_atoms
0MassSpecGymID0000201[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...[0.17369578778743744, 0.12611594796180725, 0.2...[True, False, False, False, False, False, Fals...17CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...0.173696COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C)...0.25987857
1MassSpecGymID0000202[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...[0.05142267048358917, 0.07289629429578781, 0.1...[True, False, False, False, False, False, Fals...24CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...0.051423COC(=O)/C(C)=C\\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(C...0.23719557
2MassSpecGymID0000203[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...[0.09354929625988007, 0.0947718694806099, 0.10...[True, False, False, False, False, False, Fals...23CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...0.093549C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO)...0.23826857
\n", "
" ], "text/plain": [ " identifier candidates \\\n", "0 MassSpecGymID0000201 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n", "1 MassSpecGymID0000202 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n", "2 MassSpecGymID0000203 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n", "\n", " scores \\\n", "0 [0.17369578778743744, 0.12611594796180725, 0.2... \n", "1 [0.05142267048358917, 0.07289629429578781, 0.1... \n", "2 [0.09354929625988007, 0.0947718694806099, 0.10... \n", "\n", " labels rank \\\n", "0 [True, False, False, False, False, False, Fals... 17 \n", "1 [True, False, False, False, False, False, Fals... 24 \n", "2 [True, False, False, False, False, False, Fals... 23 \n", "\n", " target target_score \\\n", "0 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.173696 \n", "1 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.051423 \n", "2 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.093549 \n", "\n", " cand@1 top_score n_heavy_atoms \n", "0 COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C)... 0.259878 57 \n", "1 COC(=O)/C(C)=C\\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(C... 0.237195 57 \n", "2 C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO)... 0.238268 57 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ranking.head(3)" ] }, { "cell_type": "markdown", "id": "93ef333e", "metadata": {}, "source": [ "## model" ] }, { "cell_type": "code", "execution_count": 12, "id": "0b4e4250", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data path: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv\n", "Processing formula spectra\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 231104/231104 [00:18<00:00, 12309.47it/s]\n", "/data/yzhouc01/FILIP-MS/mvp/data/datasets.py:221: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loaded Model from checkpoint\n" ] } ], "source": [ "import sys\n", "sys.path.insert(0, \"/data/yzhouc01/MassSpecGym\")\n", "sys.path.insert(0, \"/data/yzhouc01/FILIP-MS\")\n", "\n", "from rdkit import RDLogger\n", "import pytorch_lightning as pl\n", "from pytorch_lightning import Trainer\n", "from massspecgym.models.base import Stage\n", "import os\n", "\n", "from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset\n", "from mvp.utils.models import get_model\n", "\n", "from mvp.definitions import TEST_RESULTS_DIR\n", "import yaml\n", "from functools import partial\n", "# Suppress RDKit warnings and errors\n", "lg = RDLogger.logger()\n", "lg.setLevel(RDLogger.CRITICAL)\n", "\n", "# Load model and data\n", "\n", "param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'\n", "with open(param_pth) as f:\n", " params = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", "spec_featurizer = get_spec_featurizer(params['spectra_view'], params)\n", "mol_featurizer = get_mol_featurizer(params['molecule_view'], params)\n", "dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)\n", "\n", "\n", "# load model\n", "import torch \n", "checkpoint_pth = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt\"\n", "params['checkpoint_pth'] = checkpoint_pth\n", "model = get_model(params['model'], params)" ] }, { "cell_type": "markdown", "id": "bd9dc380", "metadata": {}, "source": [ "## visualization function" ] }, { "cell_type": "code", "execution_count": 16, "id": "0883a1a2", "metadata": {}, "outputs": [], "source": [ "\n", "import torch.nn.functional as F\n", "import numpy as np\n", "\n", "# Atomic masses corresponding to your atom_labels\n", "ATOM_LABELS = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']\n", "ATOM_MASSES = np.array([\n", " 1.0078, 12.0000, 15.9949, 14.0031, 30.9738, 31.9721, \n", " 35.45, 18.9984, 79.90, 126.90, 10.811, 74.9216, 28.085, 78.96\n", "])\n", "norm_vector = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]\n", "\n", "def spectra_from_encoding(spectral_tensor, norm_vector=norm_vector):\n", " \"\"\"\n", " Convert encoded spectra (num_peaks x 15) into m/z, intensities, and molecular formulas.\n", " Can undo normalization if a norm_vector is provided.\n", " \n", " Args:\n", " spectral_tensor (np.ndarray or torch.Tensor): [num_peaks, 15]\n", " norm_vector (np.ndarray or list): length 14, normalization factor for each atom\n", " \n", " Returns:\n", " mzs (list of float): list of m/z values\n", " intensities (list of float): list of intensities\n", " formulas (list of str): molecular formula strings\n", " \"\"\"\n", " if hasattr(spectral_tensor, \"detach\"):\n", " spectral_tensor = spectral_tensor.detach().cpu().numpy()\n", " \n", " counts = spectral_tensor[:, :14] # atom counts\n", " intensities = spectral_tensor[:, 14] # last col = intensity\n", " \n", " # Undo normalization\n", " if norm_vector is not None:\n", " counts = counts * np.array(norm_vector)\n", " \n", " # Compute m/z\n", " mzs = (counts * ATOM_MASSES).sum(axis=1)\n", " \n", " # Build molecular formula strings\n", " formulas = []\n", " for peak_counts in counts:\n", " formula_parts = []\n", " for elem, count in zip(ATOM_LABELS, peak_counts):\n", " n = int(round(count))\n", " if n > 0:\n", " formula_parts.append(f\"{elem}{n if n > 1 else ''}\")\n", " formulas.append(\"\".join(formula_parts) if formula_parts else \"Unknown\")\n", " \n", " return mzs.tolist(), intensities.tolist(), formulas\n", "\n", "\n", "def mol_to_graph_coords(mol):\n", " \"\"\"Return atom coordinates and bond list for a molecule.\"\"\"\n", " rdDepictor.Compute2DCoords(mol)\n", " conf = mol.GetConformer()\n", " coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}\n", " bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]\n", " return coords, bonds\n", "\n", "def interactive_attention_visualization(spectral_embeds, graph_embeds, \n", " peak_mzs, peak_intensities, peak_formulas, mol):\n", " \"\"\"\n", " Interactive visualization of peak-node similarity with color scale legend.\n", " - Clicking a peak recolors nodes by similarity\n", " - Clicking a node recolors peaks by similarity\n", " \"\"\"\n", " # Similarity matrix\n", " spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)\n", " graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)\n", " \n", " similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()\n", " sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)\n", " \n", " num_peaks, num_nodes = similarity.shape\n", " \n", " # --- Molecule graph ---\n", " coords, bonds = mol_to_graph_coords(mol)\n", " atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]\n", " atom_x = [coords[i].x for i in range(num_nodes)]\n", " atom_y = [coords[i].y for i in range(num_nodes)]\n", " \n", " # --- Spectrum trace ---\n", " spectrum_trace = go.Bar(\n", " x=peak_mzs,\n", " y=peak_intensities,\n", " name='peak',\n", " marker=dict(color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n", " colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n", " hovertext=[f\"Formula {f}\" for f in peak_formulas],\n", " customdata=list(range(num_peaks)) # peak index\n", " )\n", " \n", " # --- Graph nodes ---\n", " graph_nodes = go.Scatter(\n", " x=atom_x, y=atom_y,\n", " mode=\"markers+text\",\n", " name='node',\n", " text=atom_labels,\n", " textposition=\"middle center\",\n", " marker=dict(size=20, color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n", " colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n", " customdata=list(range(num_nodes)),\n", " # hovertext=[f\"Atom {i} ({label})\" for i, label in enumerate(atom_labels)]\n", " )\n", " \n", " # --- Graph bonds ---\n", " edge_x, edge_y = [], []\n", " for i, j in bonds:\n", " edge_x += [coords[i].x, coords[j].x, None]\n", " edge_y += [coords[i].y, coords[j].y, None]\n", " graph_edges = go.Scatter(\n", " x=edge_x, y=edge_y,\n", " mode=\"lines\", line=dict(color=\"gray\", width=2),\n", " hoverinfo=\"none\", showlegend=False\n", " )\n", " \n", " # --- Subplots ---\n", " fig = make_subplots(rows=1, cols=2, subplot_titles=(\"Spectrum\", \"Molecule\"), \n", " column_widths=[0.6, 0.4])\n", " \n", " fig.add_trace(spectrum_trace, row=1, col=1)\n", " fig.add_trace(graph_edges, row=1, col=2)\n", " fig.add_trace(graph_nodes, row=1, col=2)\n", " \n", " fig.update_xaxes(title=\"m/z\", row=1, col=1)\n", " fig.update_yaxes(title=\"Intensity\", row=1, col=1)\n", " fig.update_xaxes(visible=False, row=1, col=2)\n", " fig.update_yaxes(visible=False, row=1, col=2)\n", " \n", " fig.update_layout(title=\"Peak ↔ Node Similarity\", showlegend=False)\n", " \n", " # --- Interactivity ---\n", " from ipywidgets import VBox\n", " fw = go.FigureWidget(fig)\n", "\n", " def highlight_nodes(trace, points, selector):\n", " \"\"\"Click on peak → recolor nodes\"\"\"\n", " if points.point_inds:\n", " peak_idx = points.point_inds[0]\n", " scores = sim_norm[peak_idx, :]\n", " with fw.batch_update():\n", " fw.data[2].marker.color = scores\n", " fw.data[0].marker.color = [\"red\" if i == peak_idx else \"lightgray\" for i in range(num_peaks)]\n", "\n", " def highlight_peaks(trace, points, selector):\n", " \"\"\"Click on node → recolor peaks\"\"\"\n", " if points.point_inds:\n", " node_idx = points.point_inds[0]\n", " scores = sim_norm[:, node_idx]\n", " with fw.batch_update():\n", " fw.data[0].marker.color = scores\n", " fw.data[2].marker.color = [\"red\" if i == node_idx else \"lightgray\" for i in range(num_nodes)]\n", " \n", " fw.data[0].on_click(highlight_nodes) # spectrum\n", " fw.data[2].on_click(highlight_peaks) # nodes\n", " \n", " return fw\n" ] }, { "cell_type": "markdown", "id": "7d1b98d3", "metadata": {}, "source": [ "## Visualization" ] }, { "cell_type": "code", "execution_count": 92, "id": "78f85d94", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MS ID: MassSpecGymID0396247, Target:0.423, Cand@1: 0.49\n", "Target rank: 25\n" ] } ], "source": [ "# sample a case where targte is ranked at 2\n", "sample = ranking[(ranking['rank']>20) & (ranking['n_heavy_atoms'] <=20)].sample(1).iloc[0]\n", "ms_id = sample['identifier']\n", "target = sample['target']\n", "cand_at_1 = sample['cand@1']\n", "print(f\"MS ID: {ms_id}, Target:{sample['target_score']:.3}, Cand@1: {sample['top_score']:.3}\")\n", "print(f\"Target rank: {sample['rank']}\")" ] }, { "cell_type": "code", "execution_count": 96, "id": "20437eb8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CCCCCCCC(=O)NC1=C2C(=CSS2)NC1=O\n", "23\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c60a8c857d4041cdaf96a7f4ba2f4f61", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FigureWidget({\n", " 'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", " 16, 17, 18, 19, 20, 21, 22],\n", " 'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,\n", " Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula\n", " HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula\n", " H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,\n", " Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,\n", " Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,\n", " Formula H5C7ON2S2, Formula H2C10ON2S, Formula\n", " H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],\n", " 'marker': {'cmax': 1,\n", " 'cmin': 0,\n", " 'color': 'lightgray',\n", " 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n", " 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n", " '#482878'], [0.2222222222222222,\n", " '#3e4989'], [0.3333333333333333,\n", " '#31688e'], [0.4444444444444444,\n", " '#26828e'], [0.5555555555555556,\n", " '#1f9e89'], [0.6666666666666666,\n", " '#35b779'], [0.7777777777777778,\n", " '#6ece58'], [0.8888888888888888,\n", " '#b5de2b'], [1.0, '#fde725']]},\n", " 'name': 'peak',\n", " 'type': 'bar',\n", " 'uid': 'c92eb0ab-8030-4f20-a4b0-126e19bad926',\n", " 'x': [94.07799900290073, 104.05839936191737, 108.0936002238661,\n", " 116.97070118690729, 126.10410051210002, 130.98300034787468,\n", " 142.95000219490117, 143.9783008338645, 168.96559957769588,\n", " 169.96090409332814, 170.96870403325858, 171.97650416466072,\n", " 172.98100185312143, 181.96090015942156, 184.9843001706846,\n", " 185.99210011061504, 190.01999790607465, 195.97650157185868,\n", " 196.98430151178914, 197.98880457099676, 224.00769912172186,\n", " 280.07010477147026, 298.08060429381726],\n", " 'xaxis': 'x',\n", " 'y': [0.24127991497516632, 0.16632197797298431, 0.16255460679531097,\n", " 0.2183758169412613, 0.16817252337932587, 0.14077642560005188,\n", " 0.24170850217342377, 0.23078560829162598, 0.12851069867610931,\n", " 0.22988910973072052, 0.35106268525123596, 1.0,\n", " 0.2844732105731964, 0.14581838250160217, 0.1700058877468109,\n", " 0.3804142475128174, 0.13246509432792664, 0.18231017887592316,\n", " 0.24256132543087006, 0.2150418609380722, 0.14223572611808777,\n", " 0.3020711839199066, 1.100000023841858],\n", " 'yaxis': 'y'},\n", " {'hoverinfo': 'none',\n", " 'line': {'color': 'gray', 'width': 2},\n", " 'mode': 'lines',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", " 'uid': 'ef26d482-9ee9-47f4-823a-834c4cdab648',\n", " 'x': [9.232981638284745, 7.913365912100706, None, 7.913365912100706,\n", " 6.635932959325772, None, 6.635932959325772, 5.316317233141732,\n", " None, 5.316317233141732, 4.038884280366799, None,\n", " 4.038884280366799, 2.7192685541827593, None,\n", " 2.7192685541827593, 1.4418356014078257, None,\n", " 1.4418356014078257, 0.12221987522378566, None,\n", " 0.12221987522378566, 0.08003710181467982, None,\n", " 0.12221987522378566, -1.155213077551148, None,\n", " -1.155213077551148, -2.474828803735188, None,\n", " -2.474828803735188, -3.8274477532630358, None,\n", " -3.8274477532630358, -4.862094590572173, None,\n", " -4.862094590572173, -6.214713553764019, None,\n", " -6.214713553764019, -6.016031209535085, None,\n", " -6.016031209535085, -4.540619804645251, None,\n", " -4.862094590572173, -4.148922589885585, None,\n", " -4.148922589885585, -2.6735111792957427, None,\n", " -2.6735111792957427, -1.587460593601559, None,\n", " -2.6735111792957427, -2.474828803735188, None,\n", " -4.540619804645251, -3.8274477532630358, None],\n", " 'xaxis': 'x2',\n", " 'y': [0.4812527731327094, -0.2319192504406855, None,\n", " -0.2319192504406855, 0.5543154798814447, None,\n", " 0.5543154798814447, -0.1588565436919498, None,\n", " -0.1588565436919498, 0.6273781866301811, None,\n", " 0.6273781866301811, -0.08579383694321274, None,\n", " -0.08579383694321274, 0.7004408933789189, None,\n", " 0.7004408933789189, -0.012731130194475332, None,\n", " -0.012731130194475332, -1.5121378840900008, None,\n", " -0.012731130194475332, 0.7735036001276558, None,\n", " 0.7735036001276558, 0.06033157655426158, None,\n", " 0.06033157655426158, 0.708731127277566, None,\n", " 0.708731127277566, -0.3773194802201689, None,\n", " -0.3773194802201689, 0.27108004199882174, None,\n", " 0.27108004199882174, 1.757863592517086, None,\n", " 1.757863592517086, 2.028346838432628, None,\n", " -0.3773194802201689, -1.6969351846006568, None,\n", " -1.6969351846006568, -1.426451969777085, None,\n", " -1.426451969777085, -2.461098829973027, None,\n", " -1.426451969777085, 0.06033157655426158, None,\n", " 2.028346838432628, 0.708731127277566, None],\n", " 'yaxis': 'y2'},\n", " {'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", " 16, 17, 18],\n", " 'marker': {'cmax': 1,\n", " 'cmin': 0,\n", " 'color': 'lightgray',\n", " 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n", " 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n", " '#482878'], [0.2222222222222222,\n", " '#3e4989'], [0.3333333333333333,\n", " '#31688e'], [0.4444444444444444,\n", " '#26828e'], [0.5555555555555556,\n", " '#1f9e89'], [0.6666666666666666,\n", " '#35b779'], [0.7777777777777778,\n", " '#6ece58'], [0.8888888888888888,\n", " '#b5de2b'], [1.0, '#fde725']],\n", " 'size': 20},\n", " 'mode': 'markers+text',\n", " 'name': 'node',\n", " 'text': [C, C, C, C, C, C, C, C, O, N, C, C, C, C, S, S, N, C, O],\n", " 'textposition': 'middle center',\n", " 'type': 'scatter',\n", " 'uid': 'e157b647-8df0-4d9e-a022-9a8e503c555e',\n", " 'x': [9.232981638284745, 7.913365912100706, 6.635932959325772,\n", " 5.316317233141732, 4.038884280366799, 2.7192685541827593,\n", " 1.4418356014078257, 0.12221987522378566, 0.08003710181467982,\n", " -1.155213077551148, -2.474828803735188, -3.8274477532630358,\n", " -4.862094590572173, -6.214713553764019, -6.016031209535085,\n", " -4.540619804645251, -4.148922589885585, -2.6735111792957427,\n", " -1.587460593601559],\n", " 'xaxis': 'x2',\n", " 'y': [0.4812527731327094, -0.2319192504406855, 0.5543154798814447,\n", " -0.1588565436919498, 0.6273781866301811, -0.08579383694321274,\n", " 0.7004408933789189, -0.012731130194475332, -1.5121378840900008,\n", " 0.7735036001276558, 0.06033157655426158, 0.708731127277566,\n", " -0.3773194802201689, 0.27108004199882174, 1.757863592517086,\n", " 2.028346838432628, -1.6969351846006568, -1.426451969777085,\n", " -2.461098829973027],\n", " 'yaxis': 'y2'}],\n", " 'layout': {'annotations': [{'font': {'size': 16},\n", " 'showarrow': False,\n", " 'text': 'Spectrum',\n", " 'x': 0.27,\n", " 'xanchor': 'center',\n", " 'xref': 'paper',\n", " 'y': 1.0,\n", " 'yanchor': 'bottom',\n", " 'yref': 'paper'},\n", " {'font': {'size': 16},\n", " 'showarrow': False,\n", " 'text': 'Molecule',\n", " 'x': 0.8200000000000001,\n", " 'xanchor': 'center',\n", " 'xref': 'paper',\n", " 'y': 1.0,\n", " 'yanchor': 'bottom',\n", " 'yref': 'paper'}],\n", " 'showlegend': False,\n", " 'template': '...',\n", " 'title': {'text': 'Peak ↔ Node Similarity'},\n", " 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.54], 'title': {'text': 'm/z'}},\n", " 'xaxis2': {'anchor': 'y2', 'domain': [0.64, 1.0], 'visible': False},\n", " 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Intensity'}},\n", " 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'visible': False}}\n", "})" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Target Molecule\n", "i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]\n", "s = target\n", "print(s)\n", "mol = Chem.MolFromSmiles(s)\n", "g = dataset[i]['mol']\n", "spec = dataset[i]['SpecFormula']\n", "\n", "peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)\n", "\n", "print(len(peak_formulas))\n", "# Embeddings\n", "model = model.to(torch.device('cpu'))\n", "model.eval()\n", "with torch.no_grad():\n", " spec_enc, mol_enc = model.forward(dataset[i], stage='test')\n", "\n", "fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)\n", "fw\n" ] }, { "cell_type": "code", "execution_count": 98, "id": "ce3a20ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "H10C7\n", "H8C3O2N2\n", "H12C8\n", "H3C3NS2\n", "H14C8O\n", "HC7NS\n", "HC4ONS2\n", "C7N2S\n", "H3C6ONS2\n", "H2C5ON2S2\n", "H3C5ON2S2\n", "H4C5ON2S2\n", "HC8ON2S\n", "H2C6ON2S2\n", "H5C6ON2S2\n", "H6C6ON2S2\n", "H6C9ON2S\n", "H4C7ON2S2\n", "H5C7ON2S2\n", "H2C10ON2S\n", "H8C9ON2S2\n", "H16C13ON2S2\n", "H18C13O2N2S2\n" ] } ], "source": [ "for f in peak_formulas:\n", " print(f)" ] }, { "cell_type": "code", "execution_count": 99, "id": "da93666c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cc1sc2[nH]c(=S)n(CCOC(C)C)c(=O)c2c1C\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be11e39bfd484f238a5643ffbdff0d78", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FigureWidget({\n", " 'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", " 16, 17, 18, 19, 20, 21, 22],\n", " 'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,\n", " Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula\n", " HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula\n", " H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,\n", " Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,\n", " Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,\n", " Formula H5C7ON2S2, Formula H2C10ON2S, Formula\n", " H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],\n", " 'marker': {'cmax': 1,\n", " 'cmin': 0,\n", " 'color': 'lightgray',\n", " 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n", " 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n", " '#482878'], [0.2222222222222222,\n", " '#3e4989'], [0.3333333333333333,\n", " '#31688e'], [0.4444444444444444,\n", " '#26828e'], [0.5555555555555556,\n", " '#1f9e89'], [0.6666666666666666,\n", " '#35b779'], [0.7777777777777778,\n", " '#6ece58'], [0.8888888888888888,\n", " '#b5de2b'], [1.0, '#fde725']]},\n", " 'name': 'peak',\n", " 'type': 'bar',\n", " 'uid': 'ef26bc6b-dc8e-4cbb-b681-e2343d82ea94',\n", " 'x': [94.07799900290073, 104.05839936191737, 108.0936002238661,\n", " 116.97070118690729, 126.10410051210002, 130.98300034787468,\n", " 142.95000219490117, 143.9783008338645, 168.96559957769588,\n", " 169.96090409332814, 170.96870403325858, 171.97650416466072,\n", " 172.98100185312143, 181.96090015942156, 184.9843001706846,\n", " 185.99210011061504, 190.01999790607465, 195.97650157185868,\n", " 196.98430151178914, 197.98880457099676, 224.00769912172186,\n", " 280.07010477147026, 298.08060429381726],\n", " 'xaxis': 'x',\n", " 'y': [0.24127991497516632, 0.16632197797298431, 0.16255460679531097,\n", " 0.2183758169412613, 0.16817252337932587, 0.14077642560005188,\n", " 0.24170850217342377, 0.23078560829162598, 0.12851069867610931,\n", " 0.22988910973072052, 0.35106268525123596, 1.0,\n", " 0.2844732105731964, 0.14581838250160217, 0.1700058877468109,\n", " 0.3804142475128174, 0.13246509432792664, 0.18231017887592316,\n", " 0.24256132543087006, 0.2150418609380722, 0.14223572611808777,\n", " 0.3020711839199066, 1.100000023841858],\n", " 'yaxis': 'y'},\n", " {'hoverinfo': 'none',\n", " 'line': {'color': 'gray', 'width': 2},\n", " 'mode': 'lines',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", " 'uid': 'aaaf4294-9093-49bb-905c-261137ec6e91',\n", " 'x': [6.049790084552722, 4.550200591327338, None, 4.550200591327338,\n", " 3.6971529128998446, None, 3.6971529128998446,\n", " 2.260114954813708, None, 2.260114954813708, 0.9789776698290348,\n", " None, 0.9789776698290348, -0.33725023794912345, None,\n", " -0.33725023794912345, -1.6183875229337956, None,\n", " -0.33725023794912345, -0.3723408607426099, None,\n", " -0.3723408607426099, -1.6885687685207686, None,\n", " -1.6885687685207686, -2.9697060535054405, None,\n", " -2.9697060535054405, -4.2859339612835985, None,\n", " -4.2859339612835985, -5.567071246268271, None,\n", " -5.567071246268271, -6.88329915404643, None,\n", " -5.567071246268271, -5.531980623474783, None,\n", " -0.3723408607426099, 0.9087964242420628, None,\n", " 0.9087964242420628, 0.8737058014485767, None,\n", " 0.9087964242420628, 2.225024332020221, None, 2.225024332020221,\n", " 3.640375092533582, None, 3.640375092533582, 4.07040056505774,\n", " None, 3.640375092533582, 4.550200591327338, None,\n", " 2.225024332020221, 2.260114954813708, None],\n", " 'xaxis': 'x2',\n", " 'y': [0.3527856091535654, 0.38787623194705184, None,\n", " 0.38787623194705184, 1.6216953671242724, None,\n", " 1.6216953671242724, 1.1916698946001143, None,\n", " 1.1916698946001143, 1.9718540119865837, None,\n", " 1.9718540119865837, 1.2524486361476674, None,\n", " 1.2524486361476674, 2.032632753534136, None,\n", " 1.2524486361476674, -0.24714085707771702, None,\n", " -0.24714085707771702, -0.9665462329166336, None,\n", " -0.9665462329166336, -0.18636211553016513, None,\n", " -0.18636211553016513, -0.9057674913690814, None,\n", " -0.9057674913690814, -0.1255833739826133, None,\n", " -0.1255833739826133, -0.8449887498215296, None,\n", " -0.1255833739826133, 1.3740061192427715, None,\n", " -0.24714085707771702, -1.027324974464186, None,\n", " -1.027324974464186, -2.5269144676895703, None,\n", " -1.027324974464186, -0.30791959862527024, None,\n", " -0.30791959862527024, -0.8046914020866298, None,\n", " -0.8046914020866298, -2.2417293601727675, None,\n", " -0.8046914020866298, 0.38787623194705184, None,\n", " -0.30791959862527024, 1.1916698946001143, None],\n", " 'yaxis': 'y2'},\n", " {'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", " 16, 17, 18],\n", " 'marker': {'cmax': 1,\n", " 'cmin': 0,\n", " 'color': 'lightgray',\n", " 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n", " 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n", " '#482878'], [0.2222222222222222,\n", " '#3e4989'], [0.3333333333333333,\n", " '#31688e'], [0.4444444444444444,\n", " '#26828e'], [0.5555555555555556,\n", " '#1f9e89'], [0.6666666666666666,\n", " '#35b779'], [0.7777777777777778,\n", " '#6ece58'], [0.8888888888888888,\n", " '#b5de2b'], [1.0, '#fde725']],\n", " 'size': 20},\n", " 'mode': 'markers+text',\n", " 'name': 'node',\n", " 'text': [C, C, S, C, N, C, S, N, C, C, O, C, C, C, C, O, C, C, C],\n", " 'textposition': 'middle center',\n", " 'type': 'scatter',\n", " 'uid': '03337be1-04c1-4602-b8b2-bfbe9950f45e',\n", " 'x': [6.049790084552722, 4.550200591327338, 3.6971529128998446,\n", " 2.260114954813708, 0.9789776698290348, -0.33725023794912345,\n", " -1.6183875229337956, -0.3723408607426099, -1.6885687685207686,\n", " -2.9697060535054405, -4.2859339612835985, -5.567071246268271,\n", " -6.88329915404643, -5.531980623474783, 0.9087964242420628,\n", " 0.8737058014485767, 2.225024332020221, 3.640375092533582,\n", " 4.07040056505774],\n", " 'xaxis': 'x2',\n", " 'y': [0.3527856091535654, 0.38787623194705184, 1.6216953671242724,\n", " 1.1916698946001143, 1.9718540119865837, 1.2524486361476674,\n", " 2.032632753534136, -0.24714085707771702, -0.9665462329166336,\n", " -0.18636211553016513, -0.9057674913690814, -0.1255833739826133,\n", " -0.8449887498215296, 1.3740061192427715, -1.027324974464186,\n", " -2.5269144676895703, -0.30791959862527024, -0.8046914020866298,\n", " -2.2417293601727675],\n", " 'yaxis': 'y2'}],\n", " 'layout': {'annotations': [{'font': {'size': 16},\n", " 'showarrow': False,\n", " 'text': 'Spectrum',\n", " 'x': 0.27,\n", " 'xanchor': 'center',\n", " 'xref': 'paper',\n", " 'y': 1.0,\n", " 'yanchor': 'bottom',\n", " 'yref': 'paper'},\n", " {'font': {'size': 16},\n", " 'showarrow': False,\n", " 'text': 'Molecule',\n", " 'x': 0.8200000000000001,\n", " 'xanchor': 'center',\n", " 'xref': 'paper',\n", " 'y': 1.0,\n", " 'yanchor': 'bottom',\n", " 'yref': 'paper'}],\n", " 'showlegend': False,\n", " 'template': '...',\n", " 'title': {'text': 'Peak ↔ Node Similarity'},\n", " 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.54], 'title': {'text': 'm/z'}},\n", " 'xaxis2': {'anchor': 'y2', 'domain': [0.64, 1.0], 'visible': False},\n", " 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Intensity'}},\n", " 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'visible': False}}\n", "})" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Cand@1\n", "\n", "# Target Molecule\n", "i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]\n", "s = cand_at_1\n", "print(s)\n", "mol = Chem.MolFromSmiles(s)\n", "g = dataset[i]['mol']\n", "spec = dataset[i]['SpecFormula']\n", "cand_mol= mol_featurizer(cand_at_1)\n", "\n", "peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)\n", "\n", "# Embeddings\n", "model = model.to(torch.device('cpu'))\n", "model.eval()\n", "with torch.no_grad():\n", " input = copy.deepcopy(dataset[i])\n", " input['mol'] = cand_mol\n", " \n", " spec_enc, mol_enc = model.forward(input, stage='test')\n", "\n", "fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)\n", "fw\n" ] }, { "cell_type": "code", "execution_count": null, "id": "aed9d517", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (spec)", "language": "python", "name": "spec" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }