{
"cells": [
{
"cell_type": "code",
"execution_count": 53,
"id": "1e8a4ffb",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from plotly.subplots import make_subplots\n",
"from rdkit import Chem\n",
"from rdkit.Chem import rdDepictor\n",
"from rdkit.Chem.Draw import rdMolDraw2D\n",
"import pickle\n",
"import copy"
]
},
{
"cell_type": "markdown",
"id": "4c716c1a",
"metadata": {},
"source": [
"## Ranking result"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f9555d50",
"metadata": {},
"outputs": [],
"source": [
"ranking_file = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/result_MassSpecGym_retrieval_candidates_formula.pkl\"\n",
"with open(ranking_file, 'rb') as f:\n",
" ranking = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78cfd902",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 1 | \n",
" 5 | \n",
" 20 | \n",
"
\n",
" \n",
" \n",
" \n",
" | rank | \n",
" 20.688 | \n",
" 47.391 | \n",
" 72.368 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 1 5 20\n",
"rank 20.688 47.391 72.368"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r='rank'\n",
"result = []\n",
"\n",
"top_k = [1, 5, 20]\n",
"rank_result = {}\n",
"for k in top_k:\n",
" result.append(round(len(ranking[ranking[r]<=k])/len(ranking)*100, 3))\n",
"rank_result[r] = result\n",
"\n",
"pd.DataFrame.from_dict(rank_result, orient='index', columns=['1', '5', '20'])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "11bbc0d4",
"metadata": {},
"outputs": [],
"source": [
"def get_target(candidates, labels):\n",
" return np.array(candidates)[labels][0]\n",
"\n",
"def get_cand_at_1(candidates, scores):\n",
" return candidates[np.argmax(scores)]\n",
"\n",
"def get_top_score(scores):\n",
" return np.max(scores)\n",
"\n",
"def get_target_score(labels, scores):\n",
" return np.array(scores)[labels][0]\n",
"\n",
"def get_n_heavy_atoms(smiles):\n",
" mol = Chem.MolFromSmiles(smiles)\n",
" return mol.GetNumHeavyAtoms()\n",
"\n",
"ranking['target'] = ranking.apply(lambda x: get_target(x['candidates'], x['labels']), axis=1)\n",
"ranking['target_score'] = ranking.apply(lambda x: get_target_score(x['labels'], x['scores']), axis=1)\n",
"\n",
"ranking['cand@1'] = ranking.apply(lambda x: get_cand_at_1(x['candidates'], x['scores']), axis=1)\n",
"ranking['top_score'] = ranking.apply(lambda x: get_top_score(x['scores']), axis=1)\n",
"\n",
"ranking['n_heavy_atoms'] = ranking['target'].apply(get_n_heavy_atoms)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "763ea617",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" identifier | \n",
" candidates | \n",
" scores | \n",
" labels | \n",
" rank | \n",
" target | \n",
" target_score | \n",
" cand@1 | \n",
" top_score | \n",
" n_heavy_atoms | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" MassSpecGymID0000201 | \n",
" [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n",
" [0.17369578778743744, 0.12611594796180725, 0.2... | \n",
" [True, False, False, False, False, False, Fals... | \n",
" 17 | \n",
" CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n",
" 0.173696 | \n",
" COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C)... | \n",
" 0.259878 | \n",
" 57 | \n",
"
\n",
" \n",
" | 1 | \n",
" MassSpecGymID0000202 | \n",
" [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n",
" [0.05142267048358917, 0.07289629429578781, 0.1... | \n",
" [True, False, False, False, False, False, Fals... | \n",
" 24 | \n",
" CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n",
" 0.051423 | \n",
" COC(=O)/C(C)=C\\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(C... | \n",
" 0.237195 | \n",
" 57 | \n",
"
\n",
" \n",
" | 2 | \n",
" MassSpecGymID0000203 | \n",
" [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... | \n",
" [0.09354929625988007, 0.0947718694806099, 0.10... | \n",
" [True, False, False, False, False, False, Fals... | \n",
" 23 | \n",
" CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... | \n",
" 0.093549 | \n",
" C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO)... | \n",
" 0.238268 | \n",
" 57 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" identifier candidates \\\n",
"0 MassSpecGymID0000201 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n",
"1 MassSpecGymID0000202 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n",
"2 MassSpecGymID0000203 [CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(... \n",
"\n",
" scores \\\n",
"0 [0.17369578778743744, 0.12611594796180725, 0.2... \n",
"1 [0.05142267048358917, 0.07289629429578781, 0.1... \n",
"2 [0.09354929625988007, 0.0947718694806099, 0.10... \n",
"\n",
" labels rank \\\n",
"0 [True, False, False, False, False, False, Fals... 17 \n",
"1 [True, False, False, False, False, False, Fals... 24 \n",
"2 [True, False, False, False, False, False, Fals... 23 \n",
"\n",
" target target_score \\\n",
"0 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.173696 \n",
"1 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.051423 \n",
"2 CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([... 0.093549 \n",
"\n",
" cand@1 top_score n_heavy_atoms \n",
"0 COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C)... 0.259878 57 \n",
"1 COC(=O)/C(C)=C\\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(C... 0.237195 57 \n",
"2 C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO)... 0.238268 57 "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ranking.head(3)"
]
},