Spaces:
Sleeping
Sleeping
Commit
·
19a4dfc
1
Parent(s):
60219be
update
Browse files- app_utils/model_utils.py +4 -4
- app_utils/viz_utils.py +3 -12
- flare/data/datasets.py +19 -10
- flare/models/contrastive.py +184 -403
- flare/models/mol_encoder.py +2 -1
- flare/models/spec_encoder.py +2 -3
- flare/params_filipGlobal.yaml +95 -0
- flare/run.sh +3 -3
- flare/subformula_assign/run.sh +8 -3
- flare/subformula_assign/utils/chem_utils.py +4 -0
- flare/test.py +9 -1
- flare/tune.py +1 -1
- flare/utils/case_study_utils.py +193 -0
- flare/utils/general.py +94 -58
- flare/utils/loss.py +95 -4
- flare/utils/models.py +3 -1
- flare/utils/mol_search.py +367 -0
- notebooks/UMAP_spectra_embeddings.ipynb +0 -0
- notebooks/fine-grained_vs_global.ipynb +6 -2
- notebooks/good_vs_bad_instances.ipynb +0 -0
- notebooks/mol-spec_visualization.ipynb +0 -0
- notebooks/results.ipynb +233 -0
- notebooks/spectra_sim.ipynb +0 -0
app_utils/model_utils.py
CHANGED
|
@@ -3,8 +3,8 @@ import sys
|
|
| 3 |
# sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
|
| 4 |
|
| 5 |
from rdkit import RDLogger
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
|
| 9 |
import yaml
|
| 10 |
|
|
@@ -15,7 +15,7 @@ lg.setLevel(RDLogger.CRITICAL)
|
|
| 15 |
# Load model and data
|
| 16 |
|
| 17 |
def load_model_components():
|
| 18 |
-
param_pth = 'hparams.yaml'
|
| 19 |
with open(param_pth) as f:
|
| 20 |
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 21 |
|
|
@@ -24,7 +24,7 @@ def load_model_components():
|
|
| 24 |
|
| 25 |
# load model
|
| 26 |
|
| 27 |
-
checkpoint_pth = "epoch=1993-train_loss=0.10.ckpt"
|
| 28 |
params['checkpoint_pth'] = checkpoint_pth
|
| 29 |
model = get_model(params['model'], params)
|
| 30 |
|
|
|
|
| 3 |
# sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
|
| 4 |
|
| 5 |
from rdkit import RDLogger
|
| 6 |
+
from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
|
| 7 |
+
from flare.utils.models import get_model
|
| 8 |
|
| 9 |
import yaml
|
| 10 |
|
|
|
|
| 15 |
# Load model and data
|
| 16 |
|
| 17 |
def load_model_components():
|
| 18 |
+
param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
|
| 19 |
with open(param_pth) as f:
|
| 20 |
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 21 |
|
|
|
|
| 24 |
|
| 25 |
# load model
|
| 26 |
|
| 27 |
+
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
|
| 28 |
params['checkpoint_pth'] = checkpoint_pth
|
| 29 |
model = get_model(params['model'], params)
|
| 30 |
|
app_utils/viz_utils.py
CHANGED
|
@@ -6,7 +6,9 @@ import plotly.graph_objects as go
|
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
from rdkit import Chem
|
| 8 |
from rdkit.Chem import rdDepictor
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def mol_to_graph_coords(mol):
|
| 12 |
"""Return atom coordinates and bond list for a molecule."""
|
|
@@ -16,12 +18,6 @@ def mol_to_graph_coords(mol):
|
|
| 16 |
bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
|
| 17 |
return coords, bonds
|
| 18 |
|
| 19 |
-
import torch
|
| 20 |
-
import torch.nn.functional as F
|
| 21 |
-
import plotly.graph_objects as go
|
| 22 |
-
from plotly.subplots import make_subplots
|
| 23 |
-
|
| 24 |
-
|
| 25 |
def interactive_attention_visualization(
|
| 26 |
spectral_embeds,
|
| 27 |
graph_embeds,
|
|
@@ -68,7 +64,6 @@ def interactive_attention_visualization(
|
|
| 68 |
hoverinfo='text',
|
| 69 |
customdata=list(range(num_peaks)), # actual peak indices
|
| 70 |
)
|
| 71 |
-
|
| 72 |
# --- Graph nodes ---
|
| 73 |
graph_nodes = go.Scatter(
|
| 74 |
x=atom_x,
|
|
@@ -127,10 +122,6 @@ def interactive_attention_visualization(
|
|
| 127 |
# ------------------------
|
| 128 |
# Model set up
|
| 129 |
# ------------------------
|
| 130 |
-
|
| 131 |
-
from mvp.subformula_assign.utils.spectra_utils import assign_subforms
|
| 132 |
-
import matchms
|
| 133 |
-
|
| 134 |
def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
|
| 135 |
|
| 136 |
# step 1 - label peaks with formula, setup matchms spectrum
|
|
|
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
from rdkit import Chem
|
| 8 |
from rdkit.Chem import rdDepictor
|
| 9 |
+
|
| 10 |
+
from flare.subformula_assign.utils.spectra_utils import assign_subforms
|
| 11 |
+
import matchms
|
| 12 |
|
| 13 |
def mol_to_graph_coords(mol):
|
| 14 |
"""Return atom coordinates and bond list for a molecule."""
|
|
|
|
| 18 |
bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
|
| 19 |
return coords, bonds
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def interactive_attention_visualization(
|
| 22 |
spectral_embeds,
|
| 23 |
graph_embeds,
|
|
|
|
| 64 |
hoverinfo='text',
|
| 65 |
customdata=list(range(num_peaks)), # actual peak indices
|
| 66 |
)
|
|
|
|
| 67 |
# --- Graph nodes ---
|
| 68 |
graph_nodes = go.Scatter(
|
| 69 |
x=atom_x,
|
|
|
|
| 122 |
# ------------------------
|
| 123 |
# Model set up
|
| 124 |
# ------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
|
| 126 |
|
| 127 |
# step 1 - label peaks with formula, setup matchms spectrum
|
flare/data/datasets.py
CHANGED
|
@@ -83,7 +83,7 @@ class JESTR1_MassSpecDataset(MassSpecDataset):
|
|
| 83 |
|
| 84 |
spec = self.spectra[i]
|
| 85 |
metadata = self.metadata.iloc[i]
|
| 86 |
-
mol = metadata["smiles"]
|
| 87 |
|
| 88 |
# Apply all transformations to the spectrum
|
| 89 |
item = {}
|
|
@@ -254,7 +254,7 @@ class ContrastiveDataset(Dataset):
|
|
| 254 |
return item
|
| 255 |
|
| 256 |
@staticmethod
|
| 257 |
-
def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None) -> dict:
|
| 258 |
mol_key = 'cand' if stage == Stage.TEST else 'mol'
|
| 259 |
non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
|
| 260 |
require_pad = False
|
|
@@ -277,15 +277,16 @@ class ContrastiveDataset(Dataset):
|
|
| 277 |
raise
|
| 278 |
|
| 279 |
# batch graphs
|
| 280 |
-
batch_mol
|
| 281 |
-
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
|
| 290 |
# pad peaks/formulas
|
| 291 |
if require_pad:
|
|
@@ -347,7 +348,15 @@ class ExpandedRetrievalDataset:
|
|
| 347 |
|
| 348 |
self.candidates = {}
|
| 349 |
for s, cand in candidates.items():
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
self.spec_cand = [] #(spec index, cand_smiles, true_label)
|
| 353 |
|
|
|
|
| 83 |
|
| 84 |
spec = self.spectra[i]
|
| 85 |
metadata = self.metadata.iloc[i]
|
| 86 |
+
mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"]
|
| 87 |
|
| 88 |
# Apply all transformations to the spectrum
|
| 89 |
item = {}
|
|
|
|
| 254 |
return item
|
| 255 |
|
| 256 |
@staticmethod
|
| 257 |
+
def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, batch_mol: bool = True) -> dict:
|
| 258 |
mol_key = 'cand' if stage == Stage.TEST else 'mol'
|
| 259 |
non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
|
| 260 |
require_pad = False
|
|
|
|
| 277 |
raise
|
| 278 |
|
| 279 |
# batch graphs
|
| 280 |
+
if batch_mol:
|
| 281 |
+
batch_mol = []
|
| 282 |
+
batch_mol_nodes= []
|
| 283 |
|
| 284 |
+
for item in batch:
|
| 285 |
+
batch_mol.append(item[mol_key])
|
| 286 |
+
batch_mol_nodes.append(item[mol_key].num_nodes())
|
| 287 |
|
| 288 |
+
collated_batch[mol_key] = dgl.batch(batch_mol)
|
| 289 |
+
collated_batch['mol_n_nodes'] = batch_mol_nodes
|
| 290 |
|
| 291 |
# pad peaks/formulas
|
| 292 |
if require_pad:
|
|
|
|
| 348 |
|
| 349 |
self.candidates = {}
|
| 350 |
for s, cand in candidates.items():
|
| 351 |
+
clean_cands = []
|
| 352 |
+
for c in cand:
|
| 353 |
+
try:
|
| 354 |
+
if '.' not in c:
|
| 355 |
+
clean_cands.append(c)
|
| 356 |
+
except:
|
| 357 |
+
print(f"Error in processing candidate {c} for smiles {s}")
|
| 358 |
+
pass
|
| 359 |
+
self.candidates[s] = clean_cands
|
| 360 |
|
| 361 |
self.spec_cand = [] #(spec index, cand_smiles, true_label)
|
| 362 |
|
flare/models/contrastive.py
CHANGED
|
@@ -10,7 +10,7 @@ from massspecgym.models.base import Stage
|
|
| 10 |
from massspecgym import utils
|
| 11 |
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
|
| 13 |
-
from flare.utils.loss import contrastive_loss,
|
| 14 |
import flare.utils.models as model_utils
|
| 15 |
from flare.utils.general import pad_graph_nodes, filip_similarity_batch
|
| 16 |
|
|
@@ -18,14 +18,17 @@ from flare.models.encoders import CrossAttention
|
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
| 20 |
from torch_geometric.nn import global_mean_pool
|
|
|
|
| 21 |
|
| 22 |
class ContrastiveModel(RetrievalMassSpecGymModel):
|
| 23 |
def __init__(
|
| 24 |
self,
|
|
|
|
| 25 |
**kwargs
|
| 26 |
):
|
| 27 |
super().__init__(**kwargs)
|
| 28 |
self.save_hyperparameters()
|
|
|
|
| 29 |
|
| 30 |
if 'use_fp' not in self.hparams:
|
| 31 |
self.hparams.use_fp = False
|
|
@@ -42,13 +45,26 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 42 |
self.result_dct = defaultdict(lambda: defaultdict(list))
|
| 43 |
|
| 44 |
def forward(self, batch, stage):
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
spec = batch[self.spec_view]
|
| 48 |
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
|
| 49 |
spec_enc = self.spec_enc_model(spec, n_peaks)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
fp = batch['fp'] if self.hparams.use_fp else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
mol_enc = self.mol_enc_model(g, fp=fp)
|
| 53 |
|
| 54 |
return spec_enc, mol_enc
|
|
@@ -61,20 +77,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 61 |
losses['contr_loss'] = contr_loss.detach().item()
|
| 62 |
|
| 63 |
loss+=contr_loss
|
| 64 |
-
# if self.hparams.pred_fp:
|
| 65 |
-
# fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
|
| 66 |
-
# loss+= fp_loss_val
|
| 67 |
-
# losses['fp_loss'] = fp_loss_val.detach().item()
|
| 68 |
-
|
| 69 |
-
# if 'aug_cand_enc' in output:
|
| 70 |
-
# aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
|
| 71 |
-
# loss+= aug_cand_loss
|
| 72 |
-
# losses['aug_cand_loss'] = aug_cand_loss.detach().item()
|
| 73 |
-
|
| 74 |
-
# if 'ind_spec' in output:
|
| 75 |
-
# spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
|
| 76 |
-
# loss+=spec_loss
|
| 77 |
-
# losses['cons_spec_loss'] = spec_loss.detach().item()
|
| 78 |
|
| 79 |
losses['loss'] = loss
|
| 80 |
|
|
@@ -108,7 +110,7 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 108 |
# total loss
|
| 109 |
self.log(
|
| 110 |
f'{stage.to_pref()}loss',
|
| 111 |
-
|
| 112 |
batch_size=len(batch['identifier']),
|
| 113 |
sync_dist=True,
|
| 114 |
prog_bar=True,
|
|
@@ -146,11 +148,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 146 |
self.result_dct[i]['candidates'].extend(cands)
|
| 147 |
self.result_dct[i]['scores'].extend(scores.cpu().tolist())
|
| 148 |
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
|
| 149 |
-
|
| 150 |
-
# # external test case only
|
| 151 |
-
# for i, cands, scores in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores']):
|
| 152 |
-
# self.result_dct[i.cpu().item()]['candidates'].extend(cands)
|
| 153 |
-
# self.result_dct[i.cpu().item()]['scores'].extend(scores.cpu().tolist())
|
| 154 |
|
| 155 |
def _compute_rank(self, scores, labels):
|
| 156 |
if not any(labels):
|
|
@@ -160,12 +157,21 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 160 |
rank = np.count_nonzero(scores >=target_score)
|
| 161 |
return rank
|
| 162 |
|
|
|
|
|
|
|
|
|
|
| 163 |
def on_test_epoch_end(self) -> None:
|
| 164 |
|
| 165 |
self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
|
| 166 |
|
| 167 |
# Compute rank
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if not self.df_test_path:
|
| 170 |
self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
|
| 171 |
self.df_test.to_pickle(self.df_test_path)
|
|
@@ -176,160 +182,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 176 |
{"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
|
| 177 |
]
|
| 178 |
return monitors
|
| 179 |
-
|
| 180 |
-
# class MultiViewContrastive(ContrastiveModel):
|
| 181 |
-
|
| 182 |
-
# def __init__(self,
|
| 183 |
-
# **kwargs):
|
| 184 |
-
|
| 185 |
-
# super().__init__(**kwargs)
|
| 186 |
-
|
| 187 |
-
# # build fingerprint encoder model
|
| 188 |
-
# if self.hparams.use_fp:
|
| 189 |
-
# self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
|
| 190 |
-
|
| 191 |
-
# # build NL encoder model
|
| 192 |
-
# if self.hparams.use_NL_spec:
|
| 193 |
-
# self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
|
| 194 |
-
|
| 195 |
-
# def forward(self, batch, stage):
|
| 196 |
-
# g = batch['cand'] if stage == Stage.TEST else batch['mol']
|
| 197 |
-
|
| 198 |
-
# spec = batch[self.spec_view]
|
| 199 |
-
# n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
|
| 200 |
-
|
| 201 |
-
# spec_enc = self.spec_enc_model(spec, n_peaks)
|
| 202 |
-
# mol_enc = self.mol_enc_model(g)
|
| 203 |
-
# views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
|
| 204 |
-
|
| 205 |
-
# if self.hparams.use_fp:
|
| 206 |
-
# fp_enc = self.fp_enc_model(batch['fp'])
|
| 207 |
-
# views['fp_enc'] = fp_enc
|
| 208 |
-
|
| 209 |
-
# if self.hparams.use_cons_spec:
|
| 210 |
-
# spec = batch['cons_spec']
|
| 211 |
-
# n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
|
| 212 |
-
# spec_enc = self.cons_spec_enc_model(spec, n_peaks)
|
| 213 |
-
# views['cons_spec_enc'] = spec_enc
|
| 214 |
-
|
| 215 |
-
# if self.hparams.use_NL_spec:
|
| 216 |
-
# spec = batch['NL_spec']
|
| 217 |
-
# n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
|
| 218 |
-
# spec_enc = self.NL_enc_model(spec, n_peaks)
|
| 219 |
-
# views['NL_spec_enc'] = spec_enc
|
| 220 |
-
# return views
|
| 221 |
-
|
| 222 |
-
# def step(
|
| 223 |
-
# self, batch: dict, stage= Stage.NONE):
|
| 224 |
-
|
| 225 |
-
# # Compute spectra and mol encoding
|
| 226 |
-
# views = self.forward(batch, stage)
|
| 227 |
-
|
| 228 |
-
# if stage == Stage.TEST:
|
| 229 |
-
# return views
|
| 230 |
-
|
| 231 |
-
# # Calculate loss
|
| 232 |
-
# losses = self.compute_loss(batch, views)
|
| 233 |
-
|
| 234 |
-
# return losses
|
| 235 |
-
|
| 236 |
-
# def compute_loss(self, batch: dict, views: dict):
|
| 237 |
-
# loss = 0
|
| 238 |
-
# losses = {}
|
| 239 |
-
# for v1, v2 in self.hparams.contr_views:
|
| 240 |
-
# contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
|
| 241 |
-
# loss+=contr_loss
|
| 242 |
-
|
| 243 |
-
# losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
|
| 244 |
-
# losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
|
| 245 |
-
# losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
|
| 246 |
-
|
| 247 |
-
# losses['loss'] = loss
|
| 248 |
-
|
| 249 |
-
# return losses
|
| 250 |
-
|
| 251 |
-
# def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
|
| 252 |
-
# # total loss
|
| 253 |
-
# self.log(
|
| 254 |
-
# f'{stage.to_pref()}loss',
|
| 255 |
-
# outputs['loss'],
|
| 256 |
-
# batch_size=len(batch['identifier']),
|
| 257 |
-
# sync_dist=True,
|
| 258 |
-
# prog_bar=True,
|
| 259 |
-
# on_epoch=True,
|
| 260 |
-
# # on_step=True
|
| 261 |
-
# )
|
| 262 |
-
|
| 263 |
-
# for v1, v2 in self.hparams.contr_views:
|
| 264 |
-
# self.log(
|
| 265 |
-
# f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
|
| 266 |
-
# outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
|
| 267 |
-
# batch_size=len(batch['identifier']),
|
| 268 |
-
# sync_dist=True,
|
| 269 |
-
# on_epoch=True,
|
| 270 |
-
# )
|
| 271 |
-
# self.log(
|
| 272 |
-
# f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
|
| 273 |
-
# outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
|
| 274 |
-
# batch_size=len(batch['identifier']),
|
| 275 |
-
# sync_dist=True,
|
| 276 |
-
# on_epoch=True,
|
| 277 |
-
# )
|
| 278 |
-
# self.log(
|
| 279 |
-
# f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
|
| 280 |
-
# outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
|
| 281 |
-
# batch_size=len(batch['identifier']),
|
| 282 |
-
# sync_dist=True,
|
| 283 |
-
# on_epoch=True,
|
| 284 |
-
# )
|
| 285 |
-
|
| 286 |
-
# def test_step(self, batch):
|
| 287 |
-
# # Unpack inputs
|
| 288 |
-
# identifiers = batch['identifier']
|
| 289 |
-
# cand_smiles = batch['cand_smiles']
|
| 290 |
-
# id_to_ct = defaultdict(int)
|
| 291 |
-
# for i in identifiers: id_to_ct[i]+=1
|
| 292 |
-
# batch_ptr = torch.tensor(list(id_to_ct.values()))
|
| 293 |
-
|
| 294 |
-
# outputs = self.step(batch, stage=Stage.TEST)
|
| 295 |
-
# scores = {}
|
| 296 |
-
# for v1, v2 in self.hparams.contr_views:
|
| 297 |
-
# # if 'cons_spec_enc' in (v1, v2):
|
| 298 |
-
# # continue
|
| 299 |
-
# v1_enc = outputs[v1]
|
| 300 |
-
# v2_enc = outputs[v2]
|
| 301 |
-
|
| 302 |
-
# s = nn.functional.cosine_similarity(v1_enc, v2_enc)
|
| 303 |
-
# scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
|
| 304 |
-
|
| 305 |
-
# indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
|
| 306 |
-
|
| 307 |
-
# cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
|
| 308 |
-
# labels = utils.unbatch_list(batch['label'], indexes)
|
| 309 |
-
|
| 310 |
-
# return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
|
| 311 |
-
|
| 312 |
-
# def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
|
| 313 |
-
|
| 314 |
-
# # save scores
|
| 315 |
-
# for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
|
| 316 |
-
# self.result_dct[i]['candidates'].extend(cands)
|
| 317 |
-
# self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
|
| 318 |
-
|
| 319 |
-
# for v1, v2 in self.hparams.contr_views:
|
| 320 |
-
# for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
|
| 321 |
-
# self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
# def on_test_epoch_end(self) -> None:
|
| 325 |
-
|
| 326 |
-
# self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
|
| 327 |
-
|
| 328 |
-
# # Compute rank
|
| 329 |
-
# for v1, v2 in self.hparams.contr_views:
|
| 330 |
-
# self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
|
| 331 |
-
|
| 332 |
-
# self.df_test.to_pickle(self.df_test_path)
|
| 333 |
|
| 334 |
class FilipContrastive(ContrastiveModel):
|
| 335 |
def __init__(self,
|
|
@@ -381,7 +233,7 @@ class FilipContrastive(ContrastiveModel):
|
|
| 381 |
# Calculate scores
|
| 382 |
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
|
| 383 |
|
| 384 |
-
scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask,
|
| 385 |
scores = torch.split(scores, list(id_to_ct.values()))
|
| 386 |
|
| 387 |
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
|
|
@@ -389,248 +241,177 @@ class FilipContrastive(ContrastiveModel):
|
|
| 389 |
|
| 390 |
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
|
| 391 |
|
| 392 |
-
# class MultiViewFineTuning(MultiViewContrastive):
|
| 393 |
-
# def __init__(self,
|
| 394 |
-
# **kwargs):
|
| 395 |
-
# super().__init__(**kwargs)
|
| 396 |
-
|
| 397 |
-
# # load preptrained spec, mol, fp encoders
|
| 398 |
-
# checkpoint = torch.load(self.hparams.partial_checkpoint)
|
| 399 |
-
# state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
|
| 400 |
-
# self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
|
| 401 |
-
|
| 402 |
-
# state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
|
| 403 |
-
# self.mol_enc_model.load_state_dict(state_dict)
|
| 404 |
-
|
| 405 |
-
# state_dict = state_dict = {k[len("fp_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("fp_enc_model")}
|
| 406 |
-
# self.fp_enc_model.load_state_dict(state_dict)
|
| 407 |
-
|
| 408 |
-
# self.encoding_views = ['spec_enc', 'mol_enc', 'fp_enc']
|
| 409 |
-
# self.loss_fn = nn.BCELoss()
|
| 410 |
-
|
| 411 |
-
# # freeze encoders
|
| 412 |
-
# for param in self.mol_enc_model.parameters():
|
| 413 |
-
# param.requires_grad = False
|
| 414 |
-
# for param in self.spec_enc_model.parameters():
|
| 415 |
-
# param.requires_grad = False
|
| 416 |
-
# for param in self.fp_enc_model.parameters():
|
| 417 |
-
# param.requires_grad = False
|
| 418 |
-
# for param in self.cons_spec_enc_model.parameters():
|
| 419 |
-
# param.requires_grad = False
|
| 420 |
-
|
| 421 |
-
# # n_views = 2
|
| 422 |
-
# # if self.hparams.use_fp:
|
| 423 |
-
# # n_views+=1
|
| 424 |
-
|
| 425 |
-
# # in_dim = self.hparams.final_embedding_dim*n_views
|
| 426 |
-
# in_dim = self.hparams.final_embedding_dim *2 + 2
|
| 427 |
-
|
| 428 |
-
# self.classifier_model = nn.Sequential(
|
| 429 |
-
# nn.Linear(in_dim, 512),
|
| 430 |
-
# nn.ReLU(),
|
| 431 |
-
# nn.BatchNorm1d(512),
|
| 432 |
-
# nn.Dropout(0.3),
|
| 433 |
-
# nn.Linear(512, 256),
|
| 434 |
-
# nn.ReLU(),
|
| 435 |
-
# nn.BatchNorm1d(256),
|
| 436 |
-
# nn.Dropout(0.3),
|
| 437 |
-
# nn.Linear(256, 1),
|
| 438 |
-
# nn.Sigmoid()
|
| 439 |
-
# )
|
| 440 |
-
# self.noise_std = 0.01
|
| 441 |
-
|
| 442 |
-
# def _add_noise(self, x):
|
| 443 |
-
# noise = torch.randn_like(x) * self.noise_std
|
| 444 |
-
# return x + noise
|
| 445 |
-
|
| 446 |
-
# def forward(self, batch, stage):
|
| 447 |
-
|
| 448 |
-
# matching_views = super().forward(batch, stage)
|
| 449 |
-
# # matching_enc = torch.concat((matching_views['spec_enc'], matching_views['mol_enc'], matching_views['fp_enc']), dim=-1)
|
| 450 |
-
# # enc1 = matching_views['spec_enc'] - matching_views['mol_enc']
|
| 451 |
-
# # enc2 = matching_views['spec_enc'] - matching_views['fp_enc']
|
| 452 |
-
# # matching_enc = torch.concat((enc1, enc2), dim=-1)
|
| 453 |
-
# view1 = matching_views['spec_enc']
|
| 454 |
-
# view2 = matching_views['mol_enc']
|
| 455 |
-
# view3 = matching_views['fp_enc']
|
| 456 |
-
|
| 457 |
-
# if stage == Stage.TRAIN:
|
| 458 |
-
# view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
|
| 459 |
-
|
| 460 |
-
# pairwise_diffs = torch.cat([
|
| 461 |
-
# torch.abs(view1 - view2),
|
| 462 |
-
# torch.abs(view1 - view3),
|
| 463 |
-
# ], dim=-1)
|
| 464 |
-
|
| 465 |
-
# pairwise_sims = torch.cat([
|
| 466 |
-
# (view1 * view2).sum(dim=-1, keepdim=True),
|
| 467 |
-
# (view1 * view3).sum(dim=-1, keepdim=True),
|
| 468 |
-
# ], dim=-1)
|
| 469 |
-
|
| 470 |
-
# matching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
|
| 471 |
-
# matching_scores = self.classifier_model(matching_enc)
|
| 472 |
-
|
| 473 |
-
# if stage == Stage.TEST:
|
| 474 |
-
# return dict(matching_scores = matching_scores)
|
| 475 |
-
|
| 476 |
-
# view1 = view1.repeat_interleave(self.hparams.aug_cands_size, dim=0)
|
| 477 |
-
# view2 = self.mol_enc_model(batch['aug_cands'])
|
| 478 |
-
# view3= self.fp_enc_model(batch['aug_cands_fp'])
|
| 479 |
-
# if stage == Stage.TRAIN:
|
| 480 |
-
# view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
|
| 481 |
|
| 482 |
-
#
|
| 483 |
-
#
|
| 484 |
-
#
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
-
#
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
# ], dim=-1)
|
| 491 |
|
| 492 |
-
# nonmatching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
|
| 493 |
|
| 494 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
# def compute_loss(self, matching_scores, nonmatching_scores):
|
| 499 |
-
|
| 500 |
-
# matching_loss = self.loss_fn(matching_scores, torch.ones_like(matching_scores).to(matching_scores.device))
|
| 501 |
-
# nonmatching_loss = self.loss_fn(nonmatching_scores, torch.zeros_like(nonmatching_scores).to(nonmatching_scores.device))
|
| 502 |
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
# def step(
|
| 508 |
-
# self, batch: dict, stage= Stage.NONE):
|
| 509 |
-
|
| 510 |
-
# output = self.forward(batch, stage)
|
| 511 |
|
| 512 |
-
# if stage == Stage.TEST:
|
| 513 |
-
# return output
|
| 514 |
|
| 515 |
-
|
| 516 |
-
# losses = self.compute_loss(output['matching_scores'], output['nonmatching_scores'])
|
| 517 |
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
# cand_smiles = batch['cand_smiles']
|
| 524 |
-
# id_to_ct = defaultdict(int)
|
| 525 |
-
# for i in identifiers: id_to_ct[i]+=1
|
| 526 |
-
# batch_ptr = torch.tensor(list(id_to_ct.values()))
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
|
| 531 |
-
#
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
# def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
|
| 539 |
-
# # total loss
|
| 540 |
-
# self.log(
|
| 541 |
-
# f'{stage.to_pref()}loss',
|
| 542 |
-
# outputs['loss'],
|
| 543 |
-
# batch_size=len(batch['identifier']),
|
| 544 |
-
# sync_dist=True,
|
| 545 |
-
# prog_bar=True,
|
| 546 |
-
# on_epoch=True,
|
| 547 |
-
# # on_step=True
|
| 548 |
-
# )
|
| 549 |
-
|
| 550 |
-
# def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
|
| 551 |
-
# ContrastiveModel.on_test_batch_end(self, outputs, batch, batch_idx, stage)
|
| 552 |
-
|
| 553 |
-
# def on_test_epoch_end(self):
|
| 554 |
-
# self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
|
| 555 |
-
# # self.df_test.to_csv(self.hparams.resutl)
|
| 556 |
-
# print(self.df_test_path)
|
| 557 |
-
# self.df_test.to_pickle(self.df_test_path)
|
| 558 |
-
# # ContrastiveModel.on_test_epoch_end(self)
|
| 559 |
-
|
| 560 |
-
# def get_checkpoint_monitors(self) -> T.List[dict]:
|
| 561 |
-
# monitors = [
|
| 562 |
-
# {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
|
| 563 |
-
# ]
|
| 564 |
-
# return monitors
|
| 565 |
-
# def configure_optimizers(self):
|
| 566 |
-
# return torch.optim.Adam(
|
| 567 |
-
# self.classifier_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
|
| 568 |
-
# )
|
| 569 |
-
|
| 570 |
-
# class IndSpecEncoder(ContrastiveModel):
|
| 571 |
-
# """ Trains a spectra encoder that maps to a pretrained spec encoder"""
|
| 572 |
-
# def __init__(
|
| 573 |
-
# self,
|
| 574 |
-
# **kwargs
|
| 575 |
-
# ):
|
| 576 |
-
# super().__init__(**kwargs)
|
| 577 |
-
|
| 578 |
-
# # initialize ind_spec_encoder and loss
|
| 579 |
-
# self.ind_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
|
| 580 |
-
# self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
|
| 581 |
-
|
| 582 |
-
# # load preptrained spec and mol encoders
|
| 583 |
-
# checkpoint = torch.load(self.hparams.partial_checkpoint)
|
| 584 |
-
# state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
|
| 585 |
-
# self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
|
| 586 |
-
|
| 587 |
-
# state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
|
| 588 |
-
# self.mol_enc_model.load_state_dict(state_dict)
|
| 589 |
-
|
| 590 |
-
# # freeze cons spec and mol encoders
|
| 591 |
-
# for param in self.mol_enc_model.parameters():
|
| 592 |
-
# param.requires_grad = False
|
| 593 |
-
# for param in self.spec_enc_model.parameters():
|
| 594 |
-
# param.requires_grad = False
|
| 595 |
-
|
| 596 |
-
# def forward(self, batch, stage):
|
| 597 |
-
|
| 598 |
-
# spec = batch[self.spec_view]
|
| 599 |
-
# n_peaks = batch['n_peaks']
|
| 600 |
-
# spec_enc = self.ind_spec_enc_model(spec, n_peaks)
|
| 601 |
-
|
| 602 |
-
# return spec_enc
|
| 603 |
-
|
| 604 |
-
# def compute_loss(self, spec_enc, cons_spec_enc):
|
| 605 |
-
# loss = self.cons_loss(spec_enc, cons_spec_enc)
|
| 606 |
-
# return dict(loss=loss)
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
|
|
|
|
|
|
| 611 |
|
| 612 |
-
|
| 613 |
|
| 614 |
-
#
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
| 635 |
class CrossAttenContrastive(ContrastiveModel):
|
| 636 |
def __init__(
|
|
|
|
| 10 |
from massspecgym import utils
|
| 11 |
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
|
| 13 |
+
from flare.utils.loss import contrastive_loss, filip_loss_with_mask, global_infonce_loss, pcgrad_combine
|
| 14 |
import flare.utils.models as model_utils
|
| 15 |
from flare.utils.general import pad_graph_nodes, filip_similarity_batch
|
| 16 |
|
|
|
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
| 20 |
from torch_geometric.nn import global_mean_pool
|
| 21 |
+
import torch, dgllife
|
| 22 |
|
| 23 |
class ContrastiveModel(RetrievalMassSpecGymModel):
|
| 24 |
def __init__(
|
| 25 |
self,
|
| 26 |
+
external_test: bool = False,
|
| 27 |
**kwargs
|
| 28 |
):
|
| 29 |
super().__init__(**kwargs)
|
| 30 |
self.save_hyperparameters()
|
| 31 |
+
self.external_test = external_test
|
| 32 |
|
| 33 |
if 'use_fp' not in self.hparams:
|
| 34 |
self.hparams.use_fp = False
|
|
|
|
| 45 |
self.result_dct = defaultdict(lambda: defaultdict(list))
|
| 46 |
|
| 47 |
def forward(self, batch, stage):
|
| 48 |
+
if 'cand' in batch:
|
| 49 |
+
g = batch['cand']
|
| 50 |
+
elif 'mol' in batch:
|
| 51 |
+
g = batch['mol']
|
| 52 |
+
else:
|
| 53 |
+
g = None
|
| 54 |
+
|
| 55 |
spec = batch[self.spec_view]
|
| 56 |
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
|
| 57 |
spec_enc = self.spec_enc_model(spec, n_peaks)
|
| 58 |
|
| 59 |
+
if g is None:
|
| 60 |
+
mol_enc = None
|
| 61 |
+
return spec_enc, mol_enc
|
| 62 |
+
|
| 63 |
fp = batch['fp'] if self.hparams.use_fp else None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
f = self.mol_enc_model.GNN(g, g.ndata['h'])
|
| 67 |
+
|
| 68 |
mol_enc = self.mol_enc_model(g, fp=fp)
|
| 69 |
|
| 70 |
return spec_enc, mol_enc
|
|
|
|
| 77 |
losses['contr_loss'] = contr_loss.detach().item()
|
| 78 |
|
| 79 |
loss+=contr_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
losses['loss'] = loss
|
| 82 |
|
|
|
|
| 110 |
# total loss
|
| 111 |
self.log(
|
| 112 |
f'{stage.to_pref()}loss',
|
| 113 |
+
outputs['loss'],
|
| 114 |
batch_size=len(batch['identifier']),
|
| 115 |
sync_dist=True,
|
| 116 |
prog_bar=True,
|
|
|
|
| 148 |
self.result_dct[i]['candidates'].extend(cands)
|
| 149 |
self.result_dct[i]['scores'].extend(scores.cpu().tolist())
|
| 150 |
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def _compute_rank(self, scores, labels):
|
| 153 |
if not any(labels):
|
|
|
|
| 157 |
rank = np.count_nonzero(scores >=target_score)
|
| 158 |
return rank
|
| 159 |
|
| 160 |
+
def _get_top_cand(self, scores, candidates):
|
| 161 |
+
return candidates[np.argmax(np.array(scores))]
|
| 162 |
+
|
| 163 |
def on_test_epoch_end(self) -> None:
|
| 164 |
|
| 165 |
self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
|
| 166 |
|
| 167 |
# Compute rank
|
| 168 |
+
if not self.external_test:
|
| 169 |
+
self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
|
| 170 |
+
|
| 171 |
+
if self.external_test:
|
| 172 |
+
self.df_test.drop('labels', axis=1, inplace=True)
|
| 173 |
+
self.df_test['top_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row['scores'], row['candidates']), axis=1)
|
| 174 |
+
|
| 175 |
if not self.df_test_path:
|
| 176 |
self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
|
| 177 |
self.df_test.to_pickle(self.df_test_path)
|
|
|
|
| 182 |
{"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
|
| 183 |
]
|
| 184 |
return monitors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
class FilipContrastive(ContrastiveModel):
|
| 187 |
def __init__(self,
|
|
|
|
| 233 |
# Calculate scores
|
| 234 |
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
|
| 235 |
|
| 236 |
+
scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
|
| 237 |
scores = torch.split(scores, list(id_to_ct.values()))
|
| 238 |
|
| 239 |
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
|
|
|
|
| 241 |
|
| 242 |
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
# ============================================================
|
| 246 |
+
# Combined FILIP + Global InfoNCE
|
| 247 |
+
# ============================================================
|
| 248 |
+
class FilipGlobalContrastive(ContrastiveModel):
|
| 249 |
+
def __init__(self, loss_mode="sum", loss_weight=1.0, agg_fn="mean", **kwargs):
|
| 250 |
+
"""
|
| 251 |
+
Args:
|
| 252 |
+
loss_mode: str, one of ["sum", "weighted", "pcgrad"]
|
| 253 |
+
loss_weight: weight for global loss if using weighted sum
|
| 254 |
+
agg_fn: aggregation function for global InfoNCE ("mean", "max", "cls")
|
| 255 |
+
"""
|
| 256 |
+
super().__init__(**kwargs)
|
| 257 |
+
self.loss_mode = loss_mode
|
| 258 |
+
self.loss_weight = loss_weight
|
| 259 |
+
self.agg_fn = agg_fn
|
| 260 |
|
| 261 |
+
# -------------- loss computation --------------
|
| 262 |
+
def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask, stage=Stage.NONE):
|
| 263 |
+
losses = {}
|
|
|
|
| 264 |
|
|
|
|
| 265 |
|
| 266 |
+
# fine-grained FILIP loss
|
| 267 |
+
loss_fine = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp)
|
| 268 |
+
# global InfoNCE loss
|
| 269 |
+
loss_global = global_infonce_loss(spec_enc, mol_enc, spec_mask, mol_mask,
|
| 270 |
+
temperature=self.hparams.contr_temp, agg_fn=self.agg_fn)
|
| 271 |
+
|
| 272 |
+
# choose combination mode
|
| 273 |
+
if self.loss_mode == "sum":
|
| 274 |
+
loss = loss_fine + loss_global
|
| 275 |
+
elif self.loss_mode == "weighted":
|
| 276 |
+
loss = loss_fine + self.loss_weight * loss_global
|
| 277 |
+
elif self.loss_mode == "pcgrad":
|
| 278 |
+
|
| 279 |
+
if stage == Stage.TRAIN:
|
| 280 |
+
# PCGrad over both losses (training only)
|
| 281 |
+
shared_params = list(self.spec_enc_model.parameters()) + list(self.mol_enc_model.parameters())
|
| 282 |
+
self.zero_grad(set_to_none=True)
|
| 283 |
+
loss = pcgrad_combine([loss_fine, loss_global], shared_params)
|
| 284 |
+
else:
|
| 285 |
+
|
| 286 |
+
loss = (loss_fine + loss_global).detach()
|
| 287 |
|
| 288 |
+
else:
|
| 289 |
+
raise ValueError(f"Unsupported loss_mode: {self.loss_mode}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
+
losses["loss"] = loss
|
| 292 |
+
losses["loss_fine"] = loss_fine.detach()
|
| 293 |
+
losses["loss_global"] = loss_global.detach()
|
| 294 |
+
return losses
|
| 295 |
|
| 296 |
+
def step(self, batch: dict, stage=Stage.NONE):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
spec_enc, mol_enc = self.forward(batch, stage)
|
|
|
|
| 300 |
|
| 301 |
+
mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch["mol_n_nodes"])
|
| 302 |
+
spec_mask = ~torch.all((spec_enc == -5), dim=-1)
|
| 303 |
+
|
| 304 |
+
if stage == Stage.TEST:
|
| 305 |
+
return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask, stage=stage)
|
| 308 |
+
return losses
|
| 309 |
|
| 310 |
+
# -------------- TEST step with different score variants --------------
|
| 311 |
+
def test_step(self, batch, batch_idx):
|
| 312 |
+
identifiers = batch["identifier"]
|
| 313 |
+
cand_smiles = batch["cand_smiles"]
|
| 314 |
|
| 315 |
+
id_to_ct = defaultdict(int)
|
| 316 |
+
for i in identifiers:
|
| 317 |
+
id_to_ct[i] += 1
|
| 318 |
+
batch_ptr = torch.tensor(list(id_to_ct.values()), device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
outputs = self.step(batch, stage=Stage.TEST)
|
| 321 |
+
spec_enc = outputs["spec_enc"]
|
| 322 |
+
mol_enc = outputs["mol_enc"]
|
| 323 |
+
spec_mask = outputs["spec_mask"]
|
| 324 |
+
mol_mask = outputs["mol_mask"]
|
| 325 |
|
| 326 |
+
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
|
| 327 |
|
| 328 |
+
# --- fine-grained score ---
|
| 329 |
+
fine_scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
|
| 330 |
+
|
| 331 |
+
# --- global cosine score ---
|
| 332 |
+
spec_global = (spec_enc * spec_mask.unsqueeze(-1)).sum(1) / spec_mask.sum(1, keepdim=True).clamp(min=1)
|
| 333 |
+
mol_global = (mol_enc * mol_mask.unsqueeze(-1)).sum(1) / mol_mask.sum(1, keepdim=True).clamp(min=1)
|
| 334 |
+
global_scores = F.cosine_similarity(spec_global, mol_global, dim=-1)
|
| 335 |
+
|
| 336 |
+
# --- combined scores (for evaluation) ---
|
| 337 |
+
combined_sum = fine_scores + global_scores
|
| 338 |
+
combined_weighted = fine_scores + self.loss_weight * global_scores
|
| 339 |
+
combined_pc = 0.5 * (fine_scores + global_scores) # simple average baseline
|
| 340 |
+
|
| 341 |
+
scores_dict = {
|
| 342 |
+
"fine": fine_scores,
|
| 343 |
+
"global": global_scores,
|
| 344 |
+
"sum": combined_sum,
|
| 345 |
+
"weighted": combined_weighted,
|
| 346 |
+
"avg": combined_pc,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
# split back per identifier
|
| 350 |
+
for key in scores_dict:
|
| 351 |
+
scores_dict[key] = torch.split(scores_dict[key], list(id_to_ct.values()))
|
| 352 |
+
|
| 353 |
+
cand_smiles = utils.unbatch_list(batch["cand_smiles"], indexes)
|
| 354 |
+
labels = utils.unbatch_list(batch["label"], indexes)
|
| 355 |
+
|
| 356 |
+
return dict(
|
| 357 |
+
identifiers=list(id_to_ct.keys()),
|
| 358 |
+
scores=scores_dict,
|
| 359 |
+
cand_smiles=cand_smiles,
|
| 360 |
+
labels=labels,
|
| 361 |
+
)
|
| 362 |
|
| 363 |
+
def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
|
| 364 |
+
"""
|
| 365 |
+
Collects test batch outputs and stores them in self.result_dct.
|
| 366 |
+
Supports both:
|
| 367 |
+
- Single score list format (legacy)
|
| 368 |
+
- Dict of multiple score variants (new)
|
| 369 |
+
"""
|
| 370 |
+
identifiers = outputs["identifiers"]
|
| 371 |
+
cand_smiles = outputs["cand_smiles"]
|
| 372 |
+
labels = outputs["labels"]
|
| 373 |
+
scores_out = outputs["scores"]
|
| 374 |
+
|
| 375 |
+
for k, (i, cands, l) in enumerate(zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels'])):
|
| 376 |
+
self.result_dct[i]['candidates'].extend(cands)
|
| 377 |
+
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
|
| 378 |
|
| 379 |
+
for variant_name, score_list in scores_out.items():
|
| 380 |
+
self.result_dct[i][f"scores_{variant_name}"].extend(score_list[k].cpu().tolist())
|
| 381 |
+
|
| 382 |
+
def on_test_epoch_end(self) -> None:
|
| 383 |
+
"""
|
| 384 |
+
Combine results into one DataFrame with one row per identifier.
|
| 385 |
+
Adds rank/top_cand columns for each score variant.
|
| 386 |
+
"""
|
| 387 |
+
records = []
|
| 388 |
+
for identifier, val in self.result_dct.items():
|
| 389 |
+
row = {"identifier": identifier, "candidates": val["candidates"]}
|
| 390 |
+
if not self.external_test:
|
| 391 |
+
row["labels"] = val["labels"]
|
| 392 |
+
|
| 393 |
+
# For every scores_* key, compute rank or top candidate
|
| 394 |
+
for key, scores in val.items():
|
| 395 |
+
if not key.startswith("scores_"):
|
| 396 |
+
continue
|
| 397 |
+
variant = key.replace("scores_", "")
|
| 398 |
+
if not self.external_test:
|
| 399 |
+
row[f"rank_{variant}"] = self._compute_rank(scores, val["labels"])
|
| 400 |
+
else:
|
| 401 |
+
row[f"top_cand_{variant}"] = self._get_top_cand(scores, val["candidates"])
|
| 402 |
+
row[key] = scores
|
| 403 |
+
records.append(row)
|
| 404 |
+
|
| 405 |
+
self.df_test = pd.DataFrame(records)
|
| 406 |
+
|
| 407 |
+
if self.external_test and "labels" in self.df_test.columns:
|
| 408 |
+
self.df_test.drop(columns=["labels"], inplace=True)
|
| 409 |
+
|
| 410 |
+
# Save once
|
| 411 |
+
if not getattr(self, "df_test_path", None):
|
| 412 |
+
self.df_test_path = os.path.join(self.hparams["experiment_dir"], "result_combined.pkl")
|
| 413 |
+
|
| 414 |
+
self.df_test.to_pickle(self.df_test_path)
|
| 415 |
|
| 416 |
class CrossAttenContrastive(ContrastiveModel):
|
| 417 |
def __init__(
|
flare/models/mol_encoder.py
CHANGED
|
@@ -12,7 +12,7 @@ class MolEnc(nn.Module):
|
|
| 12 |
|
| 13 |
self.return_emb = False
|
| 14 |
|
| 15 |
-
if args.model in ('filipContrastive', 'crossAttenContrastive'):
|
| 16 |
self.return_emb = True
|
| 17 |
|
| 18 |
dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
|
|
@@ -46,4 +46,5 @@ class MolEnc(nn.Module):
|
|
| 46 |
h1 = self.dropout(h1)
|
| 47 |
|
| 48 |
return h1
|
|
|
|
| 49 |
|
|
|
|
| 12 |
|
| 13 |
self.return_emb = False
|
| 14 |
|
| 15 |
+
if args.model in ('filipContrastive', 'crossAttenContrastive', 'filipGlobalContrastive'):
|
| 16 |
self.return_emb = True
|
| 17 |
|
| 18 |
dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
|
|
|
|
| 46 |
h1 = self.dropout(h1)
|
| 47 |
|
| 48 |
return h1
|
| 49 |
+
|
| 50 |
|
flare/models/spec_encoder.py
CHANGED
|
@@ -111,7 +111,7 @@ class SpecFormulaTransformer(nn.Module):
|
|
| 111 |
in_dim+=1
|
| 112 |
|
| 113 |
self.returnEmb = False
|
| 114 |
-
if args.model in ('crossAttenContrastive', 'filipContrastive'):
|
| 115 |
self.returnEmb = True
|
| 116 |
assert(args.use_cls == False)
|
| 117 |
|
|
@@ -128,7 +128,7 @@ class SpecFormulaTransformer(nn.Module):
|
|
| 128 |
out_dim = args.final_embedding_dim
|
| 129 |
self.fc = nn.Linear(args.formula_dims[-1], out_dim)
|
| 130 |
|
| 131 |
-
def forward(self, spec, n_peaks):
|
| 132 |
h = self.formulaEnc(spec)
|
| 133 |
pad = (spec == -5)
|
| 134 |
pad = torch.all(pad, -1)
|
|
@@ -154,7 +154,6 @@ class SpecFormulaTransformer(nn.Module):
|
|
| 154 |
h = self.fc(h)
|
| 155 |
|
| 156 |
return h
|
| 157 |
-
|
| 158 |
class SpecFormula_mz_Encoder(nn.Module):
|
| 159 |
'''
|
| 160 |
Encodes formula and mz_int
|
|
|
|
| 111 |
in_dim+=1
|
| 112 |
|
| 113 |
self.returnEmb = False
|
| 114 |
+
if args.model in ('crossAttenContrastive', 'filipContrastive', 'filipGlobalContrastive'):
|
| 115 |
self.returnEmb = True
|
| 116 |
assert(args.use_cls == False)
|
| 117 |
|
|
|
|
| 128 |
out_dim = args.final_embedding_dim
|
| 129 |
self.fc = nn.Linear(args.formula_dims[-1], out_dim)
|
| 130 |
|
| 131 |
+
def forward(self, spec, n_peaks=None):
|
| 132 |
h = self.formulaEnc(spec)
|
| 133 |
pad = (spec == -5)
|
| 134 |
pad = torch.all(pad, -1)
|
|
|
|
| 154 |
h = self.fc(h)
|
| 155 |
|
| 156 |
return h
|
|
|
|
| 157 |
class SpecFormula_mz_Encoder(nn.Module):
|
| 158 |
'''
|
| 159 |
Encodes formula and mz_int
|
flare/params_filipGlobal.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experiment setup
|
| 2 |
+
job_key: ''
|
| 3 |
+
run_name: 'filip-global'
|
| 4 |
+
run_details: ""
|
| 5 |
+
project_name: ''
|
| 6 |
+
wandb_entity_name: 'mass-spec-ml'
|
| 7 |
+
no_wandb: True
|
| 8 |
+
seed: 42
|
| 9 |
+
debug: False
|
| 10 |
+
checkpoint_pth:
|
| 11 |
+
|
| 12 |
+
# Training setup
|
| 13 |
+
max_epochs: 2000
|
| 14 |
+
accelerator: 'gpu'
|
| 15 |
+
devices: [1]
|
| 16 |
+
log_every_n_steps: 250
|
| 17 |
+
val_check_interval: 1.0
|
| 18 |
+
|
| 19 |
+
# Data paths
|
| 20 |
+
candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
|
| 21 |
+
dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
|
| 22 |
+
subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
|
| 23 |
+
split_pth:
|
| 24 |
+
fp_dir_pth:
|
| 25 |
+
partial_checkpoint: ""
|
| 26 |
+
|
| 27 |
+
# General hyperparameters
|
| 28 |
+
batch_size: 64 #64
|
| 29 |
+
lr: 2.881339661302105e-05 # 5.0e-05
|
| 30 |
+
weight_decay: 1.8376229667330708e-05
|
| 31 |
+
contr_temp: 0.022772534845886608 # 0.022772534845886608 # 0.05
|
| 32 |
+
num_workers: 50
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# FILIP_GLOBAL model parameters
|
| 36 |
+
loss_mode: "pcgrad"
|
| 37 |
+
agg_fn: "mean"
|
| 38 |
+
loss_weight: 1.1
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
############################## Data transforms ##############################
|
| 42 |
+
# - Spectra
|
| 43 |
+
spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
|
| 44 |
+
formula_source: 'default' # magma_1, magma_all, sirius, default
|
| 45 |
+
# 1. Binner
|
| 46 |
+
max_mz: 1000
|
| 47 |
+
bin_width: 1
|
| 48 |
+
mask_peak_ratio: 0.00
|
| 49 |
+
|
| 50 |
+
# 2. SpecFormula
|
| 51 |
+
element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
|
| 52 |
+
add_intensities: True
|
| 53 |
+
|
| 54 |
+
# - Molecule
|
| 55 |
+
molecule_view: "MolGraph"
|
| 56 |
+
atom_feature: 'full'
|
| 57 |
+
bond_feature: 'full'
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
############################## Task and model ##############################
|
| 61 |
+
task: 'retrieval'
|
| 62 |
+
spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
|
| 63 |
+
mol_enc: "GNN"
|
| 64 |
+
model: filipGlobalContrastive #filipContrastive # "MultiviewContrastive"
|
| 65 |
+
contr_views: [['spec_enc', 'mol_enc']]
|
| 66 |
+
log_only_loss_at_stages: []
|
| 67 |
+
df_test_path: ""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# - Formula-based spec encoders
|
| 71 |
+
formula_dropout: 0.2
|
| 72 |
+
formula_dims: [512,256,512] #[512, 256, 512] #[64, 128, 256]
|
| 73 |
+
cross_attn_heads: 2
|
| 74 |
+
use_cls: False
|
| 75 |
+
peak_dropout: 0.2
|
| 76 |
+
formula_attn_heads: 4 # 2
|
| 77 |
+
formula_transformer_layers: 2 #2
|
| 78 |
+
|
| 79 |
+
# -- GAT params
|
| 80 |
+
attn_heads: [12,12,12]
|
| 81 |
+
|
| 82 |
+
# - Molecule encoder (GNN)
|
| 83 |
+
gnn_channels: [128, 256, 512] #[64,128,512]
|
| 84 |
+
gnn_type: "gcn"
|
| 85 |
+
# num_gnn_layers: 3
|
| 86 |
+
# gnn_hidden_dim: 512
|
| 87 |
+
gnn_dropout: 0.23234950970370824 #0.3
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# - Spectra encoder (cross attention model)
|
| 91 |
+
# final_embedding_dim: 512
|
| 92 |
+
# fc_dropout: 0.4
|
| 93 |
+
|
| 94 |
+
# - Spectra Token encoder (mz-int token model)
|
| 95 |
+
# hidden_dims: [64, 256]
|
flare/run.sh
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
# python train.py
|
| 2 |
-
python test.py --param_pth
|
| 3 |
-
|
|
|
|
| 1 |
+
# python train.py --param_pth params_filipGlobal.yaml
|
| 2 |
+
# python test.py --param_pth params_filipGlobal.yaml
|
| 3 |
+
python test.py --param_pth params_filipGlobal.yaml --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
|
flare/subformula_assign/run.sh
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
-
SPEC_FILES="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
|
| 2 |
-
OUTPUT_DIR="/data/yzhouc01/spectra_data/subformulae"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
MAX_FORMULAE=60
|
| 4 |
-
LABELS_FILE="/data/yzhouc01/
|
| 5 |
|
| 6 |
python assign_subformulae.py --spec-files $SPEC_FILES --output-dir $OUTPUT_DIR --max-formulae $MAX_FORMULAE --labels-file $LABELS_FILE
|
|
|
|
| 1 |
+
# SPEC_FILES="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
|
| 2 |
+
# OUTPUT_DIR="/data/yzhouc01/spectra_data/subformulae"
|
| 3 |
+
# MAX_FORMULAE=60
|
| 4 |
+
# LABELS_FILE="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
|
| 5 |
+
|
| 6 |
+
SPEC_FILES="/data/yzhouc01/cancer/breast_cancer_data.tsv"
|
| 7 |
+
OUTPUT_DIR="/data/yzhouc01/cancer/subformulae"
|
| 8 |
MAX_FORMULAE=60
|
| 9 |
+
LABELS_FILE="/data/yzhouc01/cancer/breast_cancer_data.tsv"
|
| 10 |
|
| 11 |
python assign_subformulae.py --spec-files $SPEC_FILES --output-dir $OUTPUT_DIR --max-formulae $MAX_FORMULAE --labels-file $LABELS_FILE
|
flare/subformula_assign/utils/chem_utils.py
CHANGED
|
@@ -181,6 +181,8 @@ def formula_to_dense(chem_formula: str) -> np.ndarray:
|
|
| 181 |
"""
|
| 182 |
total_onehot = []
|
| 183 |
for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
|
|
|
|
|
|
|
| 184 |
# Convert num to int
|
| 185 |
num = 1 if num == "" else int(num)
|
| 186 |
one_hot = element_to_position[chem_symbol].reshape(1, -1)
|
|
@@ -257,6 +259,8 @@ def formula_to_dense(chem_formula: str) -> np.ndarray:
|
|
| 257 |
"""
|
| 258 |
total_onehot = []
|
| 259 |
for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
|
|
|
|
|
|
|
| 260 |
# Convert num to int
|
| 261 |
num = 1 if num == "" else int(num)
|
| 262 |
one_hot = element_to_position[chem_symbol].reshape(1, -1)
|
|
|
|
| 181 |
"""
|
| 182 |
total_onehot = []
|
| 183 |
for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
|
| 184 |
+
if chem_symbol not in VALID_ELEMENTS: # yzc
|
| 185 |
+
continue
|
| 186 |
# Convert num to int
|
| 187 |
num = 1 if num == "" else int(num)
|
| 188 |
one_hot = element_to_position[chem_symbol].reshape(1, -1)
|
|
|
|
| 259 |
"""
|
| 260 |
total_onehot = []
|
| 261 |
for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
|
| 262 |
+
if chem_symbol not in VALID_ELEMENTS: # yzc
|
| 263 |
+
continue
|
| 264 |
# Convert num to int
|
| 265 |
num = 1 if num == "" else int(num)
|
| 266 |
one_hot = element_to_position[chem_symbol].reshape(1, -1)
|
flare/test.py
CHANGED
|
@@ -29,6 +29,8 @@ parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['
|
|
| 29 |
parser.add_argument('--df_test_pth', type=str, help='result file name')
|
| 30 |
parser.add_argument('--exp_dir', type=str)
|
| 31 |
parser.add_argument('--candidates_pth', type=str)
|
|
|
|
|
|
|
| 32 |
def main(params):
|
| 33 |
# Seed everything
|
| 34 |
pl.seed_everything(params['seed'])
|
|
@@ -58,6 +60,7 @@ def main(params):
|
|
| 58 |
|
| 59 |
model = get_model(params['model'], params)
|
| 60 |
model.df_test_path = params['df_test_path']
|
|
|
|
| 61 |
|
| 62 |
# Init trainer
|
| 63 |
trainer = Trainer(
|
|
@@ -109,7 +112,12 @@ if __name__ == "__main__":
|
|
| 109 |
params['checkpoint_pth'] = checkpoint_path
|
| 110 |
break
|
| 111 |
assert(params['checkpoint_pth'] != '')
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if args.candidates_pth:
|
| 114 |
params['candidates_pth'] = args.candidates_pth
|
| 115 |
if args.df_test_pth:
|
|
|
|
| 29 |
parser.add_argument('--df_test_pth', type=str, help='result file name')
|
| 30 |
parser.add_argument('--exp_dir', type=str)
|
| 31 |
parser.add_argument('--candidates_pth', type=str)
|
| 32 |
+
parser.add_argument('--external_test', action='store_true', help='whether the test set is external data without labels')
|
| 33 |
+
|
| 34 |
def main(params):
|
| 35 |
# Seed everything
|
| 36 |
pl.seed_everything(params['seed'])
|
|
|
|
| 60 |
|
| 61 |
model = get_model(params['model'], params)
|
| 62 |
model.df_test_path = params['df_test_path']
|
| 63 |
+
model.external_test = params['external_test']
|
| 64 |
|
| 65 |
# Init trainer
|
| 66 |
trainer = Trainer(
|
|
|
|
| 112 |
params['checkpoint_pth'] = checkpoint_path
|
| 113 |
break
|
| 114 |
assert(params['checkpoint_pth'] != '')
|
| 115 |
+
|
| 116 |
+
if args.external_test:
|
| 117 |
+
params['external_test'] = True
|
| 118 |
+
else:
|
| 119 |
+
params['external_test'] = False
|
| 120 |
+
|
| 121 |
if args.candidates_pth:
|
| 122 |
params['candidates_pth'] = args.candidates_pth
|
| 123 |
if args.df_test_pth:
|
flare/tune.py
CHANGED
|
@@ -231,7 +231,7 @@ def main(args):
|
|
| 231 |
|
| 232 |
# now = datetime.datetime.now().strftime("%Y%m%d")
|
| 233 |
# base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
|
| 234 |
-
base_dir = "
|
| 235 |
os.makedirs(base_dir, exist_ok=True)
|
| 236 |
params["experiment_dir"] = base_dir
|
| 237 |
|
|
|
|
| 231 |
|
| 232 |
# now = datetime.datetime.now().strftime("%Y%m%d")
|
| 233 |
# base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
|
| 234 |
+
base_dir = "../experiments/20250916_simple_model_optuna"
|
| 235 |
os.makedirs(base_dir, exist_ok=True)
|
| 236 |
params["experiment_dir"] = base_dir
|
| 237 |
|
flare/utils/case_study_utils.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 11 |
+
if parent_dir not in sys.path:
|
| 12 |
+
sys.path.insert(0, parent_dir)
|
| 13 |
+
|
| 14 |
+
database_to_path = {'fdb':"/data/yzhouc01/molecule_data/foodb_2020_04_07_csv/Compound.csv",
|
| 15 |
+
'hmdb':"/data/yzhouc01/molecule_data/metabolites-2025-09-18.csv",
|
| 16 |
+
'spectra_db':"/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex_processed.tsv",
|
| 17 |
+
'bio_db':"/data/yzhouc01/molecule_data/bio_2023_07_11_smiles.csv",
|
| 18 |
+
'coconut':"/data/yzhouc01/molecule_data/coconut_csv-05-2025.csv"}
|
| 19 |
+
|
| 20 |
+
db_to_mass_col = {'fdb':'exact_molecular_weight',
|
| 21 |
+
'hmdb':'MONO_MASS',
|
| 22 |
+
'spectra_db':'exact_molecular_weight',
|
| 23 |
+
'bio_db':'exact_molecular_weight',
|
| 24 |
+
'coconut':'exact_molecular_weight'}
|
| 25 |
+
|
| 26 |
+
db_to_smiles_col = {'fdb':'CANONICAL_SMILES',
|
| 27 |
+
'hmdb':'CANONICAL_SMILES',
|
| 28 |
+
'spectra_db':'CANONICAL_SMILES',
|
| 29 |
+
'bio_db':'canonical_smiles',
|
| 30 |
+
'coconut':'rdkit_canonical_smiles'}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_worker_instance = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _init_worker(databases, threshold):
|
| 37 |
+
"""Run once per worker process to initialize shared CandidateAssignment."""
|
| 38 |
+
global _worker_instance
|
| 39 |
+
_worker_instance = CandidateAssignment(databases, threshold)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _worker_retrieve_candidates(parent_mass):
|
| 43 |
+
"""Use the global CandidateAssignment instance inside each worker."""
|
| 44 |
+
return _worker_instance.retrieve_candidates(parent_mass)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_worker_instance = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _init_worker(databases, threshold):
|
| 51 |
+
"""Initialize global CandidateAssignment in each worker (silent)."""
|
| 52 |
+
global _worker_instance
|
| 53 |
+
_worker_instance = CandidateAssignment(databases, threshold, verbose=False)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _worker_retrieve_candidates(parent_mass):
|
| 57 |
+
"""Retrieve candidates using the worker's global CandidateAssignment."""
|
| 58 |
+
return _worker_instance.retrieve_candidates(parent_mass)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CandidateAssignment:
|
| 62 |
+
def __init__(self, databases=None, threshold=0.01, verbose=True):
|
| 63 |
+
self.threshold = threshold
|
| 64 |
+
self.databases = []
|
| 65 |
+
self.verbose = verbose
|
| 66 |
+
|
| 67 |
+
for db in databases:
|
| 68 |
+
if db not in database_to_path:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Database {db} not recognized. Available: {list(database_to_path.keys())}"
|
| 71 |
+
)
|
| 72 |
+
if not os.path.exists(database_to_path[db]):
|
| 73 |
+
raise ValueError(f"Database file for {db} not found at {database_to_path[db]}")
|
| 74 |
+
self.databases.append(db)
|
| 75 |
+
|
| 76 |
+
# Only print in main process
|
| 77 |
+
if self.verbose and mp.current_process().name == "MainProcess":
|
| 78 |
+
print(f"[{os.getpid()}] Loading databases: {self.databases}")
|
| 79 |
+
|
| 80 |
+
self.db_dfs = {}
|
| 81 |
+
self._load_databases()
|
| 82 |
+
|
| 83 |
+
def _load_databases(self):
|
| 84 |
+
for db in self.databases:
|
| 85 |
+
path = database_to_path[db]
|
| 86 |
+
if path.endswith("tsv"):
|
| 87 |
+
df = pd.read_csv(path, sep="\t", low_memory=False)
|
| 88 |
+
elif path.endswith("csv"):
|
| 89 |
+
df = pd.read_csv(path, low_memory=False)
|
| 90 |
+
else:
|
| 91 |
+
if self.verbose and mp.current_process().name == "MainProcess":
|
| 92 |
+
print(f"Unable to load database: {db}")
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
# make sure required columns exist
|
| 96 |
+
required_cols = [db_to_mass_col[db], db_to_smiles_col[db]]
|
| 97 |
+
for col in required_cols:
|
| 98 |
+
if col not in df.columns:
|
| 99 |
+
raise ValueError(f"Column {col} not found in database {db}. {db} columns: {df.columns.tolist()}")
|
| 100 |
+
|
| 101 |
+
# convert to proper types
|
| 102 |
+
df[db_to_mass_col[db]] = pd.to_numeric(df[db_to_mass_col[db]], errors='coerce')
|
| 103 |
+
|
| 104 |
+
self.db_dfs[db] = df
|
| 105 |
+
|
| 106 |
+
# Only print in main process
|
| 107 |
+
if self.verbose and mp.current_process().name == "MainProcess":
|
| 108 |
+
print(f"[{os.getpid()}] Loaded {db} with {len(df)} entries.")
|
| 109 |
+
|
| 110 |
+
def retrieve_candidates(self, parent_mass):
|
| 111 |
+
"""Retrieve SMILES candidates for a single parent mass."""
|
| 112 |
+
ub = parent_mass + self.threshold
|
| 113 |
+
lb = parent_mass - self.threshold
|
| 114 |
+
|
| 115 |
+
smiles_list = []
|
| 116 |
+
for db_name, df in self.db_dfs.items():
|
| 117 |
+
select_rows = df[
|
| 118 |
+
(df[db_to_mass_col[db_name]] >= lb)
|
| 119 |
+
& (df[db_to_mass_col[db_name]] <= ub)
|
| 120 |
+
]
|
| 121 |
+
smiles_list.extend(select_rows[db_to_smiles_col[db_name]].tolist())
|
| 122 |
+
|
| 123 |
+
smiles_list = list(set(smiles_list))
|
| 124 |
+
return parent_mass, smiles_list
|
| 125 |
+
|
| 126 |
+
def retrieve_candidates_batch(self, parent_masses, n_workers=25, chunksize=10):
|
| 127 |
+
"""Parallel batch retrieval with silent workers."""
|
| 128 |
+
with mp.Pool(
|
| 129 |
+
processes=n_workers,
|
| 130 |
+
initializer=_init_worker,
|
| 131 |
+
initargs=(self.databases, self.threshold),
|
| 132 |
+
) as pool:
|
| 133 |
+
results = list(
|
| 134 |
+
tqdm(
|
| 135 |
+
pool.imap(_worker_retrieve_candidates, parent_masses, chunksize=chunksize),
|
| 136 |
+
total=len(parent_masses),
|
| 137 |
+
desc="Retrieving candidates",
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
return {r[0]: r[1] for r in results}
|
| 141 |
+
|
| 142 |
+
# P_TBL = Chem.GetPeriodicTable()
|
| 143 |
+
# ELECTRON_MASS = 0.00054858
|
| 144 |
+
# VALID_ELEMENTS = [
|
| 145 |
+
# "C",
|
| 146 |
+
# "H",
|
| 147 |
+
# "As",
|
| 148 |
+
# "B",
|
| 149 |
+
# "Br",
|
| 150 |
+
# "Cl",
|
| 151 |
+
# "Co",
|
| 152 |
+
# "F",
|
| 153 |
+
# "Fe",
|
| 154 |
+
# "I",
|
| 155 |
+
# "K",
|
| 156 |
+
# "N",
|
| 157 |
+
# "Na",
|
| 158 |
+
# "O",
|
| 159 |
+
# "P",
|
| 160 |
+
# "S",
|
| 161 |
+
# "Se",
|
| 162 |
+
# "Si",
|
| 163 |
+
# ]
|
| 164 |
+
# VALID_MONO_MASSES = np.array(
|
| 165 |
+
# [P_TBL.GetMostCommonIsotopeMass(i) for i in VALID_ELEMENTS]
|
| 166 |
+
# )
|
| 167 |
+
# CHEM_MASSES = VALID_MONO_MASSES[:, None]
|
| 168 |
+
# ELEMENT_TO_MASS = dict(zip(VALID_ELEMENTS, CHEM_MASSES.squeeze()))
|
| 169 |
+
|
| 170 |
+
# adduct_to_mass = {
|
| 171 |
+
# "[M+H]+": ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
|
| 172 |
+
# "[M+Na]+": ELEMENT_TO_MASS["Na"] - ELECTRON_MASS,
|
| 173 |
+
# "[M+K]+": ELEMENT_TO_MASS["K"] - ELECTRON_MASS,
|
| 174 |
+
# "[M-H2O+H]+": -ELEMENT_TO_MASS["O"] - ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
|
| 175 |
+
# "[M+H3N+H]+": ELEMENT_TO_MASS["N"] + ELEMENT_TO_MASS["H"] * 4 - ELECTRON_MASS,
|
| 176 |
+
# "[M]+": 0 - ELECTRON_MASS,
|
| 177 |
+
# "[M-H4O2+H]+": -ELEMENT_TO_MASS["O"] * 2 - ELEMENT_TO_MASS["H"] * 3 - ELECTRON_MASS,
|
| 178 |
+
# "[M-H]-": ELEMENT_TO_MASS["H"] + ELECTRON_MASS,
|
| 179 |
+
# "[M+H2O+H]+":ELEMENT_TO_MASS["O"] * 2 + ELEMENT_TO_MASS["H"] * 2 - ELECTRON_MASS,
|
| 180 |
+
# }
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# def calculate_parent_mass(precursor_mz, adduct):
|
| 184 |
+
# if adduct not in adduct_to_mass:
|
| 185 |
+
# print(f'{adduct} not supported, returning original precursor_mz')
|
| 186 |
+
# return precursor_mz + adduct_to_mass[adduct]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
# get_mol_mass_for_combined()
|
| 191 |
+
ca = CandidateAssignment(databases=['hmdb'])
|
| 192 |
+
candidates = ca.retrieve_candidates(parent_mass=180.0634, threshold=0.01)
|
| 193 |
+
print(candidates)
|
flare/utils/general.py
CHANGED
|
@@ -2,37 +2,69 @@ import torch
|
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
| 5 |
def pad_graph_nodes(mol_enc, g_n_nodes):
|
| 6 |
"""
|
| 7 |
Args:
|
| 8 |
-
mol_enc:
|
| 9 |
-
|
| 10 |
-
g_n_nodes: list[int] Number of nodes per graph (len = B)
|
| 11 |
|
| 12 |
Returns:
|
| 13 |
-
padded: (B, max_nodes, D) tensor
|
| 14 |
mask: (B, max_nodes) bool tensor, True for valid nodes
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
-
# Already concatenated: shape (sum_nodes, D)
|
| 18 |
B = len(g_n_nodes)
|
| 19 |
D = mol_enc.shape[1]
|
| 20 |
max_nodes = max(g_n_nodes)
|
| 21 |
-
padded = mol_enc.new_zeros((B, max_nodes, D))
|
| 22 |
-
mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
idx = 0
|
| 25 |
for i, n in enumerate(g_n_nodes):
|
| 26 |
padded[i, :n] = mol_enc[idx:idx+n]
|
| 27 |
mask[i, :n] = True
|
| 28 |
idx += n
|
|
|
|
| 29 |
return padded, mask
|
| 30 |
|
| 31 |
-
import torch
|
| 32 |
-
import torch.nn.functional as F
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def filip_similarity_batch(
|
| 38 |
image_tokens,
|
|
@@ -127,60 +159,64 @@ def filip_similarity_batch(
|
|
| 127 |
return similarity
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# mask_image: (B, N_img) bool tensor
|
| 139 |
-
# mask_text: (B, N_text) bool tensor
|
| 140 |
-
|
| 141 |
-
# Returns:
|
| 142 |
-
# similarities: (B,) float tensor of similarity scores
|
| 143 |
-
# """
|
| 144 |
-
# B, N_img, D = image_tokens.shape
|
| 145 |
-
# N_text = text_tokens.shape[1]
|
| 146 |
-
|
| 147 |
-
# # Normalize tokens
|
| 148 |
-
# image_norm = F.normalize(image_tokens, p=2, dim=-1) # (B, N_img, D)
|
| 149 |
-
# text_norm = F.normalize(text_tokens, p=2, dim=-1) # (B, N_text, D)
|
| 150 |
-
|
| 151 |
-
# # Compute batched cosine similarity matrices
|
| 152 |
-
# # Result shape: (B, N_img, N_text)
|
| 153 |
-
# sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
#
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
|
|
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
-
#
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
-
#
|
|
|
|
|
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
|
| 6 |
+
|
| 7 |
def pad_graph_nodes(mol_enc, g_n_nodes):
|
| 8 |
"""
|
| 9 |
Args:
|
| 10 |
+
mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs
|
| 11 |
+
g_n_nodes: list[int], number of nodes per graph
|
|
|
|
| 12 |
|
| 13 |
Returns:
|
| 14 |
+
padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes
|
| 15 |
mask: (B, max_nodes) bool tensor, True for valid nodes
|
| 16 |
"""
|
|
|
|
|
|
|
| 17 |
B = len(g_n_nodes)
|
| 18 |
D = mol_enc.shape[1]
|
| 19 |
max_nodes = max(g_n_nodes)
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Create output with same requires_grad as input
|
| 22 |
+
padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device)
|
| 23 |
+
|
| 24 |
+
# Force gradient tracking by making this a non-leaf tensor
|
| 25 |
+
padded = padded + mol_enc.new_zeros(1).requires_grad_(True)
|
| 26 |
+
|
| 27 |
+
mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device)
|
| 28 |
+
|
| 29 |
idx = 0
|
| 30 |
for i, n in enumerate(g_n_nodes):
|
| 31 |
padded[i, :n] = mol_enc[idx:idx+n]
|
| 32 |
mask[i, :n] = True
|
| 33 |
idx += n
|
| 34 |
+
|
| 35 |
return padded, mask
|
| 36 |
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# def pad_graph_nodes(mol_enc, g_n_nodes):
|
| 41 |
+
# """
|
| 42 |
+
# Args:
|
| 43 |
+
# mol_enc: 2D tensor of shape (sum_nodes, D)
|
| 44 |
+
# Node embeddings for each molecule.
|
| 45 |
+
# g_n_nodes: list[int] Number of nodes per graph (len = B)
|
| 46 |
+
|
| 47 |
+
# Returns:
|
| 48 |
+
# padded: (B, max_nodes, D) tensor
|
| 49 |
+
# mask: (B, max_nodes) bool tensor, True for valid nodes
|
| 50 |
+
# """
|
| 51 |
+
|
| 52 |
+
# # Already concatenated: shape (sum_nodes, D)
|
| 53 |
+
# B = len(g_n_nodes)
|
| 54 |
+
# D = mol_enc.shape[1]
|
| 55 |
+
# max_nodes = max(g_n_nodes)
|
| 56 |
+
# padded = mol_enc.new_zeros((B, max_nodes, D))
|
| 57 |
+
# mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
|
| 58 |
+
|
| 59 |
+
# idx = 0
|
| 60 |
+
# for i, n in enumerate(g_n_nodes):
|
| 61 |
+
# padded[i, :n] = mol_enc[idx:idx+n]
|
| 62 |
+
# mask[i, :n] = True
|
| 63 |
+
# idx += n
|
| 64 |
+
# return padded, mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
|
| 69 |
def filip_similarity_batch(
|
| 70 |
image_tokens,
|
|
|
|
| 159 |
return similarity
|
| 160 |
|
| 161 |
|
| 162 |
+
def filip_similarity_single(
|
| 163 |
+
image_tokens,
|
| 164 |
+
text_tokens,
|
| 165 |
+
reduction="mean", # "mean", "topk", "softmax", or "geom"
|
| 166 |
+
k=5,
|
| 167 |
+
temperature=0.05,
|
| 168 |
+
eps=1e-6
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Compute FILIP similarity for a single image and text pair (no masks).
|
| 172 |
|
| 173 |
+
Args:
|
| 174 |
+
image_tokens: (N_img, D) float tensor
|
| 175 |
+
text_tokens: (N_text, D) float tensor
|
| 176 |
+
reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
|
| 177 |
+
k: int, used if reduction == "topk"
|
| 178 |
+
temperature: float, used if reduction == "softmax"
|
| 179 |
+
eps: float, small constant for numerical stability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
Returns:
|
| 182 |
+
similarity: float scalar tensor
|
| 183 |
+
"""
|
| 184 |
+
# Normalize tokens
|
| 185 |
+
image_norm = F.normalize(image_tokens, p=2, dim=-1)
|
| 186 |
+
text_norm = F.normalize(text_tokens, p=2, dim=-1)
|
| 187 |
|
| 188 |
+
# (N_img, N_text) cosine similarity matrix
|
| 189 |
+
sim_matrix = torch.matmul(image_norm, text_norm.t())
|
| 190 |
|
| 191 |
+
# Max similarity for each token (image->text and text->image)
|
| 192 |
+
max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,)
|
| 193 |
+
max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,)
|
| 194 |
|
| 195 |
+
# Aggregation helper
|
| 196 |
+
def aggregate(max_sim):
|
| 197 |
+
if reduction == "mean":
|
| 198 |
+
return max_sim.mean()
|
| 199 |
|
| 200 |
+
elif reduction == "topk":
|
| 201 |
+
k_eff = min(k, max_sim.numel())
|
| 202 |
+
topk_vals, _ = torch.topk(max_sim, k_eff)
|
| 203 |
+
return topk_vals.mean()
|
| 204 |
|
| 205 |
+
elif reduction == "softmax":
|
| 206 |
+
weights = torch.softmax(max_sim / temperature, dim=0)
|
| 207 |
+
return (weights * max_sim).sum()
|
| 208 |
|
| 209 |
+
elif reduction == "geom":
|
| 210 |
+
vals = max_sim.clamp(min=eps)
|
| 211 |
+
return torch.exp(torch.log(vals).mean())
|
| 212 |
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError(f"Unknown reduction type: {reduction}")
|
| 215 |
|
| 216 |
+
# Aggregate both directions
|
| 217 |
+
avg_img = aggregate(max_sim_img)
|
| 218 |
+
avg_text = aggregate(max_sim_text)
|
| 219 |
|
| 220 |
+
# Final similarity (scalar)
|
| 221 |
+
similarity = (avg_img + avg_text) / 2
|
| 222 |
+
return similarity
|
flare/utils/loss.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 4 |
|
| 5 |
def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
|
| 6 |
v1_norm = torch.norm(v1, dim=1, keepdim=True)
|
|
@@ -76,10 +77,6 @@ class fp_loss:
|
|
| 76 |
return 1 - torch.mean(sim)
|
| 77 |
|
| 78 |
|
| 79 |
-
import torch
|
| 80 |
-
import torch.nn.functional as F
|
| 81 |
-
import torch.distributed as dist
|
| 82 |
-
|
| 83 |
# ---------- Utility ----------
|
| 84 |
def _safe_divide(num, denom, eps=1e-8):
|
| 85 |
return num / (denom + eps)
|
|
@@ -154,3 +151,97 @@ def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07):
|
|
| 154 |
|
| 155 |
return 0.5 * (loss_a2b + loss_b2a)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
|
| 6 |
def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
|
| 7 |
v1_norm = torch.norm(v1, dim=1, keepdim=True)
|
|
|
|
| 77 |
return 1 - torch.mean(sim)
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# ---------- Utility ----------
|
| 81 |
def _safe_divide(num, denom, eps=1e-8):
|
| 82 |
return num / (denom + eps)
|
|
|
|
| 151 |
|
| 152 |
return 0.5 * (loss_a2b + loss_b2a)
|
| 153 |
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def global_infonce_loss(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07, agg_fn="mean"):
|
| 157 |
+
"""
|
| 158 |
+
Global InfoNCE loss (CLIP-style) for modalities A and B.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
a_tokens: (B, N_a, D)
|
| 162 |
+
b_tokens: (B, N_b, D)
|
| 163 |
+
mask_a: (B, N_a) bool (True = valid)
|
| 164 |
+
mask_b: (B, N_b) bool (True = valid)
|
| 165 |
+
temperature: scalar
|
| 166 |
+
agg_fn: "mean" | "max" | "cls" | callable -> how to aggregate tokens into one vector
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
scalar loss
|
| 170 |
+
"""
|
| 171 |
+
device = a_tokens.device
|
| 172 |
+
B, N_a, D = a_tokens.shape
|
| 173 |
+
N_b = b_tokens.shape[1]
|
| 174 |
+
|
| 175 |
+
# ---- Normalize token embeddings ----
|
| 176 |
+
a = F.normalize(a_tokens, dim=-1)
|
| 177 |
+
b = F.normalize(b_tokens, dim=-1)
|
| 178 |
+
|
| 179 |
+
# ---- Aggregate per sample ----
|
| 180 |
+
if callable(agg_fn):
|
| 181 |
+
a_global = agg_fn(a, mask_a) # custom aggregation
|
| 182 |
+
b_global = agg_fn(b, mask_b)
|
| 183 |
+
elif agg_fn == "mean":
|
| 184 |
+
# masked mean
|
| 185 |
+
a_global = (a * mask_a.unsqueeze(-1)).sum(dim=1) / mask_a.sum(dim=1, keepdim=True).clamp(min=1)
|
| 186 |
+
b_global = (b * mask_b.unsqueeze(-1)).sum(dim=1) / mask_b.sum(dim=1, keepdim=True).clamp(min=1)
|
| 187 |
+
elif agg_fn == "max":
|
| 188 |
+
a_global = (a.masked_fill(~mask_a.unsqueeze(-1), float('-inf'))).max(dim=1).values
|
| 189 |
+
b_global = (b.masked_fill(~mask_b.unsqueeze(-1), float('-inf'))).max(dim=1).values
|
| 190 |
+
elif agg_fn == "cls":
|
| 191 |
+
# use first valid token as "cls"
|
| 192 |
+
a_global = a[:, 0, :]
|
| 193 |
+
b_global = b[:, 0, :]
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Unknown agg_fn: {agg_fn}")
|
| 196 |
+
|
| 197 |
+
# ---- Compute cosine similarity matrix ----
|
| 198 |
+
a_global = F.normalize(a_global, dim=-1)
|
| 199 |
+
b_global = F.normalize(b_global, dim=-1)
|
| 200 |
+
logits = (a_global @ b_global.T) / temperature # (B, B)
|
| 201 |
+
|
| 202 |
+
# ---- InfoNCE loss ----
|
| 203 |
+
labels = torch.arange(B, device=device)
|
| 204 |
+
loss_a2b = F.cross_entropy(logits, labels)
|
| 205 |
+
loss_b2a = F.cross_entropy(logits.T, labels)
|
| 206 |
+
loss = 0.5 * (loss_a2b + loss_b2a)
|
| 207 |
+
|
| 208 |
+
return loss
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------- PCGrad utility ----------
|
| 212 |
+
def pcgrad_combine(losses, shared_params):
|
| 213 |
+
"""
|
| 214 |
+
Compute PCGrad combined gradient for a list of scalar losses.
|
| 215 |
+
losses: list of scalar loss tensors
|
| 216 |
+
shared_params: list of parameters to project/aggregate gradients for
|
| 217 |
+
returns: scalar combined loss for logging (mean)
|
| 218 |
+
"""
|
| 219 |
+
grads_list = [torch.autograd.grad(l, shared_params, retain_graph=True, allow_unused=True)
|
| 220 |
+
for l in losses]
|
| 221 |
+
|
| 222 |
+
# flatten
|
| 223 |
+
flat_grads = [torch.cat([g.reshape(-1) for g in grads if g is not None]) for grads in grads_list]
|
| 224 |
+
projected = [fg.clone() for fg in flat_grads]
|
| 225 |
+
|
| 226 |
+
# project conflicting grads
|
| 227 |
+
for i in range(len(flat_grads)):
|
| 228 |
+
for j in range(len(flat_grads)):
|
| 229 |
+
if i == j:
|
| 230 |
+
continue
|
| 231 |
+
dot = (projected[i] * projected[j]).sum()
|
| 232 |
+
if dot < 0:
|
| 233 |
+
proj = dot / (projected[j].norm() ** 2 + 1e-12)
|
| 234 |
+
projected[i] = projected[i] - proj * projected[j]
|
| 235 |
+
|
| 236 |
+
# sum projected grads
|
| 237 |
+
final_grad = sum(projected)
|
| 238 |
+
# assign to params
|
| 239 |
+
pointer = 0
|
| 240 |
+
for p in shared_params:
|
| 241 |
+
if p.requires_grad:
|
| 242 |
+
numel = p.numel()
|
| 243 |
+
p.grad = final_grad[pointer:pointer + numel].view_as(p).clone()
|
| 244 |
+
pointer += numel
|
| 245 |
+
|
| 246 |
+
# return average loss for logging only
|
| 247 |
+
return sum(losses) / len(losses)
|
flare/utils/models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
|
| 2 |
from flare.models.mol_encoder import MolEnc
|
| 3 |
from flare.models.encoders import MLP
|
| 4 |
-
from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive
|
| 5 |
|
| 6 |
def get_spec_encoder(spec_enc:str, args):
|
| 7 |
return {"MLP_BIN": SpecEncMLP_BIN,
|
|
@@ -28,6 +28,8 @@ def get_model(model:str,
|
|
| 28 |
model = CrossAttenContrastive(**params)
|
| 29 |
elif model == "filipContrastive":
|
| 30 |
model = FilipContrastive(**params)
|
|
|
|
|
|
|
| 31 |
else:
|
| 32 |
raise Exception(f"Model {model} not implemented.")
|
| 33 |
|
|
|
|
| 1 |
from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
|
| 2 |
from flare.models.mol_encoder import MolEnc
|
| 3 |
from flare.models.encoders import MLP
|
| 4 |
+
from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
|
| 5 |
|
| 6 |
def get_spec_encoder(spec_enc:str, args):
|
| 7 |
return {"MLP_BIN": SpecEncMLP_BIN,
|
|
|
|
| 28 |
model = CrossAttenContrastive(**params)
|
| 29 |
elif model == "filipContrastive":
|
| 30 |
model = FilipContrastive(**params)
|
| 31 |
+
elif model == "filipGlobalContrastive":
|
| 32 |
+
model = FilipGlobalContrastive(**params)
|
| 33 |
else:
|
| 34 |
raise Exception(f"Model {model} not implemented.")
|
| 35 |
|
flare/utils/mol_search.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import Callable, List, Dict, Any, Optional
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
import faiss
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import dgl
|
| 11 |
+
|
| 12 |
+
class MoleculeDataset(Dataset):
|
| 13 |
+
"""Converts SMILES to DGL graphs in parallel via DataLoader workers."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, smiles_dict, smiles_preprocess):
|
| 16 |
+
self.items = list(smiles_dict.items())
|
| 17 |
+
self.smiles_preprocess = smiles_preprocess
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.items)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, idx):
|
| 23 |
+
mol_id, smi = self.items[idx]
|
| 24 |
+
try:
|
| 25 |
+
graph = self.smiles_preprocess(smi)
|
| 26 |
+
return mol_id, graph, None
|
| 27 |
+
except Exception as e:
|
| 28 |
+
return mol_id, None, str(e)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def collate_graphs(batch):
|
| 32 |
+
"""Custom collation: keep only valid graphs."""
|
| 33 |
+
valid = [(mid, g) for mid, g, err in batch if g is not None]
|
| 34 |
+
if not valid:
|
| 35 |
+
return [], None
|
| 36 |
+
mol_ids, graphs = zip(*valid)
|
| 37 |
+
batched_graph = dgl.batch(graphs)
|
| 38 |
+
return mol_ids, batched_graph
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SpectraMoleculeRetriever:
|
| 43 |
+
"""
|
| 44 |
+
Two-stage spectra–molecule retrieval system with hierarchical metadata filtering:
|
| 45 |
+
1. Coarse retrieval via FAISS on global embeddings.
|
| 46 |
+
2. Fine-grained reranking via custom similarity (e.g., FILIP alignment).
|
| 47 |
+
3. Supports fast subset search by class, superclass, or pathway.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
molecule_encoder,
|
| 53 |
+
spectra_encoder,
|
| 54 |
+
fine_similarity_fn: Callable[[Any, Any], float],
|
| 55 |
+
smiles_preprocess: Callable[[str], Any],
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
molecule_encoder: callable with methods:
|
| 61 |
+
- global_embedding(mol)
|
| 62 |
+
- node_embeddings(mol)
|
| 63 |
+
spectra_encoder: callable with methods:
|
| 64 |
+
- global_embedding(spectrum)
|
| 65 |
+
- token_embeddings(spectrum)
|
| 66 |
+
fine_similarity_fn: function for fine-grained similarity.
|
| 67 |
+
smiles_preprocess: preprocessing function for SMILES → molecule object.
|
| 68 |
+
device: where to run encoders.
|
| 69 |
+
"""
|
| 70 |
+
self.molecule_encoder = molecule_encoder
|
| 71 |
+
self.spectra_encoder = spectra_encoder
|
| 72 |
+
self.fine_similarity_fn = fine_similarity_fn
|
| 73 |
+
self.smiles_preprocess = smiles_preprocess
|
| 74 |
+
self.device = device
|
| 75 |
+
|
| 76 |
+
# Storage
|
| 77 |
+
self.molecule_db: Dict[str, Any] = {} # mol_id → mol object
|
| 78 |
+
self.node_cache: Dict[str, Any] = {} # mol_id → node embeddings
|
| 79 |
+
self.metadata: Dict[str, Dict[str, List[str]]] = {} # e.g. {"class": {"lipid": [mol1, mol2], ...}}
|
| 80 |
+
|
| 81 |
+
self.molecule_ids: Optional[np.ndarray] = None
|
| 82 |
+
self.global_embeddings: Optional[np.ndarray] = None
|
| 83 |
+
self.index: Optional[faiss.Index] = None
|
| 84 |
+
self.smiles_dict: Optional[Dict[str, str]] = None # mol_id → smiles
|
| 85 |
+
|
| 86 |
+
self.failed_mols = []
|
| 87 |
+
|
| 88 |
+
# set model to eval mode and move to device
|
| 89 |
+
self.molecule_encoder.eval()
|
| 90 |
+
self.spectra_encoder.eval()
|
| 91 |
+
|
| 92 |
+
self.molecule_encoder.to(self.device)
|
| 93 |
+
self.spectra_encoder.to(self.device)
|
| 94 |
+
|
| 95 |
+
# -------------------------------
|
| 96 |
+
# Database building & saving
|
| 97 |
+
# -------------------------------
|
| 98 |
+
def build_database(
|
| 99 |
+
self,
|
| 100 |
+
smiles_dict: dict,
|
| 101 |
+
metadata=None,
|
| 102 |
+
cache_nodes: bool = False,
|
| 103 |
+
batch_size: int = 64,
|
| 104 |
+
num_workers: int = 25,
|
| 105 |
+
pooling: str = "max", # or "sum", "mean"
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Parallelized database construction using PyTorch DataLoader for
|
| 109 |
+
SMILES → DGLGraph conversion and batched GPU encoding.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
smiles_dict: dict {mol_id: smiles}
|
| 113 |
+
metadata: hierarchical dict for class/superclass/pathway
|
| 114 |
+
cache_nodes: if True, store node embeddings for fine-grained search
|
| 115 |
+
batch_size: number of molecules per GPU batch
|
| 116 |
+
num_workers: parallel CPU workers for SMILES parsing
|
| 117 |
+
pooling: global pooling type ("max" | "sum" | "mean")
|
| 118 |
+
"""
|
| 119 |
+
print("Building molecule database with PyTorch DataLoader parallelization...")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# set up pooling
|
| 123 |
+
if pooling == "max":
|
| 124 |
+
self.pooling = dgl.nn.pytorch.glob.MaxPooling()
|
| 125 |
+
elif pooling == "sum":
|
| 126 |
+
self.pooling = dgl.nn.pytorch.glob.SumPooling()
|
| 127 |
+
elif pooling == "mean":
|
| 128 |
+
self.pooling = dgl.nn.pytorch.glob.MeanPooling()
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unsupported pooling: {pooling}")
|
| 131 |
+
|
| 132 |
+
dataset = MoleculeDataset(smiles_dict, self.smiles_preprocess)
|
| 133 |
+
loader = DataLoader(
|
| 134 |
+
dataset,
|
| 135 |
+
batch_size=batch_size,
|
| 136 |
+
shuffle=False,
|
| 137 |
+
num_workers=num_workers,
|
| 138 |
+
collate_fn=collate_graphs,
|
| 139 |
+
pin_memory=True,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
mol_ids_all, mol_objs, mol_embs = [], [], []
|
| 143 |
+
failed_mols = []
|
| 144 |
+
node_cache = {}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
for mol_ids, batched_graph in tqdm(loader, desc="Encoding molecules"):
|
| 149 |
+
if batched_graph is None:
|
| 150 |
+
# All failed in this batch
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
batched_graph = batched_graph.to(self.device)
|
| 155 |
+
node_repr = self.molecule_encoder(batched_graph, batched_graph.ndata['h'])
|
| 156 |
+
global_emb = self.pooling(batched_graph,node_repr)
|
| 157 |
+
|
| 158 |
+
# Normalize embeddings
|
| 159 |
+
emb_np = global_emb.detach().cpu().numpy()
|
| 160 |
+
emb_np /= np.linalg.norm(emb_np, axis=1, keepdims=True)
|
| 161 |
+
|
| 162 |
+
mol_ids_all.extend(mol_ids)
|
| 163 |
+
mol_objs.extend([batched_graph] * len(mol_ids))
|
| 164 |
+
mol_embs.append(emb_np)
|
| 165 |
+
|
| 166 |
+
# Optionally store node embeddings for fine-grained search
|
| 167 |
+
if cache_nodes:
|
| 168 |
+
# Split batched node embeddings into per-graph chunks
|
| 169 |
+
node_embs = dgl.unbatch(batched_graph)
|
| 170 |
+
for mol_id, mol_graph in zip(mol_ids, node_embs):
|
| 171 |
+
node_cache[mol_id] = mol_graph.ndata['h'].detach().cpu()
|
| 172 |
+
except Exception as e:
|
| 173 |
+
failed_mols.extend(mol_ids)
|
| 174 |
+
print(f"[Warning] Failed to encode batch with molecules {mol_ids}: {e}")
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
if not mol_embs:
|
| 178 |
+
raise RuntimeError("No valid molecules were successfully encoded.")
|
| 179 |
+
|
| 180 |
+
self.failed_mols = failed_mols
|
| 181 |
+
self.smiles_dict = smiles_dict
|
| 182 |
+
self.molecule_db = dict(zip(mol_ids_all, mol_objs))
|
| 183 |
+
self.molecule_ids = np.array(mol_ids_all)
|
| 184 |
+
self.global_embeddings = np.concatenate(mol_embs, axis=0)
|
| 185 |
+
self.metadata = metadata or {}
|
| 186 |
+
self.node_cache.update(node_cache)
|
| 187 |
+
|
| 188 |
+
self._build_faiss_index()
|
| 189 |
+
|
| 190 |
+
print(f"Database built with {len(self.molecule_ids)} molecules "
|
| 191 |
+
f"({len(self.failed_mols) + (len(smiles_dict) - len(self.molecule_ids))} failed).")
|
| 192 |
+
|
| 193 |
+
def _build_faiss_index(self):
|
| 194 |
+
d = self.global_embeddings.shape[1]
|
| 195 |
+
self.index = faiss.IndexFlatIP(d)
|
| 196 |
+
self.index.add(self.global_embeddings)
|
| 197 |
+
print(f"FAISS index built with {len(self.molecule_ids)} embeddings.")
|
| 198 |
+
|
| 199 |
+
def save_database(self, path: str):
|
| 200 |
+
"""Save molecule database and embeddings."""
|
| 201 |
+
data = {
|
| 202 |
+
"molecule_ids": self.molecule_ids,
|
| 203 |
+
"global_embeddings": self.global_embeddings,
|
| 204 |
+
"metadata": self.metadata,
|
| 205 |
+
"node_cache": self.node_cache,
|
| 206 |
+
"smiles_dict": self.smiles_dict,
|
| 207 |
+
}
|
| 208 |
+
with open(path, "wb") as f:
|
| 209 |
+
pickle.dump(data, f)
|
| 210 |
+
print(f"Database saved to {path}")
|
| 211 |
+
|
| 212 |
+
def load_database(self, path: str):
|
| 213 |
+
"""Load molecule database and rebuild FAISS index."""
|
| 214 |
+
with open(path, "rb") as f:
|
| 215 |
+
data = pickle.load(f)
|
| 216 |
+
self.molecule_ids = data["molecule_ids"]
|
| 217 |
+
self.global_embeddings = data["global_embeddings"]
|
| 218 |
+
self.metadata = data.get("metadata", {})
|
| 219 |
+
self.node_cache = data.get("node_cache", {})
|
| 220 |
+
self.smiles_dict = data.get("smiles_dict", {})
|
| 221 |
+
self._build_faiss_index()
|
| 222 |
+
print(f"Database loaded from {path}")
|
| 223 |
+
|
| 224 |
+
# -------------------------------
|
| 225 |
+
# Filtering utilities
|
| 226 |
+
# -------------------------------
|
| 227 |
+
def _get_filtered_indices(self, subset: Optional[Dict[str, str]] = None) -> np.ndarray:
|
| 228 |
+
"""
|
| 229 |
+
Retrieve indices for molecules matching a given metadata subset.
|
| 230 |
+
Example subset: {"class": "lipid"} or {"pathway": "glycolysis"}
|
| 231 |
+
"""
|
| 232 |
+
if not subset:
|
| 233 |
+
return np.arange(len(self.molecule_ids))
|
| 234 |
+
|
| 235 |
+
key, value = next(iter(subset.items()))
|
| 236 |
+
if key not in self.metadata or value not in self.metadata[key]:
|
| 237 |
+
print(f"[Warning] No molecules found for {key}={value}")
|
| 238 |
+
return np.array([], dtype=int)
|
| 239 |
+
|
| 240 |
+
mol_ids = self.metadata[key][value]
|
| 241 |
+
id_to_idx = {m: i for i, m in enumerate(self.molecule_ids)}
|
| 242 |
+
selected = [id_to_idx[m] for m in mol_ids if m in id_to_idx]
|
| 243 |
+
return np.array(selected, dtype=int)
|
| 244 |
+
|
| 245 |
+
# -------------------------------
|
| 246 |
+
# Retrieval
|
| 247 |
+
# -------------------------------
|
| 248 |
+
def coarse_search(self, spectrum, top_k: int = 256, subset: Optional[Dict[str, str]] = None):
|
| 249 |
+
"""
|
| 250 |
+
Retrieve top-k candidates using FAISS, optionally restricted to subset metadata.
|
| 251 |
+
"""
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
spectrum = spectrum.to(self.device)
|
| 254 |
+
z_spec = self.spectra_encoder(spectrum).sum(axis=0)
|
| 255 |
+
z_spec = z_spec.detach().cpu().numpy() if hasattr(z_spec, "detach") else np.asarray(z_spec)
|
| 256 |
+
z_spec = z_spec / np.linalg.norm(z_spec)
|
| 257 |
+
|
| 258 |
+
subset_idx = self._get_filtered_indices(subset)
|
| 259 |
+
if subset_idx.size == 0:
|
| 260 |
+
return [], []
|
| 261 |
+
|
| 262 |
+
# subset FAISS index
|
| 263 |
+
emb_subset = self.global_embeddings[subset_idx]
|
| 264 |
+
index_subset = faiss.IndexFlatIP(emb_subset.shape[1])
|
| 265 |
+
index_subset.add(emb_subset)
|
| 266 |
+
sims, idxs = index_subset.search(z_spec[None, :], min(top_k, len(subset_idx)))
|
| 267 |
+
|
| 268 |
+
candidate_ids = self.molecule_ids[subset_idx[idxs[0]]]
|
| 269 |
+
return candidate_ids, sims[0]
|
| 270 |
+
|
| 271 |
+
def fine_rerank(self, spectrum, candidate_ids: List[str], top_k: int = 50):
|
| 272 |
+
"""
|
| 273 |
+
Compute fine-grained similarity for the candidates and rerank.
|
| 274 |
+
"""
|
| 275 |
+
spectrum = spectrum.to(self.device)
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
z_spec_tokens = self.spectra_encoder(spectrum)
|
| 278 |
+
scores = []
|
| 279 |
+
for mol_id in candidate_ids:
|
| 280 |
+
if mol_id in self.node_cache:
|
| 281 |
+
mol_tokens = self.node_cache[mol_id]
|
| 282 |
+
elif mol_id in self.molecule_db:
|
| 283 |
+
mol = self.molecule_db[mol_id].to(self.device)
|
| 284 |
+
mol_tokens = self.molecule_encoder(mol)
|
| 285 |
+
else:
|
| 286 |
+
mol = self.smiles_preprocess(self.smiles_dict[mol_id])
|
| 287 |
+
mol = mol.to(self.device)
|
| 288 |
+
mol_tokens = self.molecule_encoder(mol)
|
| 289 |
+
|
| 290 |
+
s = self.fine_similarity_fn(z_spec_tokens, mol_tokens).item()
|
| 291 |
+
scores.append((mol_id, s))
|
| 292 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 293 |
+
return scores[:top_k]
|
| 294 |
+
|
| 295 |
+
def search(
|
| 296 |
+
self,
|
| 297 |
+
spectrum,
|
| 298 |
+
coarse_k: int = 256,
|
| 299 |
+
fine_k: int = 50,
|
| 300 |
+
subset: Optional[Dict[str, str]] = None,
|
| 301 |
+
):
|
| 302 |
+
"""
|
| 303 |
+
Full two-stage search pipeline with optional subset filtering.
|
| 304 |
+
"""
|
| 305 |
+
candidate_ids, _ = self.coarse_search(spectrum, top_k=coarse_k, subset=subset)
|
| 306 |
+
if len(candidate_ids) == 0:
|
| 307 |
+
return []
|
| 308 |
+
ranked = self.fine_rerank(spectrum, candidate_ids, top_k=fine_k)
|
| 309 |
+
return ranked
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
import sys
|
| 315 |
+
sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
|
| 316 |
+
|
| 317 |
+
from flare.utils.data import get_spec_featurizer, get_mol_featurizer
|
| 318 |
+
from flare.utils.models import get_model
|
| 319 |
+
from flare.utils.mol_search import SpectraMoleculeRetriever
|
| 320 |
+
from flare.utils.general import filip_similarity_single
|
| 321 |
+
import yaml
|
| 322 |
+
|
| 323 |
+
metadata = {
|
| 324 |
+
"class": {
|
| 325 |
+
"lipid": ["mol1", "mol2"],
|
| 326 |
+
"peptide": ["mol3"]
|
| 327 |
+
},
|
| 328 |
+
"pathway": {
|
| 329 |
+
"beta-oxidation": ["mol1"],
|
| 330 |
+
"glycolysis": ["mol2", "mol3"]
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
smiles_dict = {
|
| 335 |
+
"mol1": "CCO",
|
| 336 |
+
"mol2": "CCN",
|
| 337 |
+
"mol3": "CCC"
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
# Load model and data
|
| 341 |
+
param_pth = '/data/yzhouc01/cancer/flare.yaml'
|
| 342 |
+
with open(param_pth) as f:
|
| 343 |
+
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 344 |
+
|
| 345 |
+
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
|
| 346 |
+
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# load model
|
| 350 |
+
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt"
|
| 351 |
+
params['checkpoint_pth'] = checkpoint_pth
|
| 352 |
+
model = get_model(params['model'], params)
|
| 353 |
+
|
| 354 |
+
specMolRetriever = SpectraMoleculeRetriever(
|
| 355 |
+
molecule_encoder=model.mol_enc_model,
|
| 356 |
+
spectra_encoder=model.spec_enc_model,
|
| 357 |
+
fine_similarity_fn=filip_similarity_single,
|
| 358 |
+
smiles_preprocess=mol_featurizer
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True)
|
| 362 |
+
|
| 363 |
+
# Filter search to molecules in a specific pathway
|
| 364 |
+
# results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"})
|
| 365 |
+
|
| 366 |
+
# for mol_id, score in results[:10]:
|
| 367 |
+
# print(f"{mol_id}: {score:.3f}")
|
notebooks/UMAP_spectra_embeddings.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/fine-grained_vs_global.ipynb
CHANGED
|
@@ -29819,9 +29819,13 @@
|
|
| 29819 |
],
|
| 29820 |
"metadata": {
|
| 29821 |
"kernelspec": {
|
| 29822 |
-
"display_name": "
|
| 29823 |
"language": "python",
|
| 29824 |
-
"name": "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29825 |
}
|
| 29826 |
},
|
| 29827 |
"nbformat": 4,
|
|
|
|
| 29819 |
],
|
| 29820 |
"metadata": {
|
| 29821 |
"kernelspec": {
|
| 29822 |
+
"display_name": "spec",
|
| 29823 |
"language": "python",
|
| 29824 |
+
"name": "python3"
|
| 29825 |
+
},
|
| 29826 |
+
"language_info": {
|
| 29827 |
+
"name": "python",
|
| 29828 |
+
"version": "3.11.7"
|
| 29829 |
}
|
| 29830 |
},
|
| 29831 |
"nbformat": 4,
|
notebooks/good_vs_bad_instances.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/mol-spec_visualization.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/results.ipynb
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "2cd3303a",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pickle\n",
|
| 11 |
+
"import pandas as pd"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 2,
|
| 17 |
+
"id": "8ccc0bc1",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"with open(\"/data/yzhouc01/FILIP-MS/experiments/20251110_filip-global/result_MassSpecGym_retrieval_candidates_formula.pkl\", \"rb\") as f:\n",
|
| 22 |
+
" result = pickle.load(f)"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": 3,
|
| 28 |
+
"id": "8e517777",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"outputs": [
|
| 31 |
+
{
|
| 32 |
+
"data": {
|
| 33 |
+
"text/html": [
|
| 34 |
+
"<div>\n",
|
| 35 |
+
"<style scoped>\n",
|
| 36 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 37 |
+
" vertical-align: middle;\n",
|
| 38 |
+
" }\n",
|
| 39 |
+
"\n",
|
| 40 |
+
" .dataframe tbody tr th {\n",
|
| 41 |
+
" vertical-align: top;\n",
|
| 42 |
+
" }\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" .dataframe thead th {\n",
|
| 45 |
+
" text-align: right;\n",
|
| 46 |
+
" }\n",
|
| 47 |
+
"</style>\n",
|
| 48 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 49 |
+
" <thead>\n",
|
| 50 |
+
" <tr style=\"text-align: right;\">\n",
|
| 51 |
+
" <th></th>\n",
|
| 52 |
+
" <th>rank_fine</th>\n",
|
| 53 |
+
" <th>rank_global</th>\n",
|
| 54 |
+
" <th>rank_sum</th>\n",
|
| 55 |
+
" <th>rank_weighted</th>\n",
|
| 56 |
+
" <th>rank_avg</th>\n",
|
| 57 |
+
" </tr>\n",
|
| 58 |
+
" </thead>\n",
|
| 59 |
+
" <tbody>\n",
|
| 60 |
+
" <tr>\n",
|
| 61 |
+
" <th>R@1</th>\n",
|
| 62 |
+
" <td>0.214571</td>\n",
|
| 63 |
+
" <td>0.163306</td>\n",
|
| 64 |
+
" <td>0.192869</td>\n",
|
| 65 |
+
" <td>0.191274</td>\n",
|
| 66 |
+
" <td>0.192869</td>\n",
|
| 67 |
+
" </tr>\n",
|
| 68 |
+
" <tr>\n",
|
| 69 |
+
" <th>R@5</th>\n",
|
| 70 |
+
" <td>0.483140</td>\n",
|
| 71 |
+
" <td>0.403566</td>\n",
|
| 72 |
+
" <td>0.447425</td>\n",
|
| 73 |
+
" <td>0.444862</td>\n",
|
| 74 |
+
" <td>0.447425</td>\n",
|
| 75 |
+
" </tr>\n",
|
| 76 |
+
" <tr>\n",
|
| 77 |
+
" <th>R@20</th>\n",
|
| 78 |
+
" <td>0.747095</td>\n",
|
| 79 |
+
" <td>0.694350</td>\n",
|
| 80 |
+
" <td>0.728355</td>\n",
|
| 81 |
+
" <td>0.726361</td>\n",
|
| 82 |
+
" <td>0.728355</td>\n",
|
| 83 |
+
" </tr>\n",
|
| 84 |
+
" </tbody>\n",
|
| 85 |
+
"</table>\n",
|
| 86 |
+
"</div>"
|
| 87 |
+
],
|
| 88 |
+
"text/plain": [
|
| 89 |
+
" rank_fine rank_global rank_sum rank_weighted rank_avg\n",
|
| 90 |
+
"R@1 0.214571 0.163306 0.192869 0.191274 0.192869\n",
|
| 91 |
+
"R@5 0.483140 0.403566 0.447425 0.444862 0.447425\n",
|
| 92 |
+
"R@20 0.747095 0.694350 0.728355 0.726361 0.728355"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
"execution_count": 3,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"output_type": "execute_result"
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"source": [
|
| 101 |
+
"data = []\n",
|
| 102 |
+
"for i in [1, 5, 20]:\n",
|
| 103 |
+
" curr_d = {}\n",
|
| 104 |
+
" for c in result.columns.tolist():\n",
|
| 105 |
+
" if c.startswith('rank'):\n",
|
| 106 |
+
" curr_d[c] = result[result[c] <= i].shape[0] / result.shape[0]\n",
|
| 107 |
+
" data.append(curr_d)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"data_df = pd.DataFrame(data, index=['R@1', 'R@5', 'R@20'])\n",
|
| 110 |
+
"data_df\n"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": 7,
|
| 116 |
+
"id": "10493857",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [
|
| 119 |
+
{
|
| 120 |
+
"data": {
|
| 121 |
+
"text/html": [
|
| 122 |
+
"<div>\n",
|
| 123 |
+
"<style scoped>\n",
|
| 124 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 125 |
+
" vertical-align: middle;\n",
|
| 126 |
+
" }\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" .dataframe tbody tr th {\n",
|
| 129 |
+
" vertical-align: top;\n",
|
| 130 |
+
" }\n",
|
| 131 |
+
"\n",
|
| 132 |
+
" .dataframe thead th {\n",
|
| 133 |
+
" text-align: right;\n",
|
| 134 |
+
" }\n",
|
| 135 |
+
"</style>\n",
|
| 136 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 137 |
+
" <thead>\n",
|
| 138 |
+
" <tr style=\"text-align: right;\">\n",
|
| 139 |
+
" <th></th>\n",
|
| 140 |
+
" <th>rank_fine</th>\n",
|
| 141 |
+
" <th>rank_global</th>\n",
|
| 142 |
+
" <th>rank_sum</th>\n",
|
| 143 |
+
" <th>rank_weighted</th>\n",
|
| 144 |
+
" <th>rank_avg</th>\n",
|
| 145 |
+
" </tr>\n",
|
| 146 |
+
" </thead>\n",
|
| 147 |
+
" <tbody>\n",
|
| 148 |
+
" <tr>\n",
|
| 149 |
+
" <th>R@1</th>\n",
|
| 150 |
+
" <td>0.420882</td>\n",
|
| 151 |
+
" <td>0.369731</td>\n",
|
| 152 |
+
" <td>0.412907</td>\n",
|
| 153 |
+
" <td>0.411939</td>\n",
|
| 154 |
+
" <td>0.412907</td>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" <tr>\n",
|
| 157 |
+
" <th>R@5</th>\n",
|
| 158 |
+
" <td>0.744475</td>\n",
|
| 159 |
+
" <td>0.707052</td>\n",
|
| 160 |
+
" <td>0.738893</td>\n",
|
| 161 |
+
" <td>0.737412</td>\n",
|
| 162 |
+
" <td>0.738893</td>\n",
|
| 163 |
+
" </tr>\n",
|
| 164 |
+
" <tr>\n",
|
| 165 |
+
" <th>R@20</th>\n",
|
| 166 |
+
" <td>0.927660</td>\n",
|
| 167 |
+
" <td>0.916325</td>\n",
|
| 168 |
+
" <td>0.926407</td>\n",
|
| 169 |
+
" <td>0.926122</td>\n",
|
| 170 |
+
" <td>0.926407</td>\n",
|
| 171 |
+
" </tr>\n",
|
| 172 |
+
" </tbody>\n",
|
| 173 |
+
"</table>\n",
|
| 174 |
+
"</div>"
|
| 175 |
+
],
|
| 176 |
+
"text/plain": [
|
| 177 |
+
" rank_fine rank_global rank_sum rank_weighted rank_avg\n",
|
| 178 |
+
"R@1 0.420882 0.369731 0.412907 0.411939 0.412907\n",
|
| 179 |
+
"R@5 0.744475 0.707052 0.738893 0.737412 0.738893\n",
|
| 180 |
+
"R@20 0.927660 0.916325 0.926407 0.926122 0.926407"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
"execution_count": 7,
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"output_type": "execute_result"
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
"source": [
|
| 189 |
+
"data = []\n",
|
| 190 |
+
"for i in [1, 5, 20]:\n",
|
| 191 |
+
" curr_d = {}\n",
|
| 192 |
+
" for c in result.columns.tolist():\n",
|
| 193 |
+
" if c.startswith('rank'):\n",
|
| 194 |
+
" curr_d[c] = result[result[c] <= i].shape[0] / result.shape[0]\n",
|
| 195 |
+
" data.append(curr_d)\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"data_df = pd.DataFrame(data, index=['R@1', 'R@5', 'R@20'])\n",
|
| 198 |
+
"data_df\n"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": null,
|
| 204 |
+
"id": "1e4201db",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"x"
|
| 209 |
+
]
|
| 210 |
+
}
|
| 211 |
+
],
|
| 212 |
+
"metadata": {
|
| 213 |
+
"kernelspec": {
|
| 214 |
+
"display_name": "spec",
|
| 215 |
+
"language": "python",
|
| 216 |
+
"name": "python3"
|
| 217 |
+
},
|
| 218 |
+
"language_info": {
|
| 219 |
+
"codemirror_mode": {
|
| 220 |
+
"name": "ipython",
|
| 221 |
+
"version": 3
|
| 222 |
+
},
|
| 223 |
+
"file_extension": ".py",
|
| 224 |
+
"mimetype": "text/x-python",
|
| 225 |
+
"name": "python",
|
| 226 |
+
"nbconvert_exporter": "python",
|
| 227 |
+
"pygments_lexer": "ipython3",
|
| 228 |
+
"version": "3.11.7"
|
| 229 |
+
}
|
| 230 |
+
},
|
| 231 |
+
"nbformat": 4,
|
| 232 |
+
"nbformat_minor": 5
|
| 233 |
+
}
|
notebooks/spectra_sim.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|