In [53]:
import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from rdkit import Chem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
import pickle
import copy

## Ranking result

In [2]:
ranking_file = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/result_MassSpecGym_retrieval_candidates_formula.pkl"
with open(ranking_file, 'rb') as f:
    ranking = pickle.load(f)

In [None]:
r='rank'
result = []

top_k = [1, 5, 20]
rank_result = {}
for k in top_k:
    result.append(round(len(ranking[ranking[r]<=k])/len(ranking)*100, 3))
rank_result[r] = result

pd.DataFrame.from_dict(rank_result, orient='index', columns=['1', '5', '20'])

Unnamed: 0,1,5,20
rank,20.688,47.391,72.368


In [23]:
def get_target(candidates, labels):
    return np.array(candidates)[labels][0]

def get_cand_at_1(candidates, scores):
    return candidates[np.argmax(scores)]

def get_top_score(scores):
    return np.max(scores)

def get_target_score(labels, scores):
    return np.array(scores)[labels][0]

def get_n_heavy_atoms(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol.GetNumHeavyAtoms()

ranking['target'] = ranking.apply(lambda x: get_target(x['candidates'], x['labels']), axis=1)
ranking['target_score'] = ranking.apply(lambda x: get_target_score(x['labels'], x['scores']), axis=1)

ranking['cand@1'] = ranking.apply(lambda x: get_cand_at_1(x['candidates'], x['scores']), axis=1)
ranking['top_score'] = ranking.apply(lambda x: get_top_score(x['scores']), axis=1)

ranking['n_heavy_atoms'] = ranking['target'].apply(get_n_heavy_atoms)

In [24]:
ranking.head(3)

Unnamed: 0,identifier,candidates,scores,labels,rank,target,target_score,cand@1,top_score,n_heavy_atoms
0,MassSpecGymID0000201,[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...,"[0.17369578778743744, 0.12611594796180725, 0.2...","[True, False, False, False, False, False, Fals...",17,CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...,0.173696,COCCCN1C(=O)COc2ccc(N(C(=O)[C@H]3CN(C(=O)OC(C)...,0.259878,57
1,MassSpecGymID0000202,[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...,"[0.05142267048358917, 0.07289629429578781, 0.1...","[True, False, False, False, False, False, Fals...",24,CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...,0.051423,COC(=O)/C(C)=C\CC1(O)C(=O)C2CC(C(C)C)C13Oc1c(C...,0.237195,57
2,MassSpecGymID0000203,[CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N(...,"[0.09354929625988007, 0.0947718694806099, 0.10...","[True, False, False, False, False, False, Fals...",23,CC(C)[C@@H]1C(=O)N([C@H](C(=O)O[C@@H](C(=O)N([...,0.093549,C=CCOC12Oc3ccc(OC(=O)NCC)cc3C3C(CCCCO)C(CCCCO)...,0.238268,57


## model

In [12]:
import sys
sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
sys.path.insert(0, "/data/yzhouc01/FILIP-MS")

from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from massspecgym.models.base import Stage
import os

from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
from mvp.utils.models import get_model

from mvp.definitions import TEST_RESULTS_DIR
import yaml
from functools import partial
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Load model and data

param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
with open(param_pth) as f:
    params = yaml.load(f, Loader=yaml.FullLoader)

spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)


# load model
import torch 
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
params['checkpoint_pth'] = checkpoint_pth
model = get_model(params['model'], params)

Data path:  /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv
Processing formula spectra


100%|██████████| 231104/231104 [00:18<00:00, 12309.47it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)


Loaded Model from checkpoint


## visualization function

In [16]:

import torch.nn.functional as F
import numpy as np

# Atomic masses corresponding to your atom_labels
ATOM_LABELS = ['H', 'C',  'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
ATOM_MASSES = np.array([
    1.0078, 12.0000, 15.9949, 14.0031, 30.9738, 31.9721, 
    35.45, 18.9984, 79.90, 126.90, 10.811, 74.9216, 28.085, 78.96
])
norm_vector = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]

def spectra_from_encoding(spectral_tensor, norm_vector=norm_vector):
    """
    Convert encoded spectra (num_peaks x 15) into m/z, intensities, and molecular formulas.
    Can undo normalization if a norm_vector is provided.
    
    Args:
        spectral_tensor (np.ndarray or torch.Tensor): [num_peaks, 15]
        norm_vector (np.ndarray or list): length 14, normalization factor for each atom
    
    Returns:
        mzs (list of float): list of m/z values
        intensities (list of float): list of intensities
        formulas (list of str): molecular formula strings
    """
    if hasattr(spectral_tensor, "detach"):
        spectral_tensor = spectral_tensor.detach().cpu().numpy()
    
    counts = spectral_tensor[:, :14]  # atom counts
    intensities = spectral_tensor[:, 14]  # last col = intensity
    
    # Undo normalization
    if norm_vector is not None:
        counts = counts * np.array(norm_vector)
    
    # Compute m/z
    mzs = (counts * ATOM_MASSES).sum(axis=1)
    
    # Build molecular formula strings
    formulas = []
    for peak_counts in counts:
        formula_parts = []
        for elem, count in zip(ATOM_LABELS, peak_counts):
            n = int(round(count))
            if n > 0:
                formula_parts.append(f"{elem}{n if n > 1 else ''}")
        formulas.append("".join(formula_parts) if formula_parts else "Unknown")
    
    return mzs.tolist(), intensities.tolist(), formulas


def mol_to_graph_coords(mol):
    """Return atom coordinates and bond list for a molecule."""
    rdDepictor.Compute2DCoords(mol)
    conf = mol.GetConformer()
    coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}
    bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
    return coords, bonds

def interactive_attention_visualization(spectral_embeds, graph_embeds, 
                                        peak_mzs, peak_intensities, peak_formulas, mol):
    """
    Interactive visualization of peak-node similarity with color scale legend.
    - Clicking a peak recolors nodes by similarity
    - Clicking a node recolors peaks by similarity
    """
    # Similarity matrix
    spectral_embeds =  F.normalize(spectral_embeds, p=2, dim=-1)
    graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)
    
    similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()
    sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)
    
    num_peaks, num_nodes = similarity.shape
    
    # --- Molecule graph ---
    coords, bonds = mol_to_graph_coords(mol)
    atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]
    atom_x = [coords[i].x for i in range(num_nodes)]
    atom_y = [coords[i].y for i in range(num_nodes)]
    
    # --- Spectrum trace ---
    spectrum_trace = go.Bar(
        x=peak_mzs,
        y=peak_intensities,
        name='peak',
        marker=dict(color="lightgray", colorscale="Viridis", cmin=0, cmax=1,
                    colorbar=dict(title="Similarity", len=0.8, y=0.5)),
        hovertext=[f"Formula {f}" for f in  peak_formulas],
        customdata=list(range(num_peaks))  # peak index
    )
    
    # --- Graph nodes ---
    graph_nodes = go.Scatter(
        x=atom_x, y=atom_y,
        mode="markers+text",
        name='node',
        text=atom_labels,
        textposition="middle center",
        marker=dict(size=20, color="lightgray", colorscale="Viridis", cmin=0, cmax=1,
                    colorbar=dict(title="Similarity", len=0.8, y=0.5)),
        customdata=list(range(num_nodes)),
        # hovertext=[f"Atom {i} ({label})" for i, label in enumerate(atom_labels)]
    )
    
    # --- Graph bonds ---
    edge_x, edge_y = [], []
    for i, j in bonds:
        edge_x += [coords[i].x, coords[j].x, None]
        edge_y += [coords[i].y, coords[j].y, None]
    graph_edges = go.Scatter(
        x=edge_x, y=edge_y,
        mode="lines", line=dict(color="gray", width=2),
        hoverinfo="none", showlegend=False
    )
    
    # --- Subplots ---
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Spectrum", "Molecule"), 
                        column_widths=[0.6, 0.4])
    
    fig.add_trace(spectrum_trace, row=1, col=1)
    fig.add_trace(graph_edges, row=1, col=2)
    fig.add_trace(graph_nodes, row=1, col=2)
    
    fig.update_xaxes(title="m/z", row=1, col=1)
    fig.update_yaxes(title="Intensity", row=1, col=1)
    fig.update_xaxes(visible=False, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)
    
    fig.update_layout(title="Peak ↔ Node Similarity", showlegend=False)
    
    # --- Interactivity ---
    from ipywidgets import VBox
    fw = go.FigureWidget(fig)

    def highlight_nodes(trace, points, selector):
        """Click on peak → recolor nodes"""
        if points.point_inds:
            peak_idx = points.point_inds[0]
            scores = sim_norm[peak_idx, :]
            with fw.batch_update():
                fw.data[2].marker.color = scores
                fw.data[0].marker.color = ["red" if i == peak_idx else "lightgray" for i in range(num_peaks)]

    def highlight_peaks(trace, points, selector):
        """Click on node → recolor peaks"""
        if points.point_inds:
            node_idx = points.point_inds[0]
            scores = sim_norm[:, node_idx]
            with fw.batch_update():
                fw.data[0].marker.color = scores
                fw.data[2].marker.color = ["red" if i == node_idx else "lightgray" for i in range(num_nodes)]
    
    fw.data[0].on_click(highlight_nodes)  # spectrum
    fw.data[2].on_click(highlight_peaks)  # nodes
    
    return fw


## Visualization

In [92]:
# sample a case where targte is ranked at 2
sample = ranking[(ranking['rank']>20) & (ranking['n_heavy_atoms'] <=20)].sample(1).iloc[0]
ms_id = sample['identifier']
target = sample['target']
cand_at_1 = sample['cand@1']
print(f"MS ID: {ms_id}, Target:{sample['target_score']:.3}, Cand@1: {sample['top_score']:.3}")
print(f"Target rank: {sample['rank']}")

MS ID: MassSpecGymID0396247, Target:0.423, Cand@1: 0.49
Target rank: 25


In [96]:
# Target Molecule
i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]
s = target
print(s)
mol = Chem.MolFromSmiles(s)
g = dataset[i]['mol']
spec = dataset[i]['SpecFormula']

peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)

print(len(peak_formulas))
# Embeddings
model = model.to(torch.device('cpu'))
model.eval()
with torch.no_grad():
    spec_enc, mol_enc = model.forward(dataset[i], stage='test')

fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)
fw


CCCCCCCC(=O)NC1=C2C(=CSS2)NC1=O
23


FigureWidget({
    'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
                             16, 17, 18, 19, 20, 21, 22],
              'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,
                            Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula
                            HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula
                            H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,
                            Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,
                            Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,
                            Formula H5C7ON2S2, Formula H2C10ON2S, Formula
                            H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],
              'marker': {'cmax': 1,
                         'cmin': 0,
                         'color': 'lightgray',
                         'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y'

In [98]:
for f in peak_formulas:
    print(f)

H10C7
H8C3O2N2
H12C8
H3C3NS2
H14C8O
HC7NS
HC4ONS2
C7N2S
H3C6ONS2
H2C5ON2S2
H3C5ON2S2
H4C5ON2S2
HC8ON2S
H2C6ON2S2
H5C6ON2S2
H6C6ON2S2
H6C9ON2S
H4C7ON2S2
H5C7ON2S2
H2C10ON2S
H8C9ON2S2
H16C13ON2S2
H18C13O2N2S2


In [99]:
# Cand@1

# Target Molecule
i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]
s = cand_at_1
print(s)
mol = Chem.MolFromSmiles(s)
g = dataset[i]['mol']
spec = dataset[i]['SpecFormula']
cand_mol= mol_featurizer(cand_at_1)

peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)

# Embeddings
model = model.to(torch.device('cpu'))
model.eval()
with torch.no_grad():
    input = copy.deepcopy(dataset[i])
    input['mol'] = cand_mol
    
    spec_enc, mol_enc = model.forward(input, stage='test')

fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)
fw


Cc1sc2[nH]c(=S)n(CCOC(C)C)c(=O)c2c1C


FigureWidget({
    'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
                             16, 17, 18, 19, 20, 21, 22],
              'hovertext': [Formula H10C7, Formula H8C3O2N2, Formula H12C8,
                            Formula H3C3NS2, Formula H14C8O, Formula HC7NS, Formula
                            HC4ONS2, Formula C7N2S, Formula H3C6ONS2, Formula
                            H2C5ON2S2, Formula H3C5ON2S2, Formula H4C5ON2S2,
                            Formula HC8ON2S, Formula H2C6ON2S2, Formula H5C6ON2S2,
                            Formula H6C6ON2S2, Formula H6C9ON2S, Formula H4C7ON2S2,
                            Formula H5C7ON2S2, Formula H2C10ON2S, Formula
                            H8C9ON2S2, Formula H16C13ON2S2, Formula H18C13O2N2S2],
              'marker': {'cmax': 1,
                         'cmin': 0,
                         'color': 'lightgray',
                         'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y'