
{
"cell_type": "markdown",
"id": "93ef333e",
"metadata": {},
"source": [
"## model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0b4e4250",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data path: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv\n",
"Processing formula spectra\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 231104/231104 [00:18<00:00, 12309.47it/s]\n",
"/data/yzhouc01/FILIP-MS/mvp/data/datasets.py:221: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded Model from checkpoint\n"
]
}
],
"source": [
"import sys\n",
"sys.path.insert(0, \"/data/yzhouc01/MassSpecGym\")\n",
"sys.path.insert(0, \"/data/yzhouc01/FILIP-MS\")\n",
"\n",
"from rdkit import RDLogger\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning import Trainer\n",
"from massspecgym.models.base import Stage\n",
"import os\n",
"\n",
"from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset\n",
"from mvp.utils.models import get_model\n",
"\n",
"from mvp.definitions import TEST_RESULTS_DIR\n",
"import yaml\n",
"from functools import partial\n",
"# Suppress RDKit warnings and errors\n",
"lg = RDLogger.logger()\n",
"lg.setLevel(RDLogger.CRITICAL)\n",
"\n",
"# Load model and data\n",
"\n",
"param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'\n",
"with open(param_pth) as f:\n",
" params = yaml.load(f, Loader=yaml.FullLoader)\n",
"\n",
"spec_featurizer = get_spec_featurizer(params['spectra_view'], params)\n",
"mol_featurizer = get_mol_featurizer(params['molecule_view'], params)\n",
"dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)\n",
"\n",
"\n",
"# load model\n",
"import torch \n",
"checkpoint_pth = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt\"\n",
"params['checkpoint_pth'] = checkpoint_pth\n",
"model = get_model(params['model'], params)"
]
},
{
"cell_type": "markdown",
"id": "bd9dc380",
"metadata": {},
"source": [
"## visualization function"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0883a1a2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"\n",
"# Atomic masses corresponding to your atom_labels\n",
"ATOM_LABELS = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']\n",
"ATOM_MASSES = np.array([\n",
" 1.0078, 12.0000, 15.9949, 14.0031, 30.9738, 31.9721, \n",
" 35.45, 18.9984, 79.90, 126.90, 10.811, 74.9216, 28.085, 78.96\n",
"])\n",
"norm_vector = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]\n",
"\n",
"def spectra_from_encoding(spectral_tensor, norm_vector=norm_vector):\n",
" \"\"\"\n",
" Convert encoded spectra (num_peaks x 15) into m/z, intensities, and molecular formulas.\n",
" Can undo normalization if a norm_vector is provided.\n",
" \n",
" Args:\n",
" spectral_tensor (np.ndarray or torch.Tensor): [num_peaks, 15]\n",
" norm_vector (np.ndarray or list): length 14, normalization factor for each atom\n",
" \n",
" Returns:\n",
" mzs (list of float): list of m/z values\n",
" intensities (list of float): list of intensities\n",
" formulas (list of str): molecular formula strings\n",
" \"\"\"\n",
" if hasattr(spectral_tensor, \"detach\"):\n",
" spectral_tensor = spectral_tensor.detach().cpu().numpy()\n",
" \n",
" counts = spectral_tensor[:, :14] # atom counts\n",
" intensities = spectral_tensor[:, 14] # last col = intensity\n",
" \n",
" # Undo normalization\n",
" if norm_vector is not None:\n",
" counts = counts * np.array(norm_vector)\n",
" \n",
" # Compute m/z\n",
" mzs = (counts * ATOM_MASSES).sum(axis=1)\n",
" \n",
" # Build molecular formula strings\n",
" formulas = []\n",
" for peak_counts in counts:\n",
" formula_parts = []\n",
" for elem, count in zip(ATOM_LABELS, peak_counts):\n",
" n = int(round(count))\n",
" if n > 0:\n",
" formula_parts.append(f\"{elem}{n if n > 1 else ''}\")\n",
" formulas.append(\"\".join(formula_parts) if formula_parts else \"Unknown\")\n",
" \n",
" return mzs.tolist(), intensities.tolist(), formulas\n",
"\n",
"\n",
"def mol_to_graph_coords(mol):\n",
" \"\"\"Return atom coordinates and bond list for a molecule.\"\"\"\n",
" rdDepictor.Compute2DCoords(mol)\n",
" conf = mol.GetConformer()\n",
" coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}\n",
" bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]\n",
" return coords, bonds\n",
"\n",
"def interactive_attention_visualization(spectral_embeds, graph_embeds, \n",
" peak_mzs, peak_intensities, peak_formulas, mol):\n",
" \"\"\"\n",
" Interactive visualization of peak-node similarity with color scale legend.\n",
" - Clicking a peak recolors nodes by similarity\n",
" - Clicking a node recolors peaks by similarity\n",
" \"\"\"\n",
" # Similarity matrix\n",
" spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)\n",
" graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)\n",
" \n",
" similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()\n",
" sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)\n",
" \n",
" num_peaks, num_nodes = similarity.shape\n",
" \n",
" # --- Molecule graph ---\n",
" coords, bonds = mol_to_graph_coords(mol)\n",
" atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]\n",
" atom_x = [coords[i].x for i in range(num_nodes)]\n",
" atom_y = [coords[i].y for i in range(num_nodes)]\n",
" \n",
" # --- Spectrum trace ---\n",
" spectrum_trace = go.Bar(\n",
" x=peak_mzs,\n",
" y=peak_intensities,\n",
" name='peak',\n",
" marker=dict(color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n",
" colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n",
" hovertext=[f\"Formula {f}\" for f in peak_formulas],\n",
" customdata=list(range(num_peaks)) # peak index\n",
" )\n",
" \n",
" # --- Graph nodes ---\n",
" graph_nodes = go.Scatter(\n",
" x=atom_x, y=atom_y,\n",
" mode=\"markers+text\",\n",
" name='node',\n",
" text=atom_labels,\n",
" textposition=\"middle center\",\n",
" marker=dict(size=20, color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n",
" colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n",
" customdata=list(range(num_nodes)),\n",
" # hovertext=[f\"Atom {i} ({label})\" for i, label in enumerate(atom_labels)]\n",
" )\n",
" \n",
" # --- Graph bonds ---\n",
" edge_x, edge_y = [], []\n",
" for i, j in bonds:\n",
" edge_x += [coords[i].x, coords[j].x, None]\n",
" edge_y += [coords[i].y, coords[j].y, None]\n",
" graph_edges = go.Scatter(\n",
" x=edge_x, y=edge_y,\n",
" mode=\"lines\", line=dict(color=\"gray\", width=2),\n",
" hoverinfo=\"none\", showlegend=False\n",
" )\n",
" \n",
" # --- Subplots ---\n",
" fig = make_subplots(rows=1, cols=2, subplot_titles=(\"Spectrum\", \"Molecule\"), \n",
" column_widths=[0.6, 0.4])\n",
" \n",
" fig.add_trace(spectrum_trace, row=1, col=1)\n",
" fig.add_trace(graph_edges, row=1, col=2)\n",
" fig.add_trace(graph_nodes, row=1, col=2)\n",
" \n",
" fig.update_xaxes(title=\"m/z\", row=1, col=1)\n",
" fig.update_yaxes(title=\"Intensity\", row=1, col=1)\n",
" fig.update_xaxes(visible=False, row=1, col=2)\n",
" fig.update_yaxes(visible=False, row=1, col=2)\n",
" \n",
" fig.update_layout(title=\"Peak ↔ Node Similarity\", showlegend=False)\n",
" \n",
" # --- Interactivity ---\n",
" from ipywidgets import VBox\n",
" fw = go.FigureWidget(fig)\n",
"\n",
" def highlight_nodes(trace, points, selector):\n",
" \"\"\"Click on peak → recolor nodes\"\"\"\n",
" if points.point_inds:\n",
" peak_idx = points.point_inds[0]\n",
" scores = sim_norm[peak_idx, :]\n",
" with fw.batch_update():\n",
" fw.data[2].marker.color = scores\n",
" fw.data[0].marker.color = [\"red\" if i == peak_idx else \"lightgray\" for i in range(num_peaks)]\n",
"\n",
" def highlight_peaks(trace, points, selector):\n",
" \"\"\"Click on node → recolor peaks\"\"\"\n",
" if points.point_inds:\n",
" node_idx = points.point_inds[0]\n",
" scores = sim_norm[:, node_idx]\n",
" with fw.batch_update():\n",
" fw.data[0].marker.color = scores\n",
" fw.data[2].marker.color = [\"red\" if i == node_idx else \"lightgray\" for i in range(num_nodes)]\n",
" \n",
" fw.data[0].on_click(highlight_nodes) # spectrum\n",
" fw.data[2].on_click(highlight_peaks) # nodes\n",
" \n",
" return fw\n"
]
},
{
"cell_type": "markdown",
"id": "7d1b98d3",
"metadata": {},
"source": [
"## Visualization"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "78f85d94",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MS ID: MassSpecGymID0396247, Target:0.423, Cand@1: 0.49\n",
"Target rank: 25\n"
]
}
],
"source": [
"# sample a case where targte is ranked at 2\n",
"sample = ranking[(ranking['rank']>20) & (ranking['n_heavy_atoms'] <=20)].sample(1).iloc[0]\n",
"ms_id = sample['identifier']\n",
"target = sample['target']\n",
"cand_at_1 = sample['cand@1']\n",
"print(f\"MS ID: {ms_id}, Target:{sample['target_score']:.3}, Cand@1: {sample['top_score']:.3}\")\n",
"print(f\"Target rank: {sample['rank']}\")"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "20437eb8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CCCCCCCC(=O)NC1=C2C(=CSS2)NC1=O\n",
"23\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c60a8c857d4041cdaf96a7f4ba2f4f61",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FigureWidget({\n",
" 'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n",
" 16, 17, 18, 19, 20, 21, 22],\n",
" 'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,\n",
" Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula\n",
" HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula\n",
" H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,\n",
" Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,\n",
" Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,\n",
" Formula H5C7ON2S2, Formula H2C10ON2S, Formula\n",
" H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],\n",
" 'marker': {'cmax': 1,\n",
" 'cmin': 0,\n",
" 'color': 'lightgray',\n",
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
" '#482878'], [0.2222222222222222,\n",
" '#3e4989'], [0.3333333333333333,\n",
" '#31688e'], [0.4444444444444444,\n",
" '#26828e'], [0.5555555555555556,\n",
" '#1f9e89'], [0.6666666666666666,\n",
" '#35b779'], [0.7777777777777778,\n",
" '#6ece58'], [0.8888888888888888,\n",
" '#b5de2b'], [1.0, '#fde725']]},\n",
" 'name': 'peak',\n",
" 'type': 'bar',\n",
" 'uid': 'c92eb0ab-8030-4f20-a4b0-126e19bad926',\n",
" 'x': [94.07799900290073, 104.05839936191737, 108.0936002238661,\n",
" 116.97070118690729, 126.10410051210002, 130.98300034787468,\n",
" 142.95000219490117, 143.9783008338645, 168.96559957769588,\n",
" 169.96090409332814, 170.96870403325858, 171.97650416466072,\n",
" 172.98100185312143, 181.96090015942156, 184.9843001706846,\n",
" 185.99210011061504, 190.01999790607465, 195.97650157185868,\n",
" 196.98430151178914, 197.98880457099676, 224.00769912172186,\n",
" 280.07010477147026, 298.08060429381726],\n",
" 'xaxis': 'x',\n",
" 'y': [0.24127991497516632, 0.16632197797298431, 0.16255460679531097,\n",
" 0.2183758169412613, 0.16817252337932587, 0.14077642560005188,\n",
" 0.24170850217342377, 0.23078560829162598, 0.12851069867610931,\n",
" 0.22988910973072052, 0.35106268525123596, 1.0,\n",
" 0.2844732105731964, 0.14581838250160217, 0.1700058877468109,\n",
" 0.3804142475128174, 0.13246509432792664, 0.18231017887592316,\n",
" 0.24256132543087006, 0.2150418609380722, 0.14223572611808777,\n",
" 0.3020711839199066, 1.100000023841858],\n",
" 'yaxis': 'y'},\n",
" {'hoverinfo': 'none',\n",
" 'line': {'color': 'gray', 'width': 2},\n",
" 'mode': 'lines',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'ef26d482-9ee9-47f4-823a-834c4cdab648',\n",
" 'x': [9.232981638284745, 7.913365912100706, None, 7.913365912100706,\n",
" 6.635932959325772, None, 6.635932959325772, 5.316317233141732,\n",
" None, 5.316317233141732, 4.038884280366799, None,\n",
" 4.038884280366799, 2.7192685541827593, None,\n",
" 2.7192685541827593, 1.4418356014078257, None,\n",
" 1.4418356014078257, 0.12221987522378566, None,\n",
" 0.12221987522378566, 0.08003710181467982, None,\n",
" 0.12221987522378566, -1.155213077551148, None,\n",
" -1.155213077551148, -2.474828803735188, None,\n",
" -2.474828803735188, -3.8274477532630358, None,\n",
" -3.8274477532630358, -4.862094590572173, None,\n",
" -4.862094590572173, -6.214713553764019, None,\n",
" -6.214713553764019, -6.016031209535085, None,\n",
" -6.016031209535085, -4.540619804645251, None,\n",
" -4.862094590572173, -4.148922589885585, None,\n",
" -4.148922589885585, -2.6735111792957427, None,\n",
" -2.6735111792957427, -1.587460593601559, None,\n",
" -2.6735111792957427, -2.474828803735188, None,\n",
" -4.540619804645251, -3.8274477532630358, None],\n",
" 'xaxis': 'x2',\n",
" 'y': [0.4812527731327094, -0.2319192504406855, None,\n",
" -0.2319192504406855, 0.5543154798814447, None,\n",
" 0.5543154798814447, -0.1588565436919498, None,\n",
" -0.1588565436919498, 0.6273781866301811, None,\n",
" 0.6273781866301811, -0.08579383694321274, None,\n",
" -0.08579383694321274, 0.7004408933789189, None,\n",
" 0.7004408933789189, -0.012731130194475332, None,\n",
" -0.012731130194475332, -1.5121378840900008, None,\n",
" -0.012731130194475332, 0.7735036001276558, None,\n",
" 0.7735036001276558, 0.06033157655426158, None,\n",
" 0.06033157655426158, 0.708731127277566, None,\n",
" 0.708731127277566, -0.3773194802201689, None,\n",
" -0.3773194802201689, 0.27108004199882174, None,\n",
" 0.27108004199882174, 1.757863592517086, None,\n",
" 1.757863592517086, 2.028346838432628, None,\n",
" -0.3773194802201689, -1.6969351846006568, None,\n",
" -1.6969351846006568, -1.426451969777085, None,\n",
" -1.426451969777085, -2.461098829973027, None,\n",
" -1.426451969777085, 0.06033157655426158, None,\n",
" 2.028346838432628, 0.708731127277566, None],\n",
" 'yaxis': 'y2'},\n",
" {'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n",
" 16, 17, 18],\n",
" 'marker': {'cmax': 1,\n",
" 'cmin': 0,\n",
" 'color': 'lightgray',\n",
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
" '#482878'], [0.2222222222222222,\n",
" '#3e4989'], [0.3333333333333333,\n",
" '#31688e'], [0.4444444444444444,\n",
" '#26828e'], [0.5555555555555556,\n",
" '#1f9e89'], [0.6666666666666666,\n",
" '#35b779'], [0.7777777777777778,\n",
" '#6ece58'], [0.8888888888888888,\n",
" '#b5de2b'], [1.0, '#fde725']],\n",
" 'size': 20},\n",
" 'mode': 'markers+text',\n",
" 'name': 'node',\n",
" 'text': [C, C, C, C, C, C, C, C, O, N, C, C, C, C, S, S, N, C, O],\n",
" 'textposition': 'middle center',\n",
" 'type': 'scatter',\n",
" 'uid': 'e157b647-8df0-4d9e-a022-9a8e503c555e',\n",
" 'x': [9.232981638284745, 7.913365912100706, 6.635932959325772,\n",
" 5.316317233141732, 4.038884280366799, 2.7192685541827593,\n",
" 1.4418356014078257, 0.12221987522378566, 0.08003710181467982,\n",
" -1.155213077551148, -2.474828803735188, -3.8274477532630358,\n",
" -4.862094590572173, -6.214713553764019, -6.016031209535085,\n",
" -4.540619804645251, -4.148922589885585, -2.6735111792957427,\n",
" -1.587460593601559],\n",
" 'xaxis': 'x2',\n",
" 'y': [0.4812527731327094, -0.2319192504406855, 0.5543154798814447,\n",
" -0.1588565436919498, 0.6273781866301811, -0.08579383694321274,\n",
" 0.7004408933789189, -0.012731130194475332, -1.5121378840900008,\n",
" 0.7735036001276558, 0.06033157655426158, 0.708731127277566,\n",
" -0.3773194802201689, 0.27108004199882174, 1.757863592517086,\n",
" 2.028346838432628, -1.6969351846006568, -1.426451969777085,\n",
" -2.461098829973027],\n",
" 'yaxis': 'y2'}],\n",
" 'layout': {'annotations': [{'font': {'size': 16},\n",
" 'showarrow': False,\n",
" 'text': 'Spectrum',\n",
" 'x': 0.27,\n",
" 'xanchor': 'center',\n",
" 'xref': 'paper',\n",
" 'y': 1.0,\n",
" 'yanchor': 'bottom',\n",
" 'yref': 'paper'},\n",
" {'font': {'size': 16},\n",
" 'showarrow': False,\n",
" 'text': 'Molecule',\n",
" 'x': 0.8200000000000001,\n",
" 'xanchor': 'center',\n",
" 'xref': 'paper',\n",
" 'y': 1.0,\n",
" 'yanchor': 'bottom',\n",
" 'yref': 'paper'}],\n",
" 'showlegend': False,\n",
" 'template': '...',\n",
" 'title': {'text': 'Peak ↔ Node Similarity'},\n",
" 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.54], 'title': {'text': 'm/z'}},\n",
" 'xaxis2': {'anchor': 'y2', 'domain': [0.64, 1.0], 'visible': False},\n",
" 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Intensity'}},\n",
" 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'visible': False}}\n",
"})"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Target Molecule\n",
"i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]\n",
"s = target\n",
"print(s)\n",
"mol = Chem.MolFromSmiles(s)\n",
"g = dataset[i]['mol']\n",
"spec = dataset[i]['SpecFormula']\n",
"\n",
"peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)\n",
"\n",
"print(len(peak_formulas))\n",
"# Embeddings\n",
"model = model.to(torch.device('cpu'))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" spec_enc, mol_enc = model.forward(dataset[i], stage='test')\n",
"\n",
"fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)\n",
"fw\n"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "ce3a20ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"H10C7\n",
"H8C3O2N2\n",
"H12C8\n",
"H3C3NS2\n",
"H14C8O\n",
"HC7NS\n",
"HC4ONS2\n",
"C7N2S\n",
"H3C6ONS2\n",
"H2C5ON2S2\n",
"H3C5ON2S2\n",
"H4C5ON2S2\n",
"HC8ON2S\n",
"H2C6ON2S2\n",
"H5C6ON2S2\n",
"H6C6ON2S2\n",
"H6C9ON2S\n",
"H4C7ON2S2\n",
"H5C7ON2S2\n",
"H2C10ON2S\n",
"H8C9ON2S2\n",
"H16C13ON2S2\n",
"H18C13O2N2S2\n"
]
}
],
"source": [
"for f in peak_formulas:\n",
" print(f)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "da93666c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cc1sc2[nH]c(=S)n(CCOC(C)C)c(=O)c2c1C\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "be11e39bfd484f238a5643ffbdff0d78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FigureWidget({\n",
" 'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n",
" 16, 17, 18, 19, 20, 21, 22],\n",
" 'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,\n",
" Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula\n",
" HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula\n",
" H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,\n",
" Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,\n",
" Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,\n",
" Formula H5C7ON2S2, Formula H2C10ON2S, Formula\n",
" H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],\n",
" 'marker': {'cmax': 1,\n",
" 'cmin': 0,\n",
" 'color': 'lightgray',\n",
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
" '#482878'], [0.2222222222222222,\n",
" '#3e4989'], [0.3333333333333333,\n",
" '#31688e'], [0.4444444444444444,\n",
" '#26828e'], [0.5555555555555556,\n",
" '#1f9e89'], [0.6666666666666666,\n",
" '#35b779'], [0.7777777777777778,\n",
" '#6ece58'], [0.8888888888888888,\n",
" '#b5de2b'], [1.0, '#fde725']]},\n",
" 'name': 'peak',\n",
" 'type': 'bar',\n",
" 'uid': 'ef26bc6b-dc8e-4cbb-b681-e2343d82ea94',\n",
" 'x': [94.07799900290073, 104.05839936191737, 108.0936002238661,\n",
" 116.97070118690729, 126.10410051210002, 130.98300034787468,\n",
" 142.95000219490117, 143.9783008338645, 168.96559957769588,\n",
" 169.96090409332814, 170.96870403325858, 171.97650416466072,\n",
" 172.98100185312143, 181.96090015942156, 184.9843001706846,\n",
" 185.99210011061504, 190.01999790607465, 195.97650157185868,\n",
" 196.98430151178914, 197.98880457099676, 224.00769912172186,\n",
" 280.07010477147026, 298.08060429381726],\n",
" 'xaxis': 'x',\n",
" 'y': [0.24127991497516632, 0.16632197797298431, 0.16255460679531097,\n",
" 0.2183758169412613, 0.16817252337932587, 0.14077642560005188,\n",
" 0.24170850217342377, 0.23078560829162598, 0.12851069867610931,\n",
" 0.22988910973072052, 0.35106268525123596, 1.0,\n",
" 0.2844732105731964, 0.14581838250160217, 0.1700058877468109,\n",
" 0.3804142475128174, 0.13246509432792664, 0.18231017887592316,\n",
" 0.24256132543087006, 0.2150418609380722, 0.14223572611808777,\n",
" 0.3020711839199066, 1.100000023841858],\n",
" 'yaxis': 'y'},\n",
" {'hoverinfo': 'none',\n",
" 'line': {'color': 'gray', 'width': 2},\n",
" 'mode': 'lines',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'aaaf4294-9093-49bb-905c-261137ec6e91',\n",
" 'x': [6.049790084552722, 4.550200591327338, None, 4.550200591327338,\n",
" 3.6971529128998446, None, 3.6971529128998446,\n",
" 2.260114954813708, None, 2.260114954813708, 0.9789776698290348,\n",
" None, 0.9789776698290348, -0.33725023794912345, None,\n",
" -0.33725023794912345, -1.6183875229337956, None,\n",
" -0.33725023794912345, -0.3723408607426099, None,\n",
" -0.3723408607426099, -1.6885687685207686, None,\n",
" -1.6885687685207686, -2.9697060535054405, None,\n",
" -2.9697060535054405, -4.2859339612835985, None,\n",
" -4.2859339612835985, -5.567071246268271, None,\n",
" -5.567071246268271, -6.88329915404643, None,\n",
" -5.567071246268271, -5.531980623474783, None,\n",
" -0.3723408607426099, 0.9087964242420628, None,\n",
" 0.9087964242420628, 0.8737058014485767, None,\n",
" 0.9087964242420628, 2.225024332020221, None, 2.225024332020221,\n",
" 3.640375092533582, None, 3.640375092533582, 4.07040056505774,\n",
" None, 3.640375092533582, 4.550200591327338, None,\n",
" 2.225024332020221, 2.260114954813708, None],\n",
" 'xaxis': 'x2',\n",
" 'y': [0.3527856091535654, 0.38787623194705184, None,\n",
" 0.38787623194705184, 1.6216953671242724, None,\n",
" 1.6216953671242724, 1.1916698946001143, None,\n",
" 1.1916698946001143, 1.9718540119865837, None,\n",
" 1.9718540119865837, 1.2524486361476674, None,\n",
" 1.2524486361476674, 2.032632753534136, None,\n",
" 1.2524486361476674, -0.24714085707771702, None,\n",
" -0.24714085707771702, -0.9665462329166336, None,\n",
" -0.9665462329166336, -0.18636211553016513, None,\n",
" -0.18636211553016513, -0.9057674913690814, None,\n",
" -0.9057674913690814, -0.1255833739826133, None,\n",
" -0.1255833739826133, -0.8449887498215296, None,\n",
" -0.1255833739826133, 1.3740061192427715, None,\n",
" -0.24714085707771702, -1.027324974464186, None,\n",
" -1.027324974464186, -2.5269144676895703, None,\n",
" -1.027324974464186, -0.30791959862527024, None,\n",
" -0.30791959862527024, -0.8046914020866298, None,\n",
" -0.8046914020866298, -2.2417293601727675, None,\n",
" -0.8046914020866298, 0.38787623194705184, None,\n",
" -0.30791959862527024, 1.1916698946001143, None],\n",
" 'yaxis': 'y2'},\n",
" {'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n",
" 16, 17, 18],\n",
" 'marker': {'cmax': 1,\n",
" 'cmin': 0,\n",
" 'color': 'lightgray',\n",
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
" '#482878'], [0.2222222222222222,\n",
" '#3e4989'], [0.3333333333333333,\n",
" '#31688e'], [0.4444444444444444,\n",
" '#26828e'], [0.5555555555555556,\n",
" '#1f9e89'], [0.6666666666666666,\n",
" '#35b779'], [0.7777777777777778,\n",
" '#6ece58'], [0.8888888888888888,\n",
" '#b5de2b'], [1.0, '#fde725']],\n",
" 'size': 20},\n",
" 'mode': 'markers+text',\n",
" 'name': 'node',\n",
" 'text': [C, C, S, C, N, C, S, N, C, C, O, C, C, C, C, O, C, C, C],\n",
" 'textposition': 'middle center',\n",
" 'type': 'scatter',\n",
" 'uid': '03337be1-04c1-4602-b8b2-bfbe9950f45e',\n",
" 'x': [6.049790084552722, 4.550200591327338, 3.6971529128998446,\n",
" 2.260114954813708, 0.9789776698290348, -0.33725023794912345,\n",
" -1.6183875229337956, -0.3723408607426099, -1.6885687685207686,\n",
" -2.9697060535054405, -4.2859339612835985, -5.567071246268271,\n",
" -6.88329915404643, -5.531980623474783, 0.9087964242420628,\n",
" 0.8737058014485767, 2.225024332020221, 3.640375092533582,\n",
" 4.07040056505774],\n",
" 'xaxis': 'x2',\n",
" 'y': [0.3527856091535654, 0.38787623194705184, 1.6216953671242724,\n",
" 1.1916698946001143, 1.9718540119865837, 1.2524486361476674,\n",
" 2.032632753534136, -0.24714085707771702, -0.9665462329166336,\n",
" -0.18636211553016513, -0.9057674913690814, -0.1255833739826133,\n",
" -0.8449887498215296, 1.3740061192427715, -1.027324974464186,\n",
" -2.5269144676895703, -0.30791959862527024, -0.8046914020866298,\n",
" -2.2417293601727675],\n",
" 'yaxis': 'y2'}],\n",
" 'layout': {'annotations': [{'font': {'size': 16},\n",
" 'showarrow': False,\n",
" 'text': 'Spectrum',\n",
" 'x': 0.27,\n",
" 'xanchor': 'center',\n",
" 'xref': 'paper',\n",
" 'y': 1.0,\n",
" 'yanchor': 'bottom',\n",
" 'yref': 'paper'},\n",
" {'font': {'size': 16},\n",
" 'showarrow': False,\n",
" 'text': 'Molecule',\n",
" 'x': 0.8200000000000001,\n",
" 'xanchor': 'center',\n",
" 'xref': 'paper',\n",
" 'y': 1.0,\n",
" 'yanchor': 'bottom',\n",
" 'yref': 'paper'}],\n",
" 'showlegend': False,\n",
" 'template': '...',\n",
" 'title': {'text': 'Peak ↔ Node Similarity'},\n",
" 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.54], 'title': {'text': 'm/z'}},\n",
" 'xaxis2': {'anchor': 'y2', 'domain': [0.64, 1.0], 'visible': False},\n",
" 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Intensity'}},\n",
" 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'visible': False}}\n",
"})"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Cand@1\n",
"\n",
"# Target Molecule\n",
"i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]\n",
"s = cand_at_1\n",
"print(s)\n",
"mol = Chem.MolFromSmiles(s)\n",
"g = dataset[i]['mol']\n",
"spec = dataset[i]['SpecFormula']\n",
"cand_mol= mol_featurizer(cand_at_1)\n",
"\n",
"peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)\n",
"\n",
"# Embeddings\n",
"model = model.to(torch.device('cpu'))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" input = copy.deepcopy(dataset[i])\n",
" input['mol'] = cand_mol\n",
" \n",
" spec_enc, mol_enc = model.forward(input, stage='test')\n",
"\n",
"fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)\n",
"fw\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aed9d517",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (spec)",
"language": "python",
"name": "spec"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}