yzhouchen001 commited on
Commit
19a4dfc
·
1 Parent(s): 60219be
app_utils/model_utils.py CHANGED
@@ -3,8 +3,8 @@ import sys
3
  # sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
4
 
5
  from rdkit import RDLogger
6
- from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
7
- from mvp.utils.models import get_model
8
 
9
  import yaml
10
 
@@ -15,7 +15,7 @@ lg.setLevel(RDLogger.CRITICAL)
15
  # Load model and data
16
 
17
  def load_model_components():
18
- param_pth = 'hparams.yaml'
19
  with open(param_pth) as f:
20
  params = yaml.load(f, Loader=yaml.FullLoader)
21
 
@@ -24,7 +24,7 @@ def load_model_components():
24
 
25
  # load model
26
 
27
- checkpoint_pth = "epoch=1993-train_loss=0.10.ckpt"
28
  params['checkpoint_pth'] = checkpoint_pth
29
  model = get_model(params['model'], params)
30
 
 
3
  # sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
4
 
5
  from rdkit import RDLogger
6
+ from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
7
+ from flare.utils.models import get_model
8
 
9
  import yaml
10
 
 
15
  # Load model and data
16
 
17
  def load_model_components():
18
+ param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
19
  with open(param_pth) as f:
20
  params = yaml.load(f, Loader=yaml.FullLoader)
21
 
 
24
 
25
  # load model
26
 
27
+ checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
28
  params['checkpoint_pth'] = checkpoint_pth
29
  model = get_model(params['model'], params)
30
 
app_utils/viz_utils.py CHANGED
@@ -6,7 +6,9 @@ import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  from rdkit import Chem
8
  from rdkit.Chem import rdDepictor
9
- import pandas as pd
 
 
10
 
11
  def mol_to_graph_coords(mol):
12
  """Return atom coordinates and bond list for a molecule."""
@@ -16,12 +18,6 @@ def mol_to_graph_coords(mol):
16
  bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
17
  return coords, bonds
18
 
19
- import torch
20
- import torch.nn.functional as F
21
- import plotly.graph_objects as go
22
- from plotly.subplots import make_subplots
23
-
24
-
25
  def interactive_attention_visualization(
26
  spectral_embeds,
27
  graph_embeds,
@@ -68,7 +64,6 @@ def interactive_attention_visualization(
68
  hoverinfo='text',
69
  customdata=list(range(num_peaks)), # actual peak indices
70
  )
71
-
72
  # --- Graph nodes ---
73
  graph_nodes = go.Scatter(
74
  x=atom_x,
@@ -127,10 +122,6 @@ def interactive_attention_visualization(
127
  # ------------------------
128
  # Model set up
129
  # ------------------------
130
-
131
- from mvp.subformula_assign.utils.spectra_utils import assign_subforms
132
- import matchms
133
-
134
  def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
135
 
136
  # step 1 - label peaks with formula, setup matchms spectrum
 
6
  from plotly.subplots import make_subplots
7
  from rdkit import Chem
8
  from rdkit.Chem import rdDepictor
9
+
10
+ from flare.subformula_assign.utils.spectra_utils import assign_subforms
11
+ import matchms
12
 
13
  def mol_to_graph_coords(mol):
14
  """Return atom coordinates and bond list for a molecule."""
 
18
  bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
19
  return coords, bonds
20
 
 
 
 
 
 
 
21
  def interactive_attention_visualization(
22
  spectral_embeds,
23
  graph_embeds,
 
64
  hoverinfo='text',
65
  customdata=list(range(num_peaks)), # actual peak indices
66
  )
 
67
  # --- Graph nodes ---
68
  graph_nodes = go.Scatter(
69
  x=atom_x,
 
122
  # ------------------------
123
  # Model set up
124
  # ------------------------
 
 
 
 
125
  def run(ms, smiles, formula, precursor_mz, adduct, spec_featurizer, mol_featurizer,model, mass_diff_thresh=20, precursor_intensity=1.1):
126
 
127
  # step 1 - label peaks with formula, setup matchms spectrum
flare/data/datasets.py CHANGED
@@ -83,7 +83,7 @@ class JESTR1_MassSpecDataset(MassSpecDataset):
83
 
84
  spec = self.spectra[i]
85
  metadata = self.metadata.iloc[i]
86
- mol = metadata["smiles"]
87
 
88
  # Apply all transformations to the spectrum
89
  item = {}
@@ -254,7 +254,7 @@ class ContrastiveDataset(Dataset):
254
  return item
255
 
256
  @staticmethod
257
- def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None) -> dict:
258
  mol_key = 'cand' if stage == Stage.TEST else 'mol'
259
  non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
260
  require_pad = False
@@ -277,15 +277,16 @@ class ContrastiveDataset(Dataset):
277
  raise
278
 
279
  # batch graphs
280
- batch_mol = []
281
- batch_mol_nodes= []
 
282
 
283
- for item in batch:
284
- batch_mol.append(item[mol_key])
285
- batch_mol_nodes.append(item[mol_key].num_nodes())
286
 
287
- collated_batch[mol_key] = dgl.batch(batch_mol)
288
- collated_batch['mol_n_nodes'] = batch_mol_nodes
289
 
290
  # pad peaks/formulas
291
  if require_pad:
@@ -347,7 +348,15 @@ class ExpandedRetrievalDataset:
347
 
348
  self.candidates = {}
349
  for s, cand in candidates.items():
350
- self.candidates[s] = [c for c in cand if '.' not in c]
 
 
 
 
 
 
 
 
351
 
352
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
353
 
 
83
 
84
  spec = self.spectra[i]
85
  metadata = self.metadata.iloc[i]
86
+ mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"]
87
 
88
  # Apply all transformations to the spectrum
89
  item = {}
 
254
  return item
255
 
256
  @staticmethod
257
+ def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, batch_mol: bool = True) -> dict:
258
  mol_key = 'cand' if stage == Stage.TEST else 'mol'
259
  non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
260
  require_pad = False
 
277
  raise
278
 
279
  # batch graphs
280
+ if batch_mol:
281
+ batch_mol = []
282
+ batch_mol_nodes= []
283
 
284
+ for item in batch:
285
+ batch_mol.append(item[mol_key])
286
+ batch_mol_nodes.append(item[mol_key].num_nodes())
287
 
288
+ collated_batch[mol_key] = dgl.batch(batch_mol)
289
+ collated_batch['mol_n_nodes'] = batch_mol_nodes
290
 
291
  # pad peaks/formulas
292
  if require_pad:
 
348
 
349
  self.candidates = {}
350
  for s, cand in candidates.items():
351
+ clean_cands = []
352
+ for c in cand:
353
+ try:
354
+ if '.' not in c:
355
+ clean_cands.append(c)
356
+ except:
357
+ print(f"Error in processing candidate {c} for smiles {s}")
358
+ pass
359
+ self.candidates[s] = clean_cands
360
 
361
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
362
 
flare/models/contrastive.py CHANGED
@@ -10,7 +10,7 @@ from massspecgym.models.base import Stage
10
  from massspecgym import utils
11
  from torch.nn.utils.rnn import pad_sequence
12
 
13
- from flare.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss, filip_loss_with_mask
14
  import flare.utils.models as model_utils
15
  from flare.utils.general import pad_graph_nodes, filip_similarity_batch
16
 
@@ -18,14 +18,17 @@ from flare.models.encoders import CrossAttention
18
  import torch.nn.functional as F
19
 
20
  from torch_geometric.nn import global_mean_pool
 
21
 
22
  class ContrastiveModel(RetrievalMassSpecGymModel):
23
  def __init__(
24
  self,
 
25
  **kwargs
26
  ):
27
  super().__init__(**kwargs)
28
  self.save_hyperparameters()
 
29
 
30
  if 'use_fp' not in self.hparams:
31
  self.hparams.use_fp = False
@@ -42,13 +45,26 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
42
  self.result_dct = defaultdict(lambda: defaultdict(list))
43
 
44
  def forward(self, batch, stage):
45
- g = batch['cand'] if stage == Stage.TEST else batch['mol']
46
-
 
 
 
 
 
47
  spec = batch[self.spec_view]
48
  n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
49
  spec_enc = self.spec_enc_model(spec, n_peaks)
50
 
 
 
 
 
51
  fp = batch['fp'] if self.hparams.use_fp else None
 
 
 
 
52
  mol_enc = self.mol_enc_model(g, fp=fp)
53
 
54
  return spec_enc, mol_enc
@@ -61,20 +77,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
61
  losses['contr_loss'] = contr_loss.detach().item()
62
 
63
  loss+=contr_loss
64
- # if self.hparams.pred_fp:
65
- # fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
66
- # loss+= fp_loss_val
67
- # losses['fp_loss'] = fp_loss_val.detach().item()
68
-
69
- # if 'aug_cand_enc' in output:
70
- # aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
71
- # loss+= aug_cand_loss
72
- # losses['aug_cand_loss'] = aug_cand_loss.detach().item()
73
-
74
- # if 'ind_spec' in output:
75
- # spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
76
- # loss+=spec_loss
77
- # losses['cons_spec_loss'] = spec_loss.detach().item()
78
 
79
  losses['loss'] = loss
80
 
@@ -108,7 +110,7 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
108
  # total loss
109
  self.log(
110
  f'{stage.to_pref()}loss',
111
- outputs['loss'],
112
  batch_size=len(batch['identifier']),
113
  sync_dist=True,
114
  prog_bar=True,
@@ -146,11 +148,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
146
  self.result_dct[i]['candidates'].extend(cands)
147
  self.result_dct[i]['scores'].extend(scores.cpu().tolist())
148
  self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
149
-
150
- # # external test case only
151
- # for i, cands, scores in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores']):
152
- # self.result_dct[i.cpu().item()]['candidates'].extend(cands)
153
- # self.result_dct[i.cpu().item()]['scores'].extend(scores.cpu().tolist())
154
 
155
  def _compute_rank(self, scores, labels):
156
  if not any(labels):
@@ -160,12 +157,21 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
160
  rank = np.count_nonzero(scores >=target_score)
161
  return rank
162
 
 
 
 
163
  def on_test_epoch_end(self) -> None:
164
 
165
  self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
166
 
167
  # Compute rank
168
- self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
 
 
 
 
 
 
169
  if not self.df_test_path:
170
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
171
  self.df_test.to_pickle(self.df_test_path)
@@ -176,160 +182,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
176
  {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
177
  ]
178
  return monitors
179
-
180
- # class MultiViewContrastive(ContrastiveModel):
181
-
182
- # def __init__(self,
183
- # **kwargs):
184
-
185
- # super().__init__(**kwargs)
186
-
187
- # # build fingerprint encoder model
188
- # if self.hparams.use_fp:
189
- # self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
190
-
191
- # # build NL encoder model
192
- # if self.hparams.use_NL_spec:
193
- # self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
194
-
195
- # def forward(self, batch, stage):
196
- # g = batch['cand'] if stage == Stage.TEST else batch['mol']
197
-
198
- # spec = batch[self.spec_view]
199
- # n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
200
-
201
- # spec_enc = self.spec_enc_model(spec, n_peaks)
202
- # mol_enc = self.mol_enc_model(g)
203
- # views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
204
-
205
- # if self.hparams.use_fp:
206
- # fp_enc = self.fp_enc_model(batch['fp'])
207
- # views['fp_enc'] = fp_enc
208
-
209
- # if self.hparams.use_cons_spec:
210
- # spec = batch['cons_spec']
211
- # n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
212
- # spec_enc = self.cons_spec_enc_model(spec, n_peaks)
213
- # views['cons_spec_enc'] = spec_enc
214
-
215
- # if self.hparams.use_NL_spec:
216
- # spec = batch['NL_spec']
217
- # n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
218
- # spec_enc = self.NL_enc_model(spec, n_peaks)
219
- # views['NL_spec_enc'] = spec_enc
220
- # return views
221
-
222
- # def step(
223
- # self, batch: dict, stage= Stage.NONE):
224
-
225
- # # Compute spectra and mol encoding
226
- # views = self.forward(batch, stage)
227
-
228
- # if stage == Stage.TEST:
229
- # return views
230
-
231
- # # Calculate loss
232
- # losses = self.compute_loss(batch, views)
233
-
234
- # return losses
235
-
236
- # def compute_loss(self, batch: dict, views: dict):
237
- # loss = 0
238
- # losses = {}
239
- # for v1, v2 in self.hparams.contr_views:
240
- # contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
241
- # loss+=contr_loss
242
-
243
- # losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
244
- # losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
245
- # losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
246
-
247
- # losses['loss'] = loss
248
-
249
- # return losses
250
-
251
- # def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
252
- # # total loss
253
- # self.log(
254
- # f'{stage.to_pref()}loss',
255
- # outputs['loss'],
256
- # batch_size=len(batch['identifier']),
257
- # sync_dist=True,
258
- # prog_bar=True,
259
- # on_epoch=True,
260
- # # on_step=True
261
- # )
262
-
263
- # for v1, v2 in self.hparams.contr_views:
264
- # self.log(
265
- # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
266
- # outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
267
- # batch_size=len(batch['identifier']),
268
- # sync_dist=True,
269
- # on_epoch=True,
270
- # )
271
- # self.log(
272
- # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
273
- # outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
274
- # batch_size=len(batch['identifier']),
275
- # sync_dist=True,
276
- # on_epoch=True,
277
- # )
278
- # self.log(
279
- # f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
280
- # outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
281
- # batch_size=len(batch['identifier']),
282
- # sync_dist=True,
283
- # on_epoch=True,
284
- # )
285
-
286
- # def test_step(self, batch):
287
- # # Unpack inputs
288
- # identifiers = batch['identifier']
289
- # cand_smiles = batch['cand_smiles']
290
- # id_to_ct = defaultdict(int)
291
- # for i in identifiers: id_to_ct[i]+=1
292
- # batch_ptr = torch.tensor(list(id_to_ct.values()))
293
-
294
- # outputs = self.step(batch, stage=Stage.TEST)
295
- # scores = {}
296
- # for v1, v2 in self.hparams.contr_views:
297
- # # if 'cons_spec_enc' in (v1, v2):
298
- # # continue
299
- # v1_enc = outputs[v1]
300
- # v2_enc = outputs[v2]
301
-
302
- # s = nn.functional.cosine_similarity(v1_enc, v2_enc)
303
- # scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
304
-
305
- # indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
306
-
307
- # cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
308
- # labels = utils.unbatch_list(batch['label'], indexes)
309
-
310
- # return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
311
-
312
- # def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
313
-
314
- # # save scores
315
- # for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
316
- # self.result_dct[i]['candidates'].extend(cands)
317
- # self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
318
-
319
- # for v1, v2 in self.hparams.contr_views:
320
- # for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
321
- # self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
322
-
323
-
324
- # def on_test_epoch_end(self) -> None:
325
-
326
- # self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
327
-
328
- # # Compute rank
329
- # for v1, v2 in self.hparams.contr_views:
330
- # self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
331
-
332
- # self.df_test.to_pickle(self.df_test_path)
333
 
334
  class FilipContrastive(ContrastiveModel):
335
  def __init__(self,
@@ -381,7 +233,7 @@ class FilipContrastive(ContrastiveModel):
381
  # Calculate scores
382
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
383
 
384
- scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_masks)
385
  scores = torch.split(scores, list(id_to_ct.values()))
386
 
387
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
@@ -389,248 +241,177 @@ class FilipContrastive(ContrastiveModel):
389
 
390
  return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
391
 
392
- # class MultiViewFineTuning(MultiViewContrastive):
393
- # def __init__(self,
394
- # **kwargs):
395
- # super().__init__(**kwargs)
396
-
397
- # # load preptrained spec, mol, fp encoders
398
- # checkpoint = torch.load(self.hparams.partial_checkpoint)
399
- # state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
400
- # self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
401
-
402
- # state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
403
- # self.mol_enc_model.load_state_dict(state_dict)
404
-
405
- # state_dict = state_dict = {k[len("fp_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("fp_enc_model")}
406
- # self.fp_enc_model.load_state_dict(state_dict)
407
-
408
- # self.encoding_views = ['spec_enc', 'mol_enc', 'fp_enc']
409
- # self.loss_fn = nn.BCELoss()
410
-
411
- # # freeze encoders
412
- # for param in self.mol_enc_model.parameters():
413
- # param.requires_grad = False
414
- # for param in self.spec_enc_model.parameters():
415
- # param.requires_grad = False
416
- # for param in self.fp_enc_model.parameters():
417
- # param.requires_grad = False
418
- # for param in self.cons_spec_enc_model.parameters():
419
- # param.requires_grad = False
420
-
421
- # # n_views = 2
422
- # # if self.hparams.use_fp:
423
- # # n_views+=1
424
-
425
- # # in_dim = self.hparams.final_embedding_dim*n_views
426
- # in_dim = self.hparams.final_embedding_dim *2 + 2
427
-
428
- # self.classifier_model = nn.Sequential(
429
- # nn.Linear(in_dim, 512),
430
- # nn.ReLU(),
431
- # nn.BatchNorm1d(512),
432
- # nn.Dropout(0.3),
433
- # nn.Linear(512, 256),
434
- # nn.ReLU(),
435
- # nn.BatchNorm1d(256),
436
- # nn.Dropout(0.3),
437
- # nn.Linear(256, 1),
438
- # nn.Sigmoid()
439
- # )
440
- # self.noise_std = 0.01
441
-
442
- # def _add_noise(self, x):
443
- # noise = torch.randn_like(x) * self.noise_std
444
- # return x + noise
445
-
446
- # def forward(self, batch, stage):
447
-
448
- # matching_views = super().forward(batch, stage)
449
- # # matching_enc = torch.concat((matching_views['spec_enc'], matching_views['mol_enc'], matching_views['fp_enc']), dim=-1)
450
- # # enc1 = matching_views['spec_enc'] - matching_views['mol_enc']
451
- # # enc2 = matching_views['spec_enc'] - matching_views['fp_enc']
452
- # # matching_enc = torch.concat((enc1, enc2), dim=-1)
453
- # view1 = matching_views['spec_enc']
454
- # view2 = matching_views['mol_enc']
455
- # view3 = matching_views['fp_enc']
456
-
457
- # if stage == Stage.TRAIN:
458
- # view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
459
-
460
- # pairwise_diffs = torch.cat([
461
- # torch.abs(view1 - view2),
462
- # torch.abs(view1 - view3),
463
- # ], dim=-1)
464
-
465
- # pairwise_sims = torch.cat([
466
- # (view1 * view2).sum(dim=-1, keepdim=True),
467
- # (view1 * view3).sum(dim=-1, keepdim=True),
468
- # ], dim=-1)
469
-
470
- # matching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
471
- # matching_scores = self.classifier_model(matching_enc)
472
-
473
- # if stage == Stage.TEST:
474
- # return dict(matching_scores = matching_scores)
475
-
476
- # view1 = view1.repeat_interleave(self.hparams.aug_cands_size, dim=0)
477
- # view2 = self.mol_enc_model(batch['aug_cands'])
478
- # view3= self.fp_enc_model(batch['aug_cands_fp'])
479
- # if stage == Stage.TRAIN:
480
- # view1, view2, view3 = map(self._add_noise, (view1, view2, view3))
481
 
482
- # pairwise_diffs = torch.cat([
483
- # torch.abs(view1 - view2),
484
- # torch.abs(view1 - view3),
485
- # ], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
486
 
487
- # pairwise_sims = torch.cat([
488
- # (view1 * view2).sum(dim=-1, keepdim=True),
489
- # (view1 * view3).sum(dim=-1, keepdim=True),
490
- # ], dim=-1)
491
 
492
- # nonmatching_enc = torch.cat([pairwise_diffs, pairwise_sims], dim=-1)
493
 
494
- # nonmatching_scores = self.classifier_model(nonmatching_enc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
- # return dict(matching_scores=matching_scores, nonmatching_scores=nonmatching_scores)
497
-
498
- # def compute_loss(self, matching_scores, nonmatching_scores):
499
-
500
- # matching_loss = self.loss_fn(matching_scores, torch.ones_like(matching_scores).to(matching_scores.device))
501
- # nonmatching_loss = self.loss_fn(nonmatching_scores, torch.zeros_like(nonmatching_scores).to(nonmatching_scores.device))
502
 
503
- # loss = matching_loss + (1/self.hparams.aug_cands_size)*nonmatching_loss
 
 
 
504
 
505
- # return dict(loss=loss)
506
-
507
- # def step(
508
- # self, batch: dict, stage= Stage.NONE):
509
-
510
- # output = self.forward(batch, stage)
511
 
512
- # if stage == Stage.TEST:
513
- # return output
514
 
515
- # # Calculate loss
516
- # losses = self.compute_loss(output['matching_scores'], output['nonmatching_scores'])
517
 
518
- # return losses
519
-
520
- # def test_step(self, batch):
521
- # # Unpack inputs
522
- # identifiers = batch['identifier']
523
- # cand_smiles = batch['cand_smiles']
524
- # id_to_ct = defaultdict(int)
525
- # for i in identifiers: id_to_ct[i]+=1
526
- # batch_ptr = torch.tensor(list(id_to_ct.values()))
527
 
528
- # outputs = self.step(batch, stage=Stage.TEST)
529
- # scores = outputs['matching_scores']
530
 
531
- # indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
 
 
 
532
 
533
- # cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
534
- # labels = utils.unbatch_list(batch['label'], indexes)
535
-
536
- # return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
537
-
538
- # def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
539
- # # total loss
540
- # self.log(
541
- # f'{stage.to_pref()}loss',
542
- # outputs['loss'],
543
- # batch_size=len(batch['identifier']),
544
- # sync_dist=True,
545
- # prog_bar=True,
546
- # on_epoch=True,
547
- # # on_step=True
548
- # )
549
-
550
- # def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
551
- # ContrastiveModel.on_test_batch_end(self, outputs, batch, batch_idx, stage)
552
-
553
- # def on_test_epoch_end(self):
554
- # self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
555
- # # self.df_test.to_csv(self.hparams.resutl)
556
- # print(self.df_test_path)
557
- # self.df_test.to_pickle(self.df_test_path)
558
- # # ContrastiveModel.on_test_epoch_end(self)
559
-
560
- # def get_checkpoint_monitors(self) -> T.List[dict]:
561
- # monitors = [
562
- # {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
563
- # ]
564
- # return monitors
565
- # def configure_optimizers(self):
566
- # return torch.optim.Adam(
567
- # self.classifier_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
568
- # )
569
-
570
- # class IndSpecEncoder(ContrastiveModel):
571
- # """ Trains a spectra encoder that maps to a pretrained spec encoder"""
572
- # def __init__(
573
- # self,
574
- # **kwargs
575
- # ):
576
- # super().__init__(**kwargs)
577
-
578
- # # initialize ind_spec_encoder and loss
579
- # self.ind_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
580
- # self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
581
-
582
- # # load preptrained spec and mol encoders
583
- # checkpoint = torch.load(self.hparams.partial_checkpoint)
584
- # state_dict = state_dict = {k[len("spec_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("spec_enc_model")}
585
- # self.spec_enc_model.load_state_dict(state_dict) # trained on consensus spectra
586
-
587
- # state_dict = state_dict = {k[len("mol_enc_model."):]: v for k, v in checkpoint['state_dict'].items() if k.startswith("mol_enc_model")}
588
- # self.mol_enc_model.load_state_dict(state_dict)
589
-
590
- # # freeze cons spec and mol encoders
591
- # for param in self.mol_enc_model.parameters():
592
- # param.requires_grad = False
593
- # for param in self.spec_enc_model.parameters():
594
- # param.requires_grad = False
595
-
596
- # def forward(self, batch, stage):
597
-
598
- # spec = batch[self.spec_view]
599
- # n_peaks = batch['n_peaks']
600
- # spec_enc = self.ind_spec_enc_model(spec, n_peaks)
601
-
602
- # return spec_enc
603
-
604
- # def compute_loss(self, spec_enc, cons_spec_enc):
605
- # loss = self.cons_loss(spec_enc, cons_spec_enc)
606
- # return dict(loss=loss)
607
 
608
- # def step(self, batch: dict, stage=Stage.NONE):
609
- # self.spec_enc_model.eval()
610
- # self.mol_enc_model.eval()
 
 
611
 
612
- # spec_enc = self.forward(batch, stage)
613
 
614
- # if stage == Stage.TEST:
615
- # mol_enc = self.mol_enc_model(batch['cand'])
616
- # return dict(spec_enc=spec_enc, mol_enc=mol_enc)
617
-
618
- # cons_spec_enc = self.spec_enc_model(batch['cons_spec'], batch['cons_n_peaks'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
 
620
- # losses = self.compute_loss(spec_enc, cons_spec_enc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- # return losses
623
-
624
-
625
- # def configure_optimizers(self):
626
- # return torch.optim.Adam(
627
- # self.ind_spec_enc_model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
628
- # )
629
- # def get_checkpoint_monitors(self) -> T.List[dict]:
630
- # monitors = [
631
- # {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True}
632
- # ]
633
- # return monitors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
  class CrossAttenContrastive(ContrastiveModel):
636
  def __init__(
 
10
  from massspecgym import utils
11
  from torch.nn.utils.rnn import pad_sequence
12
 
13
+ from flare.utils.loss import contrastive_loss, filip_loss_with_mask, global_infonce_loss, pcgrad_combine
14
  import flare.utils.models as model_utils
15
  from flare.utils.general import pad_graph_nodes, filip_similarity_batch
16
 
 
18
  import torch.nn.functional as F
19
 
20
  from torch_geometric.nn import global_mean_pool
21
+ import torch, dgllife
22
 
23
  class ContrastiveModel(RetrievalMassSpecGymModel):
24
  def __init__(
25
  self,
26
+ external_test: bool = False,
27
  **kwargs
28
  ):
29
  super().__init__(**kwargs)
30
  self.save_hyperparameters()
31
+ self.external_test = external_test
32
 
33
  if 'use_fp' not in self.hparams:
34
  self.hparams.use_fp = False
 
45
  self.result_dct = defaultdict(lambda: defaultdict(list))
46
 
47
  def forward(self, batch, stage):
48
+ if 'cand' in batch:
49
+ g = batch['cand']
50
+ elif 'mol' in batch:
51
+ g = batch['mol']
52
+ else:
53
+ g = None
54
+
55
  spec = batch[self.spec_view]
56
  n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
57
  spec_enc = self.spec_enc_model(spec, n_peaks)
58
 
59
+ if g is None:
60
+ mol_enc = None
61
+ return spec_enc, mol_enc
62
+
63
  fp = batch['fp'] if self.hparams.use_fp else None
64
+
65
+
66
+ f = self.mol_enc_model.GNN(g, g.ndata['h'])
67
+
68
  mol_enc = self.mol_enc_model(g, fp=fp)
69
 
70
  return spec_enc, mol_enc
 
77
  losses['contr_loss'] = contr_loss.detach().item()
78
 
79
  loss+=contr_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  losses['loss'] = loss
82
 
 
110
  # total loss
111
  self.log(
112
  f'{stage.to_pref()}loss',
113
+ outputs['loss'],
114
  batch_size=len(batch['identifier']),
115
  sync_dist=True,
116
  prog_bar=True,
 
148
  self.result_dct[i]['candidates'].extend(cands)
149
  self.result_dct[i]['scores'].extend(scores.cpu().tolist())
150
  self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
 
 
 
 
 
151
 
152
  def _compute_rank(self, scores, labels):
153
  if not any(labels):
 
157
  rank = np.count_nonzero(scores >=target_score)
158
  return rank
159
 
160
+ def _get_top_cand(self, scores, candidates):
161
+ return candidates[np.argmax(np.array(scores))]
162
+
163
  def on_test_epoch_end(self) -> None:
164
 
165
  self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
166
 
167
  # Compute rank
168
+ if not self.external_test:
169
+ self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
170
+
171
+ if self.external_test:
172
+ self.df_test.drop('labels', axis=1, inplace=True)
173
+ self.df_test['top_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row['scores'], row['candidates']), axis=1)
174
+
175
  if not self.df_test_path:
176
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
177
  self.df_test.to_pickle(self.df_test_path)
 
182
  {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
183
  ]
184
  return monitors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  class FilipContrastive(ContrastiveModel):
187
  def __init__(self,
 
233
  # Calculate scores
234
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
235
 
236
+ scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
237
  scores = torch.split(scores, list(id_to_ct.values()))
238
 
239
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
 
241
 
242
  return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ # ============================================================
246
+ # Combined FILIP + Global InfoNCE
247
+ # ============================================================
248
+ class FilipGlobalContrastive(ContrastiveModel):
249
+ def __init__(self, loss_mode="sum", loss_weight=1.0, agg_fn="mean", **kwargs):
250
+ """
251
+ Args:
252
+ loss_mode: str, one of ["sum", "weighted", "pcgrad"]
253
+ loss_weight: weight for global loss if using weighted sum
254
+ agg_fn: aggregation function for global InfoNCE ("mean", "max", "cls")
255
+ """
256
+ super().__init__(**kwargs)
257
+ self.loss_mode = loss_mode
258
+ self.loss_weight = loss_weight
259
+ self.agg_fn = agg_fn
260
 
261
+ # -------------- loss computation --------------
262
+ def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask, stage=Stage.NONE):
263
+ losses = {}
 
264
 
 
265
 
266
+ # fine-grained FILIP loss
267
+ loss_fine = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp)
268
+ # global InfoNCE loss
269
+ loss_global = global_infonce_loss(spec_enc, mol_enc, spec_mask, mol_mask,
270
+ temperature=self.hparams.contr_temp, agg_fn=self.agg_fn)
271
+
272
+ # choose combination mode
273
+ if self.loss_mode == "sum":
274
+ loss = loss_fine + loss_global
275
+ elif self.loss_mode == "weighted":
276
+ loss = loss_fine + self.loss_weight * loss_global
277
+ elif self.loss_mode == "pcgrad":
278
+
279
+ if stage == Stage.TRAIN:
280
+ # PCGrad over both losses (training only)
281
+ shared_params = list(self.spec_enc_model.parameters()) + list(self.mol_enc_model.parameters())
282
+ self.zero_grad(set_to_none=True)
283
+ loss = pcgrad_combine([loss_fine, loss_global], shared_params)
284
+ else:
285
+
286
+ loss = (loss_fine + loss_global).detach()
287
 
288
+ else:
289
+ raise ValueError(f"Unsupported loss_mode: {self.loss_mode}")
 
 
 
 
290
 
291
+ losses["loss"] = loss
292
+ losses["loss_fine"] = loss_fine.detach()
293
+ losses["loss_global"] = loss_global.detach()
294
+ return losses
295
 
296
+ def step(self, batch: dict, stage=Stage.NONE):
 
 
 
 
 
297
 
 
 
298
 
299
+ spec_enc, mol_enc = self.forward(batch, stage)
 
300
 
301
+ mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch["mol_n_nodes"])
302
+ spec_mask = ~torch.all((spec_enc == -5), dim=-1)
303
+
304
+ if stage == Stage.TEST:
305
+ return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask)
 
 
 
 
306
 
307
+ losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask, stage=stage)
308
+ return losses
309
 
310
+ # -------------- TEST step with different score variants --------------
311
+ def test_step(self, batch, batch_idx):
312
+ identifiers = batch["identifier"]
313
+ cand_smiles = batch["cand_smiles"]
314
 
315
+ id_to_ct = defaultdict(int)
316
+ for i in identifiers:
317
+ id_to_ct[i] += 1
318
+ batch_ptr = torch.tensor(list(id_to_ct.values()), device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ outputs = self.step(batch, stage=Stage.TEST)
321
+ spec_enc = outputs["spec_enc"]
322
+ mol_enc = outputs["mol_enc"]
323
+ spec_mask = outputs["spec_mask"]
324
+ mol_mask = outputs["mol_mask"]
325
 
326
+ indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
327
 
328
+ # --- fine-grained score ---
329
+ fine_scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
330
+
331
+ # --- global cosine score ---
332
+ spec_global = (spec_enc * spec_mask.unsqueeze(-1)).sum(1) / spec_mask.sum(1, keepdim=True).clamp(min=1)
333
+ mol_global = (mol_enc * mol_mask.unsqueeze(-1)).sum(1) / mol_mask.sum(1, keepdim=True).clamp(min=1)
334
+ global_scores = F.cosine_similarity(spec_global, mol_global, dim=-1)
335
+
336
+ # --- combined scores (for evaluation) ---
337
+ combined_sum = fine_scores + global_scores
338
+ combined_weighted = fine_scores + self.loss_weight * global_scores
339
+ combined_pc = 0.5 * (fine_scores + global_scores) # simple average baseline
340
+
341
+ scores_dict = {
342
+ "fine": fine_scores,
343
+ "global": global_scores,
344
+ "sum": combined_sum,
345
+ "weighted": combined_weighted,
346
+ "avg": combined_pc,
347
+ }
348
+
349
+ # split back per identifier
350
+ for key in scores_dict:
351
+ scores_dict[key] = torch.split(scores_dict[key], list(id_to_ct.values()))
352
+
353
+ cand_smiles = utils.unbatch_list(batch["cand_smiles"], indexes)
354
+ labels = utils.unbatch_list(batch["label"], indexes)
355
+
356
+ return dict(
357
+ identifiers=list(id_to_ct.keys()),
358
+ scores=scores_dict,
359
+ cand_smiles=cand_smiles,
360
+ labels=labels,
361
+ )
362
 
363
+ def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
364
+ """
365
+ Collects test batch outputs and stores them in self.result_dct.
366
+ Supports both:
367
+ - Single score list format (legacy)
368
+ - Dict of multiple score variants (new)
369
+ """
370
+ identifiers = outputs["identifiers"]
371
+ cand_smiles = outputs["cand_smiles"]
372
+ labels = outputs["labels"]
373
+ scores_out = outputs["scores"]
374
+
375
+ for k, (i, cands, l) in enumerate(zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels'])):
376
+ self.result_dct[i]['candidates'].extend(cands)
377
+ self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
378
 
379
+ for variant_name, score_list in scores_out.items():
380
+ self.result_dct[i][f"scores_{variant_name}"].extend(score_list[k].cpu().tolist())
381
+
382
+ def on_test_epoch_end(self) -> None:
383
+ """
384
+ Combine results into one DataFrame with one row per identifier.
385
+ Adds rank/top_cand columns for each score variant.
386
+ """
387
+ records = []
388
+ for identifier, val in self.result_dct.items():
389
+ row = {"identifier": identifier, "candidates": val["candidates"]}
390
+ if not self.external_test:
391
+ row["labels"] = val["labels"]
392
+
393
+ # For every scores_* key, compute rank or top candidate
394
+ for key, scores in val.items():
395
+ if not key.startswith("scores_"):
396
+ continue
397
+ variant = key.replace("scores_", "")
398
+ if not self.external_test:
399
+ row[f"rank_{variant}"] = self._compute_rank(scores, val["labels"])
400
+ else:
401
+ row[f"top_cand_{variant}"] = self._get_top_cand(scores, val["candidates"])
402
+ row[key] = scores
403
+ records.append(row)
404
+
405
+ self.df_test = pd.DataFrame(records)
406
+
407
+ if self.external_test and "labels" in self.df_test.columns:
408
+ self.df_test.drop(columns=["labels"], inplace=True)
409
+
410
+ # Save once
411
+ if not getattr(self, "df_test_path", None):
412
+ self.df_test_path = os.path.join(self.hparams["experiment_dir"], "result_combined.pkl")
413
+
414
+ self.df_test.to_pickle(self.df_test_path)
415
 
416
  class CrossAttenContrastive(ContrastiveModel):
417
  def __init__(
flare/models/mol_encoder.py CHANGED
@@ -12,7 +12,7 @@ class MolEnc(nn.Module):
12
 
13
  self.return_emb = False
14
 
15
- if args.model in ('filipContrastive', 'crossAttenContrastive'):
16
  self.return_emb = True
17
 
18
  dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
@@ -46,4 +46,5 @@ class MolEnc(nn.Module):
46
  h1 = self.dropout(h1)
47
 
48
  return h1
 
49
 
 
12
 
13
  self.return_emb = False
14
 
15
+ if args.model in ('filipContrastive', 'crossAttenContrastive', 'filipGlobalContrastive'):
16
  self.return_emb = True
17
 
18
  dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
 
46
  h1 = self.dropout(h1)
47
 
48
  return h1
49
+
50
 
flare/models/spec_encoder.py CHANGED
@@ -111,7 +111,7 @@ class SpecFormulaTransformer(nn.Module):
111
  in_dim+=1
112
 
113
  self.returnEmb = False
114
- if args.model in ('crossAttenContrastive', 'filipContrastive'):
115
  self.returnEmb = True
116
  assert(args.use_cls == False)
117
 
@@ -128,7 +128,7 @@ class SpecFormulaTransformer(nn.Module):
128
  out_dim = args.final_embedding_dim
129
  self.fc = nn.Linear(args.formula_dims[-1], out_dim)
130
 
131
- def forward(self, spec, n_peaks):
132
  h = self.formulaEnc(spec)
133
  pad = (spec == -5)
134
  pad = torch.all(pad, -1)
@@ -154,7 +154,6 @@ class SpecFormulaTransformer(nn.Module):
154
  h = self.fc(h)
155
 
156
  return h
157
-
158
  class SpecFormula_mz_Encoder(nn.Module):
159
  '''
160
  Encodes formula and mz_int
 
111
  in_dim+=1
112
 
113
  self.returnEmb = False
114
+ if args.model in ('crossAttenContrastive', 'filipContrastive', 'filipGlobalContrastive'):
115
  self.returnEmb = True
116
  assert(args.use_cls == False)
117
 
 
128
  out_dim = args.final_embedding_dim
129
  self.fc = nn.Linear(args.formula_dims[-1], out_dim)
130
 
131
+ def forward(self, spec, n_peaks=None):
132
  h = self.formulaEnc(spec)
133
  pad = (spec == -5)
134
  pad = torch.all(pad, -1)
 
154
  h = self.fc(h)
155
 
156
  return h
 
157
  class SpecFormula_mz_Encoder(nn.Module):
158
  '''
159
  Encodes formula and mz_int
flare/params_filipGlobal.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment setup
2
+ job_key: ''
3
+ run_name: 'filip-global'
4
+ run_details: ""
5
+ project_name: ''
6
+ wandb_entity_name: 'mass-spec-ml'
7
+ no_wandb: True
8
+ seed: 42
9
+ debug: False
10
+ checkpoint_pth:
11
+
12
+ # Training setup
13
+ max_epochs: 2000
14
+ accelerator: 'gpu'
15
+ devices: [1]
16
+ log_every_n_steps: 250
17
+ val_check_interval: 1.0
18
+
19
+ # Data paths
20
+ candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
+ dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
+ subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
+ split_pth:
24
+ fp_dir_pth:
25
+ partial_checkpoint: ""
26
+
27
+ # General hyperparameters
28
+ batch_size: 64 #64
29
+ lr: 2.881339661302105e-05 # 5.0e-05
30
+ weight_decay: 1.8376229667330708e-05
31
+ contr_temp: 0.022772534845886608 # 0.022772534845886608 # 0.05
32
+ num_workers: 50
33
+
34
+
35
+ # FILIP_GLOBAL model parameters
36
+ loss_mode: "pcgrad"
37
+ agg_fn: "mean"
38
+ loss_weight: 1.1
39
+
40
+
41
+ ############################## Data transforms ##############################
42
+ # - Spectra
43
+ spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
44
+ formula_source: 'default' # magma_1, magma_all, sirius, default
45
+ # 1. Binner
46
+ max_mz: 1000
47
+ bin_width: 1
48
+ mask_peak_ratio: 0.00
49
+
50
+ # 2. SpecFormula
51
+ element_list: ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']
52
+ add_intensities: True
53
+
54
+ # - Molecule
55
+ molecule_view: "MolGraph"
56
+ atom_feature: 'full'
57
+ bond_feature: 'full'
58
+
59
+
60
+ ############################## Task and model ##############################
61
+ task: 'retrieval'
62
+ spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
63
+ mol_enc: "GNN"
64
+ model: filipGlobalContrastive #filipContrastive # "MultiviewContrastive"
65
+ contr_views: [['spec_enc', 'mol_enc']]
66
+ log_only_loss_at_stages: []
67
+ df_test_path: ""
68
+
69
+
70
+ # - Formula-based spec encoders
71
+ formula_dropout: 0.2
72
+ formula_dims: [512,256,512] #[512, 256, 512] #[64, 128, 256]
73
+ cross_attn_heads: 2
74
+ use_cls: False
75
+ peak_dropout: 0.2
76
+ formula_attn_heads: 4 # 2
77
+ formula_transformer_layers: 2 #2
78
+
79
+ # -- GAT params
80
+ attn_heads: [12,12,12]
81
+
82
+ # - Molecule encoder (GNN)
83
+ gnn_channels: [128, 256, 512] #[64,128,512]
84
+ gnn_type: "gcn"
85
+ # num_gnn_layers: 3
86
+ # gnn_hidden_dim: 512
87
+ gnn_dropout: 0.23234950970370824 #0.3
88
+
89
+
90
+ # - Spectra encoder (cross attention model)
91
+ # final_embedding_dim: 512
92
+ # fc_dropout: 0.4
93
+
94
+ # - Spectra Token encoder (mz-int token model)
95
+ # hidden_dims: [64, 256]
flare/run.sh CHANGED
@@ -1,3 +1,3 @@
1
- # python train.py
2
- python test.py --param_pth ../hparams.yaml
3
- # python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
 
1
+ # python train.py --param_pth params_filipGlobal.yaml
2
+ # python test.py --param_pth params_filipGlobal.yaml
3
+ python test.py --param_pth params_filipGlobal.yaml --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
flare/subformula_assign/run.sh CHANGED
@@ -1,6 +1,11 @@
1
- SPEC_FILES="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
2
- OUTPUT_DIR="/data/yzhouc01/spectra_data/subformulae"
 
 
 
 
 
3
  MAX_FORMULAE=60
4
- LABELS_FILE="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
5
 
6
  python assign_subformulae.py --spec-files $SPEC_FILES --output-dir $OUTPUT_DIR --max-formulae $MAX_FORMULAE --labels-file $LABELS_FILE
 
1
+ # SPEC_FILES="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
2
+ # OUTPUT_DIR="/data/yzhouc01/spectra_data/subformulae"
3
+ # MAX_FORMULAE=60
4
+ # LABELS_FILE="/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv"
5
+
6
+ SPEC_FILES="/data/yzhouc01/cancer/breast_cancer_data.tsv"
7
+ OUTPUT_DIR="/data/yzhouc01/cancer/subformulae"
8
  MAX_FORMULAE=60
9
+ LABELS_FILE="/data/yzhouc01/cancer/breast_cancer_data.tsv"
10
 
11
  python assign_subformulae.py --spec-files $SPEC_FILES --output-dir $OUTPUT_DIR --max-formulae $MAX_FORMULAE --labels-file $LABELS_FILE
flare/subformula_assign/utils/chem_utils.py CHANGED
@@ -181,6 +181,8 @@ def formula_to_dense(chem_formula: str) -> np.ndarray:
181
  """
182
  total_onehot = []
183
  for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
 
 
184
  # Convert num to int
185
  num = 1 if num == "" else int(num)
186
  one_hot = element_to_position[chem_symbol].reshape(1, -1)
@@ -257,6 +259,8 @@ def formula_to_dense(chem_formula: str) -> np.ndarray:
257
  """
258
  total_onehot = []
259
  for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
 
 
260
  # Convert num to int
261
  num = 1 if num == "" else int(num)
262
  one_hot = element_to_position[chem_symbol].reshape(1, -1)
 
181
  """
182
  total_onehot = []
183
  for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
184
+ if chem_symbol not in VALID_ELEMENTS: # yzc
185
+ continue
186
  # Convert num to int
187
  num = 1 if num == "" else int(num)
188
  one_hot = element_to_position[chem_symbol].reshape(1, -1)
 
259
  """
260
  total_onehot = []
261
  for (chem_symbol, num) in re.findall(CHEM_FORMULA_SIZE, chem_formula):
262
+ if chem_symbol not in VALID_ELEMENTS: # yzc
263
+ continue
264
  # Convert num to int
265
  num = 1 if num == "" else int(num)
266
  one_hot = element_to_position[chem_symbol].reshape(1, -1)
flare/test.py CHANGED
@@ -29,6 +29,8 @@ parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['
29
  parser.add_argument('--df_test_pth', type=str, help='result file name')
30
  parser.add_argument('--exp_dir', type=str)
31
  parser.add_argument('--candidates_pth', type=str)
 
 
32
  def main(params):
33
  # Seed everything
34
  pl.seed_everything(params['seed'])
@@ -58,6 +60,7 @@ def main(params):
58
 
59
  model = get_model(params['model'], params)
60
  model.df_test_path = params['df_test_path']
 
61
 
62
  # Init trainer
63
  trainer = Trainer(
@@ -109,7 +112,12 @@ if __name__ == "__main__":
109
  params['checkpoint_pth'] = checkpoint_path
110
  break
111
  assert(params['checkpoint_pth'] != '')
112
-
 
 
 
 
 
113
  if args.candidates_pth:
114
  params['candidates_pth'] = args.candidates_pth
115
  if args.df_test_pth:
 
29
  parser.add_argument('--df_test_pth', type=str, help='result file name')
30
  parser.add_argument('--exp_dir', type=str)
31
  parser.add_argument('--candidates_pth', type=str)
32
+ parser.add_argument('--external_test', action='store_true', help='whether the test set is external data without labels')
33
+
34
  def main(params):
35
  # Seed everything
36
  pl.seed_everything(params['seed'])
 
60
 
61
  model = get_model(params['model'], params)
62
  model.df_test_path = params['df_test_path']
63
+ model.external_test = params['external_test']
64
 
65
  # Init trainer
66
  trainer = Trainer(
 
112
  params['checkpoint_pth'] = checkpoint_path
113
  break
114
  assert(params['checkpoint_pth'] != '')
115
+
116
+ if args.external_test:
117
+ params['external_test'] = True
118
+ else:
119
+ params['external_test'] = False
120
+
121
  if args.candidates_pth:
122
  params['candidates_pth'] = args.candidates_pth
123
  if args.df_test_pth:
flare/tune.py CHANGED
@@ -231,7 +231,7 @@ def main(args):
231
 
232
  # now = datetime.datetime.now().strftime("%Y%m%d")
233
  # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
234
- base_dir = "/data/yzhouc01/FILIP-MS/experiments/20250916_simple_model_optuna"
235
  os.makedirs(base_dir, exist_ok=True)
236
  params["experiment_dir"] = base_dir
237
 
 
231
 
232
  # now = datetime.datetime.now().strftime("%Y%m%d")
233
  # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
234
+ base_dir = "../experiments/20250916_simple_model_optuna"
235
  os.makedirs(base_dir, exist_ok=True)
236
  params["experiment_dir"] = base_dir
237
 
flare/utils/case_study_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+ from rdkit import Chem
4
+ import multiprocessing as mp
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ import sys
9
+ import os
10
+ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
11
+ if parent_dir not in sys.path:
12
+ sys.path.insert(0, parent_dir)
13
+
14
+ database_to_path = {'fdb':"/data/yzhouc01/molecule_data/foodb_2020_04_07_csv/Compound.csv",
15
+ 'hmdb':"/data/yzhouc01/molecule_data/metabolites-2025-09-18.csv",
16
+ 'spectra_db':"/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex_processed.tsv",
17
+ 'bio_db':"/data/yzhouc01/molecule_data/bio_2023_07_11_smiles.csv",
18
+ 'coconut':"/data/yzhouc01/molecule_data/coconut_csv-05-2025.csv"}
19
+
20
+ db_to_mass_col = {'fdb':'exact_molecular_weight',
21
+ 'hmdb':'MONO_MASS',
22
+ 'spectra_db':'exact_molecular_weight',
23
+ 'bio_db':'exact_molecular_weight',
24
+ 'coconut':'exact_molecular_weight'}
25
+
26
+ db_to_smiles_col = {'fdb':'CANONICAL_SMILES',
27
+ 'hmdb':'CANONICAL_SMILES',
28
+ 'spectra_db':'CANONICAL_SMILES',
29
+ 'bio_db':'canonical_smiles',
30
+ 'coconut':'rdkit_canonical_smiles'}
31
+
32
+
33
+ _worker_instance = None
34
+
35
+
36
+ def _init_worker(databases, threshold):
37
+ """Run once per worker process to initialize shared CandidateAssignment."""
38
+ global _worker_instance
39
+ _worker_instance = CandidateAssignment(databases, threshold)
40
+
41
+
42
+ def _worker_retrieve_candidates(parent_mass):
43
+ """Use the global CandidateAssignment instance inside each worker."""
44
+ return _worker_instance.retrieve_candidates(parent_mass)
45
+
46
+
47
+ _worker_instance = None
48
+
49
+
50
+ def _init_worker(databases, threshold):
51
+ """Initialize global CandidateAssignment in each worker (silent)."""
52
+ global _worker_instance
53
+ _worker_instance = CandidateAssignment(databases, threshold, verbose=False)
54
+
55
+
56
+ def _worker_retrieve_candidates(parent_mass):
57
+ """Retrieve candidates using the worker's global CandidateAssignment."""
58
+ return _worker_instance.retrieve_candidates(parent_mass)
59
+
60
+
61
+ class CandidateAssignment:
62
+ def __init__(self, databases=None, threshold=0.01, verbose=True):
63
+ self.threshold = threshold
64
+ self.databases = []
65
+ self.verbose = verbose
66
+
67
+ for db in databases:
68
+ if db not in database_to_path:
69
+ raise ValueError(
70
+ f"Database {db} not recognized. Available: {list(database_to_path.keys())}"
71
+ )
72
+ if not os.path.exists(database_to_path[db]):
73
+ raise ValueError(f"Database file for {db} not found at {database_to_path[db]}")
74
+ self.databases.append(db)
75
+
76
+ # Only print in main process
77
+ if self.verbose and mp.current_process().name == "MainProcess":
78
+ print(f"[{os.getpid()}] Loading databases: {self.databases}")
79
+
80
+ self.db_dfs = {}
81
+ self._load_databases()
82
+
83
+ def _load_databases(self):
84
+ for db in self.databases:
85
+ path = database_to_path[db]
86
+ if path.endswith("tsv"):
87
+ df = pd.read_csv(path, sep="\t", low_memory=False)
88
+ elif path.endswith("csv"):
89
+ df = pd.read_csv(path, low_memory=False)
90
+ else:
91
+ if self.verbose and mp.current_process().name == "MainProcess":
92
+ print(f"Unable to load database: {db}")
93
+ continue
94
+
95
+ # make sure required columns exist
96
+ required_cols = [db_to_mass_col[db], db_to_smiles_col[db]]
97
+ for col in required_cols:
98
+ if col not in df.columns:
99
+ raise ValueError(f"Column {col} not found in database {db}. {db} columns: {df.columns.tolist()}")
100
+
101
+ # convert to proper types
102
+ df[db_to_mass_col[db]] = pd.to_numeric(df[db_to_mass_col[db]], errors='coerce')
103
+
104
+ self.db_dfs[db] = df
105
+
106
+ # Only print in main process
107
+ if self.verbose and mp.current_process().name == "MainProcess":
108
+ print(f"[{os.getpid()}] Loaded {db} with {len(df)} entries.")
109
+
110
+ def retrieve_candidates(self, parent_mass):
111
+ """Retrieve SMILES candidates for a single parent mass."""
112
+ ub = parent_mass + self.threshold
113
+ lb = parent_mass - self.threshold
114
+
115
+ smiles_list = []
116
+ for db_name, df in self.db_dfs.items():
117
+ select_rows = df[
118
+ (df[db_to_mass_col[db_name]] >= lb)
119
+ & (df[db_to_mass_col[db_name]] <= ub)
120
+ ]
121
+ smiles_list.extend(select_rows[db_to_smiles_col[db_name]].tolist())
122
+
123
+ smiles_list = list(set(smiles_list))
124
+ return parent_mass, smiles_list
125
+
126
+ def retrieve_candidates_batch(self, parent_masses, n_workers=25, chunksize=10):
127
+ """Parallel batch retrieval with silent workers."""
128
+ with mp.Pool(
129
+ processes=n_workers,
130
+ initializer=_init_worker,
131
+ initargs=(self.databases, self.threshold),
132
+ ) as pool:
133
+ results = list(
134
+ tqdm(
135
+ pool.imap(_worker_retrieve_candidates, parent_masses, chunksize=chunksize),
136
+ total=len(parent_masses),
137
+ desc="Retrieving candidates",
138
+ )
139
+ )
140
+ return {r[0]: r[1] for r in results}
141
+
142
+ # P_TBL = Chem.GetPeriodicTable()
143
+ # ELECTRON_MASS = 0.00054858
144
+ # VALID_ELEMENTS = [
145
+ # "C",
146
+ # "H",
147
+ # "As",
148
+ # "B",
149
+ # "Br",
150
+ # "Cl",
151
+ # "Co",
152
+ # "F",
153
+ # "Fe",
154
+ # "I",
155
+ # "K",
156
+ # "N",
157
+ # "Na",
158
+ # "O",
159
+ # "P",
160
+ # "S",
161
+ # "Se",
162
+ # "Si",
163
+ # ]
164
+ # VALID_MONO_MASSES = np.array(
165
+ # [P_TBL.GetMostCommonIsotopeMass(i) for i in VALID_ELEMENTS]
166
+ # )
167
+ # CHEM_MASSES = VALID_MONO_MASSES[:, None]
168
+ # ELEMENT_TO_MASS = dict(zip(VALID_ELEMENTS, CHEM_MASSES.squeeze()))
169
+
170
+ # adduct_to_mass = {
171
+ # "[M+H]+": ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
172
+ # "[M+Na]+": ELEMENT_TO_MASS["Na"] - ELECTRON_MASS,
173
+ # "[M+K]+": ELEMENT_TO_MASS["K"] - ELECTRON_MASS,
174
+ # "[M-H2O+H]+": -ELEMENT_TO_MASS["O"] - ELEMENT_TO_MASS["H"] - ELECTRON_MASS,
175
+ # "[M+H3N+H]+": ELEMENT_TO_MASS["N"] + ELEMENT_TO_MASS["H"] * 4 - ELECTRON_MASS,
176
+ # "[M]+": 0 - ELECTRON_MASS,
177
+ # "[M-H4O2+H]+": -ELEMENT_TO_MASS["O"] * 2 - ELEMENT_TO_MASS["H"] * 3 - ELECTRON_MASS,
178
+ # "[M-H]-": ELEMENT_TO_MASS["H"] + ELECTRON_MASS,
179
+ # "[M+H2O+H]+":ELEMENT_TO_MASS["O"] * 2 + ELEMENT_TO_MASS["H"] * 2 - ELECTRON_MASS,
180
+ # }
181
+
182
+
183
+ # def calculate_parent_mass(precursor_mz, adduct):
184
+ # if adduct not in adduct_to_mass:
185
+ # print(f'{adduct} not supported, returning original precursor_mz')
186
+ # return precursor_mz + adduct_to_mass[adduct]
187
+
188
+
189
+ if __name__ == "__main__":
190
+ # get_mol_mass_for_combined()
191
+ ca = CandidateAssignment(databases=['hmdb'])
192
+ candidates = ca.retrieve_candidates(parent_mass=180.0634, threshold=0.01)
193
+ print(candidates)
flare/utils/general.py CHANGED
@@ -2,37 +2,69 @@ import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
 
 
 
5
  def pad_graph_nodes(mol_enc, g_n_nodes):
6
  """
7
  Args:
8
- mol_enc: 2D tensor of shape (sum_nodes, D)
9
- Node embeddings for each molecule.
10
- g_n_nodes: list[int] Number of nodes per graph (len = B)
11
 
12
  Returns:
13
- padded: (B, max_nodes, D) tensor
14
  mask: (B, max_nodes) bool tensor, True for valid nodes
15
  """
16
-
17
- # Already concatenated: shape (sum_nodes, D)
18
  B = len(g_n_nodes)
19
  D = mol_enc.shape[1]
20
  max_nodes = max(g_n_nodes)
21
- padded = mol_enc.new_zeros((B, max_nodes, D))
22
- mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
23
 
 
 
 
 
 
 
 
 
24
  idx = 0
25
  for i, n in enumerate(g_n_nodes):
26
  padded[i, :n] = mol_enc[idx:idx+n]
27
  mask[i, :n] = True
28
  idx += n
 
29
  return padded, mask
30
 
31
- import torch
32
- import torch.nn.functional as F
33
 
34
- import torch
35
- import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def filip_similarity_batch(
38
  image_tokens,
@@ -127,60 +159,64 @@ def filip_similarity_batch(
127
  return similarity
128
 
129
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # def filip_similarity_batch(image_tokens, text_tokens, mask_image, mask_text):
132
- # """
133
- # Compute FILIP similarity for batches of image and text token embeddings.
134
-
135
- # Args:
136
- # image_tokens: (B, N_img, D) float tensor
137
- # text_tokens: (B, N_text, D) float tensor
138
- # mask_image: (B, N_img) bool tensor
139
- # mask_text: (B, N_text) bool tensor
140
-
141
- # Returns:
142
- # similarities: (B,) float tensor of similarity scores
143
- # """
144
- # B, N_img, D = image_tokens.shape
145
- # N_text = text_tokens.shape[1]
146
-
147
- # # Normalize tokens
148
- # image_norm = F.normalize(image_tokens, p=2, dim=-1) # (B, N_img, D)
149
- # text_norm = F.normalize(text_tokens, p=2, dim=-1) # (B, N_text, D)
150
-
151
- # # Compute batched cosine similarity matrices
152
- # # Result shape: (B, N_img, N_text)
153
- # sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
154
 
155
- # # Expand masks for broadcasting
156
- # mask_image_exp = mask_image.unsqueeze(2) # (B, N_img, 1)
157
- # mask_text_exp = mask_text.unsqueeze(1) # (B, 1, N_text)
158
- # valid_mask = mask_image_exp & mask_text_exp # (B, N_img, N_text)
 
 
159
 
160
- # # Mask invalid positions by setting them to -inf
161
- # sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
162
 
163
- # # Max over text tokens per image token: (B, N_img)
164
- # max_sim_img, _ = sim_matrix_masked.max(dim=2)
 
165
 
166
- # # Max over image tokens per text token: (B, N_text)
167
- # max_sim_text, _ = sim_matrix_masked.max(dim=1)
 
 
168
 
169
- # # Replace -inf (no valid tokens) with zeros to avoid NaNs
170
- # max_sim_img[max_sim_img == float('-inf')] = 0
171
- # max_sim_text[max_sim_text == float('-inf')] = 0
 
172
 
173
- # # Sum over valid tokens and divide by number of valid tokens (avoid division by zero)
174
- # sum_img = (max_sim_img * mask_image).sum(dim=1)
175
- # count_img = mask_image.sum(dim=1).clamp(min=1).float()
176
 
177
- # sum_text = (max_sim_text * mask_text).sum(dim=1)
178
- # count_text = mask_text.sum(dim=1).clamp(min=1).float()
 
179
 
180
- # avg_img = sum_img / count_img
181
- # avg_text = sum_text / count_text
182
 
183
- # # Final similarity per batch element
184
- # similarity = (avg_img + avg_text) / 2
 
185
 
186
- # return similarity
 
 
 
2
  from torch import nn
3
  import torch.nn.functional as F
4
 
5
+
6
+
7
  def pad_graph_nodes(mol_enc, g_n_nodes):
8
  """
9
  Args:
10
+ mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs
11
+ g_n_nodes: list[int], number of nodes per graph
 
12
 
13
  Returns:
14
+ padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes
15
  mask: (B, max_nodes) bool tensor, True for valid nodes
16
  """
 
 
17
  B = len(g_n_nodes)
18
  D = mol_enc.shape[1]
19
  max_nodes = max(g_n_nodes)
 
 
20
 
21
+ # Create output with same requires_grad as input
22
+ padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device)
23
+
24
+ # Force gradient tracking by making this a non-leaf tensor
25
+ padded = padded + mol_enc.new_zeros(1).requires_grad_(True)
26
+
27
+ mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device)
28
+
29
  idx = 0
30
  for i, n in enumerate(g_n_nodes):
31
  padded[i, :n] = mol_enc[idx:idx+n]
32
  mask[i, :n] = True
33
  idx += n
34
+
35
  return padded, mask
36
 
 
 
37
 
38
+
39
+
40
+ # def pad_graph_nodes(mol_enc, g_n_nodes):
41
+ # """
42
+ # Args:
43
+ # mol_enc: 2D tensor of shape (sum_nodes, D)
44
+ # Node embeddings for each molecule.
45
+ # g_n_nodes: list[int] Number of nodes per graph (len = B)
46
+
47
+ # Returns:
48
+ # padded: (B, max_nodes, D) tensor
49
+ # mask: (B, max_nodes) bool tensor, True for valid nodes
50
+ # """
51
+
52
+ # # Already concatenated: shape (sum_nodes, D)
53
+ # B = len(g_n_nodes)
54
+ # D = mol_enc.shape[1]
55
+ # max_nodes = max(g_n_nodes)
56
+ # padded = mol_enc.new_zeros((B, max_nodes, D))
57
+ # mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
58
+
59
+ # idx = 0
60
+ # for i, n in enumerate(g_n_nodes):
61
+ # padded[i, :n] = mol_enc[idx:idx+n]
62
+ # mask[i, :n] = True
63
+ # idx += n
64
+ # return padded, mask
65
+
66
+
67
+
68
 
69
  def filip_similarity_batch(
70
  image_tokens,
 
159
  return similarity
160
 
161
 
162
+ def filip_similarity_single(
163
+ image_tokens,
164
+ text_tokens,
165
+ reduction="mean", # "mean", "topk", "softmax", or "geom"
166
+ k=5,
167
+ temperature=0.05,
168
+ eps=1e-6
169
+ ):
170
+ """
171
+ Compute FILIP similarity for a single image and text pair (no masks).
172
 
173
+ Args:
174
+ image_tokens: (N_img, D) float tensor
175
+ text_tokens: (N_text, D) float tensor
176
+ reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
177
+ k: int, used if reduction == "topk"
178
+ temperature: float, used if reduction == "softmax"
179
+ eps: float, small constant for numerical stability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ Returns:
182
+ similarity: float scalar tensor
183
+ """
184
+ # Normalize tokens
185
+ image_norm = F.normalize(image_tokens, p=2, dim=-1)
186
+ text_norm = F.normalize(text_tokens, p=2, dim=-1)
187
 
188
+ # (N_img, N_text) cosine similarity matrix
189
+ sim_matrix = torch.matmul(image_norm, text_norm.t())
190
 
191
+ # Max similarity for each token (image->text and text->image)
192
+ max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,)
193
+ max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,)
194
 
195
+ # Aggregation helper
196
+ def aggregate(max_sim):
197
+ if reduction == "mean":
198
+ return max_sim.mean()
199
 
200
+ elif reduction == "topk":
201
+ k_eff = min(k, max_sim.numel())
202
+ topk_vals, _ = torch.topk(max_sim, k_eff)
203
+ return topk_vals.mean()
204
 
205
+ elif reduction == "softmax":
206
+ weights = torch.softmax(max_sim / temperature, dim=0)
207
+ return (weights * max_sim).sum()
208
 
209
+ elif reduction == "geom":
210
+ vals = max_sim.clamp(min=eps)
211
+ return torch.exp(torch.log(vals).mean())
212
 
213
+ else:
214
+ raise ValueError(f"Unknown reduction type: {reduction}")
215
 
216
+ # Aggregate both directions
217
+ avg_img = aggregate(max_sim_img)
218
+ avg_text = aggregate(max_sim_text)
219
 
220
+ # Final similarity (scalar)
221
+ similarity = (avg_img + avg_text) / 2
222
+ return similarity
flare/utils/loss.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
  def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
6
  v1_norm = torch.norm(v1, dim=1, keepdim=True)
@@ -76,10 +77,6 @@ class fp_loss:
76
  return 1 - torch.mean(sim)
77
 
78
 
79
- import torch
80
- import torch.nn.functional as F
81
- import torch.distributed as dist
82
-
83
  # ---------- Utility ----------
84
  def _safe_divide(num, denom, eps=1e-8):
85
  return num / (denom + eps)
@@ -154,3 +151,97 @@ def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07):
154
 
155
  return 0.5 * (loss_a2b + loss_b2a)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import torch.distributed as dist
5
 
6
  def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
7
  v1_norm = torch.norm(v1, dim=1, keepdim=True)
 
77
  return 1 - torch.mean(sim)
78
 
79
 
 
 
 
 
80
  # ---------- Utility ----------
81
  def _safe_divide(num, denom, eps=1e-8):
82
  return num / (denom + eps)
 
151
 
152
  return 0.5 * (loss_a2b + loss_b2a)
153
 
154
+
155
+
156
+ def global_infonce_loss(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07, agg_fn="mean"):
157
+ """
158
+ Global InfoNCE loss (CLIP-style) for modalities A and B.
159
+
160
+ Args:
161
+ a_tokens: (B, N_a, D)
162
+ b_tokens: (B, N_b, D)
163
+ mask_a: (B, N_a) bool (True = valid)
164
+ mask_b: (B, N_b) bool (True = valid)
165
+ temperature: scalar
166
+ agg_fn: "mean" | "max" | "cls" | callable -> how to aggregate tokens into one vector
167
+
168
+ Returns:
169
+ scalar loss
170
+ """
171
+ device = a_tokens.device
172
+ B, N_a, D = a_tokens.shape
173
+ N_b = b_tokens.shape[1]
174
+
175
+ # ---- Normalize token embeddings ----
176
+ a = F.normalize(a_tokens, dim=-1)
177
+ b = F.normalize(b_tokens, dim=-1)
178
+
179
+ # ---- Aggregate per sample ----
180
+ if callable(agg_fn):
181
+ a_global = agg_fn(a, mask_a) # custom aggregation
182
+ b_global = agg_fn(b, mask_b)
183
+ elif agg_fn == "mean":
184
+ # masked mean
185
+ a_global = (a * mask_a.unsqueeze(-1)).sum(dim=1) / mask_a.sum(dim=1, keepdim=True).clamp(min=1)
186
+ b_global = (b * mask_b.unsqueeze(-1)).sum(dim=1) / mask_b.sum(dim=1, keepdim=True).clamp(min=1)
187
+ elif agg_fn == "max":
188
+ a_global = (a.masked_fill(~mask_a.unsqueeze(-1), float('-inf'))).max(dim=1).values
189
+ b_global = (b.masked_fill(~mask_b.unsqueeze(-1), float('-inf'))).max(dim=1).values
190
+ elif agg_fn == "cls":
191
+ # use first valid token as "cls"
192
+ a_global = a[:, 0, :]
193
+ b_global = b[:, 0, :]
194
+ else:
195
+ raise ValueError(f"Unknown agg_fn: {agg_fn}")
196
+
197
+ # ---- Compute cosine similarity matrix ----
198
+ a_global = F.normalize(a_global, dim=-1)
199
+ b_global = F.normalize(b_global, dim=-1)
200
+ logits = (a_global @ b_global.T) / temperature # (B, B)
201
+
202
+ # ---- InfoNCE loss ----
203
+ labels = torch.arange(B, device=device)
204
+ loss_a2b = F.cross_entropy(logits, labels)
205
+ loss_b2a = F.cross_entropy(logits.T, labels)
206
+ loss = 0.5 * (loss_a2b + loss_b2a)
207
+
208
+ return loss
209
+
210
+
211
+ # ---------- PCGrad utility ----------
212
+ def pcgrad_combine(losses, shared_params):
213
+ """
214
+ Compute PCGrad combined gradient for a list of scalar losses.
215
+ losses: list of scalar loss tensors
216
+ shared_params: list of parameters to project/aggregate gradients for
217
+ returns: scalar combined loss for logging (mean)
218
+ """
219
+ grads_list = [torch.autograd.grad(l, shared_params, retain_graph=True, allow_unused=True)
220
+ for l in losses]
221
+
222
+ # flatten
223
+ flat_grads = [torch.cat([g.reshape(-1) for g in grads if g is not None]) for grads in grads_list]
224
+ projected = [fg.clone() for fg in flat_grads]
225
+
226
+ # project conflicting grads
227
+ for i in range(len(flat_grads)):
228
+ for j in range(len(flat_grads)):
229
+ if i == j:
230
+ continue
231
+ dot = (projected[i] * projected[j]).sum()
232
+ if dot < 0:
233
+ proj = dot / (projected[j].norm() ** 2 + 1e-12)
234
+ projected[i] = projected[i] - proj * projected[j]
235
+
236
+ # sum projected grads
237
+ final_grad = sum(projected)
238
+ # assign to params
239
+ pointer = 0
240
+ for p in shared_params:
241
+ if p.requires_grad:
242
+ numel = p.numel()
243
+ p.grad = final_grad[pointer:pointer + numel].view_as(p).clone()
244
+ pointer += numel
245
+
246
+ # return average loss for logging only
247
+ return sum(losses) / len(losses)
flare/utils/models.py CHANGED
@@ -1,7 +1,7 @@
1
  from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
  from flare.models.mol_encoder import MolEnc
3
  from flare.models.encoders import MLP
4
- from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive
5
 
6
  def get_spec_encoder(spec_enc:str, args):
7
  return {"MLP_BIN": SpecEncMLP_BIN,
@@ -28,6 +28,8 @@ def get_model(model:str,
28
  model = CrossAttenContrastive(**params)
29
  elif model == "filipContrastive":
30
  model = FilipContrastive(**params)
 
 
31
  else:
32
  raise Exception(f"Model {model} not implemented.")
33
 
 
1
  from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
  from flare.models.mol_encoder import MolEnc
3
  from flare.models.encoders import MLP
4
+ from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
5
 
6
  def get_spec_encoder(spec_enc:str, args):
7
  return {"MLP_BIN": SpecEncMLP_BIN,
 
28
  model = CrossAttenContrastive(**params)
29
  elif model == "filipContrastive":
30
  model = FilipContrastive(**params)
31
+ elif model == "filipGlobalContrastive":
32
+ model = FilipGlobalContrastive(**params)
33
  else:
34
  raise Exception(f"Model {model} not implemented.")
35
 
flare/utils/mol_search.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pickle
4
+ from typing import Callable, List, Dict, Any, Optional
5
+ from rdkit import Chem
6
+ import faiss
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from tqdm import tqdm
10
+ import dgl
11
+
12
+ class MoleculeDataset(Dataset):
13
+ """Converts SMILES to DGL graphs in parallel via DataLoader workers."""
14
+
15
+ def __init__(self, smiles_dict, smiles_preprocess):
16
+ self.items = list(smiles_dict.items())
17
+ self.smiles_preprocess = smiles_preprocess
18
+
19
+ def __len__(self):
20
+ return len(self.items)
21
+
22
+ def __getitem__(self, idx):
23
+ mol_id, smi = self.items[idx]
24
+ try:
25
+ graph = self.smiles_preprocess(smi)
26
+ return mol_id, graph, None
27
+ except Exception as e:
28
+ return mol_id, None, str(e)
29
+
30
+
31
+ def collate_graphs(batch):
32
+ """Custom collation: keep only valid graphs."""
33
+ valid = [(mid, g) for mid, g, err in batch if g is not None]
34
+ if not valid:
35
+ return [], None
36
+ mol_ids, graphs = zip(*valid)
37
+ batched_graph = dgl.batch(graphs)
38
+ return mol_ids, batched_graph
39
+
40
+
41
+
42
+ class SpectraMoleculeRetriever:
43
+ """
44
+ Two-stage spectra–molecule retrieval system with hierarchical metadata filtering:
45
+ 1. Coarse retrieval via FAISS on global embeddings.
46
+ 2. Fine-grained reranking via custom similarity (e.g., FILIP alignment).
47
+ 3. Supports fast subset search by class, superclass, or pathway.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ molecule_encoder,
53
+ spectra_encoder,
54
+ fine_similarity_fn: Callable[[Any, Any], float],
55
+ smiles_preprocess: Callable[[str], Any],
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ ):
58
+ """
59
+ Args:
60
+ molecule_encoder: callable with methods:
61
+ - global_embedding(mol)
62
+ - node_embeddings(mol)
63
+ spectra_encoder: callable with methods:
64
+ - global_embedding(spectrum)
65
+ - token_embeddings(spectrum)
66
+ fine_similarity_fn: function for fine-grained similarity.
67
+ smiles_preprocess: preprocessing function for SMILES → molecule object.
68
+ device: where to run encoders.
69
+ """
70
+ self.molecule_encoder = molecule_encoder
71
+ self.spectra_encoder = spectra_encoder
72
+ self.fine_similarity_fn = fine_similarity_fn
73
+ self.smiles_preprocess = smiles_preprocess
74
+ self.device = device
75
+
76
+ # Storage
77
+ self.molecule_db: Dict[str, Any] = {} # mol_id → mol object
78
+ self.node_cache: Dict[str, Any] = {} # mol_id → node embeddings
79
+ self.metadata: Dict[str, Dict[str, List[str]]] = {} # e.g. {"class": {"lipid": [mol1, mol2], ...}}
80
+
81
+ self.molecule_ids: Optional[np.ndarray] = None
82
+ self.global_embeddings: Optional[np.ndarray] = None
83
+ self.index: Optional[faiss.Index] = None
84
+ self.smiles_dict: Optional[Dict[str, str]] = None # mol_id → smiles
85
+
86
+ self.failed_mols = []
87
+
88
+ # set model to eval mode and move to device
89
+ self.molecule_encoder.eval()
90
+ self.spectra_encoder.eval()
91
+
92
+ self.molecule_encoder.to(self.device)
93
+ self.spectra_encoder.to(self.device)
94
+
95
+ # -------------------------------
96
+ # Database building & saving
97
+ # -------------------------------
98
+ def build_database(
99
+ self,
100
+ smiles_dict: dict,
101
+ metadata=None,
102
+ cache_nodes: bool = False,
103
+ batch_size: int = 64,
104
+ num_workers: int = 25,
105
+ pooling: str = "max", # or "sum", "mean"
106
+ ):
107
+ """
108
+ Parallelized database construction using PyTorch DataLoader for
109
+ SMILES → DGLGraph conversion and batched GPU encoding.
110
+
111
+ Args:
112
+ smiles_dict: dict {mol_id: smiles}
113
+ metadata: hierarchical dict for class/superclass/pathway
114
+ cache_nodes: if True, store node embeddings for fine-grained search
115
+ batch_size: number of molecules per GPU batch
116
+ num_workers: parallel CPU workers for SMILES parsing
117
+ pooling: global pooling type ("max" | "sum" | "mean")
118
+ """
119
+ print("Building molecule database with PyTorch DataLoader parallelization...")
120
+
121
+
122
+ # set up pooling
123
+ if pooling == "max":
124
+ self.pooling = dgl.nn.pytorch.glob.MaxPooling()
125
+ elif pooling == "sum":
126
+ self.pooling = dgl.nn.pytorch.glob.SumPooling()
127
+ elif pooling == "mean":
128
+ self.pooling = dgl.nn.pytorch.glob.MeanPooling()
129
+ else:
130
+ raise ValueError(f"Unsupported pooling: {pooling}")
131
+
132
+ dataset = MoleculeDataset(smiles_dict, self.smiles_preprocess)
133
+ loader = DataLoader(
134
+ dataset,
135
+ batch_size=batch_size,
136
+ shuffle=False,
137
+ num_workers=num_workers,
138
+ collate_fn=collate_graphs,
139
+ pin_memory=True,
140
+ )
141
+
142
+ mol_ids_all, mol_objs, mol_embs = [], [], []
143
+ failed_mols = []
144
+ node_cache = {}
145
+
146
+
147
+ with torch.no_grad():
148
+ for mol_ids, batched_graph in tqdm(loader, desc="Encoding molecules"):
149
+ if batched_graph is None:
150
+ # All failed in this batch
151
+ continue
152
+
153
+ try:
154
+ batched_graph = batched_graph.to(self.device)
155
+ node_repr = self.molecule_encoder(batched_graph, batched_graph.ndata['h'])
156
+ global_emb = self.pooling(batched_graph,node_repr)
157
+
158
+ # Normalize embeddings
159
+ emb_np = global_emb.detach().cpu().numpy()
160
+ emb_np /= np.linalg.norm(emb_np, axis=1, keepdims=True)
161
+
162
+ mol_ids_all.extend(mol_ids)
163
+ mol_objs.extend([batched_graph] * len(mol_ids))
164
+ mol_embs.append(emb_np)
165
+
166
+ # Optionally store node embeddings for fine-grained search
167
+ if cache_nodes:
168
+ # Split batched node embeddings into per-graph chunks
169
+ node_embs = dgl.unbatch(batched_graph)
170
+ for mol_id, mol_graph in zip(mol_ids, node_embs):
171
+ node_cache[mol_id] = mol_graph.ndata['h'].detach().cpu()
172
+ except Exception as e:
173
+ failed_mols.extend(mol_ids)
174
+ print(f"[Warning] Failed to encode batch with molecules {mol_ids}: {e}")
175
+ continue
176
+
177
+ if not mol_embs:
178
+ raise RuntimeError("No valid molecules were successfully encoded.")
179
+
180
+ self.failed_mols = failed_mols
181
+ self.smiles_dict = smiles_dict
182
+ self.molecule_db = dict(zip(mol_ids_all, mol_objs))
183
+ self.molecule_ids = np.array(mol_ids_all)
184
+ self.global_embeddings = np.concatenate(mol_embs, axis=0)
185
+ self.metadata = metadata or {}
186
+ self.node_cache.update(node_cache)
187
+
188
+ self._build_faiss_index()
189
+
190
+ print(f"Database built with {len(self.molecule_ids)} molecules "
191
+ f"({len(self.failed_mols) + (len(smiles_dict) - len(self.molecule_ids))} failed).")
192
+
193
+ def _build_faiss_index(self):
194
+ d = self.global_embeddings.shape[1]
195
+ self.index = faiss.IndexFlatIP(d)
196
+ self.index.add(self.global_embeddings)
197
+ print(f"FAISS index built with {len(self.molecule_ids)} embeddings.")
198
+
199
+ def save_database(self, path: str):
200
+ """Save molecule database and embeddings."""
201
+ data = {
202
+ "molecule_ids": self.molecule_ids,
203
+ "global_embeddings": self.global_embeddings,
204
+ "metadata": self.metadata,
205
+ "node_cache": self.node_cache,
206
+ "smiles_dict": self.smiles_dict,
207
+ }
208
+ with open(path, "wb") as f:
209
+ pickle.dump(data, f)
210
+ print(f"Database saved to {path}")
211
+
212
+ def load_database(self, path: str):
213
+ """Load molecule database and rebuild FAISS index."""
214
+ with open(path, "rb") as f:
215
+ data = pickle.load(f)
216
+ self.molecule_ids = data["molecule_ids"]
217
+ self.global_embeddings = data["global_embeddings"]
218
+ self.metadata = data.get("metadata", {})
219
+ self.node_cache = data.get("node_cache", {})
220
+ self.smiles_dict = data.get("smiles_dict", {})
221
+ self._build_faiss_index()
222
+ print(f"Database loaded from {path}")
223
+
224
+ # -------------------------------
225
+ # Filtering utilities
226
+ # -------------------------------
227
+ def _get_filtered_indices(self, subset: Optional[Dict[str, str]] = None) -> np.ndarray:
228
+ """
229
+ Retrieve indices for molecules matching a given metadata subset.
230
+ Example subset: {"class": "lipid"} or {"pathway": "glycolysis"}
231
+ """
232
+ if not subset:
233
+ return np.arange(len(self.molecule_ids))
234
+
235
+ key, value = next(iter(subset.items()))
236
+ if key not in self.metadata or value not in self.metadata[key]:
237
+ print(f"[Warning] No molecules found for {key}={value}")
238
+ return np.array([], dtype=int)
239
+
240
+ mol_ids = self.metadata[key][value]
241
+ id_to_idx = {m: i for i, m in enumerate(self.molecule_ids)}
242
+ selected = [id_to_idx[m] for m in mol_ids if m in id_to_idx]
243
+ return np.array(selected, dtype=int)
244
+
245
+ # -------------------------------
246
+ # Retrieval
247
+ # -------------------------------
248
+ def coarse_search(self, spectrum, top_k: int = 256, subset: Optional[Dict[str, str]] = None):
249
+ """
250
+ Retrieve top-k candidates using FAISS, optionally restricted to subset metadata.
251
+ """
252
+ with torch.no_grad():
253
+ spectrum = spectrum.to(self.device)
254
+ z_spec = self.spectra_encoder(spectrum).sum(axis=0)
255
+ z_spec = z_spec.detach().cpu().numpy() if hasattr(z_spec, "detach") else np.asarray(z_spec)
256
+ z_spec = z_spec / np.linalg.norm(z_spec)
257
+
258
+ subset_idx = self._get_filtered_indices(subset)
259
+ if subset_idx.size == 0:
260
+ return [], []
261
+
262
+ # subset FAISS index
263
+ emb_subset = self.global_embeddings[subset_idx]
264
+ index_subset = faiss.IndexFlatIP(emb_subset.shape[1])
265
+ index_subset.add(emb_subset)
266
+ sims, idxs = index_subset.search(z_spec[None, :], min(top_k, len(subset_idx)))
267
+
268
+ candidate_ids = self.molecule_ids[subset_idx[idxs[0]]]
269
+ return candidate_ids, sims[0]
270
+
271
+ def fine_rerank(self, spectrum, candidate_ids: List[str], top_k: int = 50):
272
+ """
273
+ Compute fine-grained similarity for the candidates and rerank.
274
+ """
275
+ spectrum = spectrum.to(self.device)
276
+ with torch.no_grad():
277
+ z_spec_tokens = self.spectra_encoder(spectrum)
278
+ scores = []
279
+ for mol_id in candidate_ids:
280
+ if mol_id in self.node_cache:
281
+ mol_tokens = self.node_cache[mol_id]
282
+ elif mol_id in self.molecule_db:
283
+ mol = self.molecule_db[mol_id].to(self.device)
284
+ mol_tokens = self.molecule_encoder(mol)
285
+ else:
286
+ mol = self.smiles_preprocess(self.smiles_dict[mol_id])
287
+ mol = mol.to(self.device)
288
+ mol_tokens = self.molecule_encoder(mol)
289
+
290
+ s = self.fine_similarity_fn(z_spec_tokens, mol_tokens).item()
291
+ scores.append((mol_id, s))
292
+ scores.sort(key=lambda x: x[1], reverse=True)
293
+ return scores[:top_k]
294
+
295
+ def search(
296
+ self,
297
+ spectrum,
298
+ coarse_k: int = 256,
299
+ fine_k: int = 50,
300
+ subset: Optional[Dict[str, str]] = None,
301
+ ):
302
+ """
303
+ Full two-stage search pipeline with optional subset filtering.
304
+ """
305
+ candidate_ids, _ = self.coarse_search(spectrum, top_k=coarse_k, subset=subset)
306
+ if len(candidate_ids) == 0:
307
+ return []
308
+ ranked = self.fine_rerank(spectrum, candidate_ids, top_k=fine_k)
309
+ return ranked
310
+
311
+
312
+
313
+ if __name__ == "__main__":
314
+ import sys
315
+ sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
316
+
317
+ from flare.utils.data import get_spec_featurizer, get_mol_featurizer
318
+ from flare.utils.models import get_model
319
+ from flare.utils.mol_search import SpectraMoleculeRetriever
320
+ from flare.utils.general import filip_similarity_single
321
+ import yaml
322
+
323
+ metadata = {
324
+ "class": {
325
+ "lipid": ["mol1", "mol2"],
326
+ "peptide": ["mol3"]
327
+ },
328
+ "pathway": {
329
+ "beta-oxidation": ["mol1"],
330
+ "glycolysis": ["mol2", "mol3"]
331
+ }
332
+ }
333
+
334
+ smiles_dict = {
335
+ "mol1": "CCO",
336
+ "mol2": "CCN",
337
+ "mol3": "CCC"
338
+ }
339
+
340
+ # Load model and data
341
+ param_pth = '/data/yzhouc01/cancer/flare.yaml'
342
+ with open(param_pth) as f:
343
+ params = yaml.load(f, Loader=yaml.FullLoader)
344
+
345
+ spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
346
+ mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
347
+
348
+
349
+ # load model
350
+ checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt"
351
+ params['checkpoint_pth'] = checkpoint_pth
352
+ model = get_model(params['model'], params)
353
+
354
+ specMolRetriever = SpectraMoleculeRetriever(
355
+ molecule_encoder=model.mol_enc_model,
356
+ spectra_encoder=model.spec_enc_model,
357
+ fine_similarity_fn=filip_similarity_single,
358
+ smiles_preprocess=mol_featurizer
359
+ )
360
+
361
+ specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True)
362
+
363
+ # Filter search to molecules in a specific pathway
364
+ # results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"})
365
+
366
+ # for mol_id, score in results[:10]:
367
+ # print(f"{mol_id}: {score:.3f}")
notebooks/UMAP_spectra_embeddings.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/fine-grained_vs_global.ipynb CHANGED
@@ -29819,9 +29819,13 @@
29819
  ],
29820
  "metadata": {
29821
  "kernelspec": {
29822
- "display_name": "Python (spec)",
29823
  "language": "python",
29824
- "name": "spec"
 
 
 
 
29825
  }
29826
  },
29827
  "nbformat": 4,
 
29819
  ],
29820
  "metadata": {
29821
  "kernelspec": {
29822
+ "display_name": "spec",
29823
  "language": "python",
29824
+ "name": "python3"
29825
+ },
29826
+ "language_info": {
29827
+ "name": "python",
29828
+ "version": "3.11.7"
29829
  }
29830
  },
29831
  "nbformat": 4,
notebooks/good_vs_bad_instances.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/mol-spec_visualization.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/results.ipynb ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2cd3303a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pickle\n",
11
+ "import pandas as pd"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "8ccc0bc1",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "with open(\"/data/yzhouc01/FILIP-MS/experiments/20251110_filip-global/result_MassSpecGym_retrieval_candidates_formula.pkl\", \"rb\") as f:\n",
22
+ " result = pickle.load(f)"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 3,
28
+ "id": "8e517777",
29
+ "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "data": {
33
+ "text/html": [
34
+ "<div>\n",
35
+ "<style scoped>\n",
36
+ " .dataframe tbody tr th:only-of-type {\n",
37
+ " vertical-align: middle;\n",
38
+ " }\n",
39
+ "\n",
40
+ " .dataframe tbody tr th {\n",
41
+ " vertical-align: top;\n",
42
+ " }\n",
43
+ "\n",
44
+ " .dataframe thead th {\n",
45
+ " text-align: right;\n",
46
+ " }\n",
47
+ "</style>\n",
48
+ "<table border=\"1\" class=\"dataframe\">\n",
49
+ " <thead>\n",
50
+ " <tr style=\"text-align: right;\">\n",
51
+ " <th></th>\n",
52
+ " <th>rank_fine</th>\n",
53
+ " <th>rank_global</th>\n",
54
+ " <th>rank_sum</th>\n",
55
+ " <th>rank_weighted</th>\n",
56
+ " <th>rank_avg</th>\n",
57
+ " </tr>\n",
58
+ " </thead>\n",
59
+ " <tbody>\n",
60
+ " <tr>\n",
61
+ " <th>R@1</th>\n",
62
+ " <td>0.214571</td>\n",
63
+ " <td>0.163306</td>\n",
64
+ " <td>0.192869</td>\n",
65
+ " <td>0.191274</td>\n",
66
+ " <td>0.192869</td>\n",
67
+ " </tr>\n",
68
+ " <tr>\n",
69
+ " <th>R@5</th>\n",
70
+ " <td>0.483140</td>\n",
71
+ " <td>0.403566</td>\n",
72
+ " <td>0.447425</td>\n",
73
+ " <td>0.444862</td>\n",
74
+ " <td>0.447425</td>\n",
75
+ " </tr>\n",
76
+ " <tr>\n",
77
+ " <th>R@20</th>\n",
78
+ " <td>0.747095</td>\n",
79
+ " <td>0.694350</td>\n",
80
+ " <td>0.728355</td>\n",
81
+ " <td>0.726361</td>\n",
82
+ " <td>0.728355</td>\n",
83
+ " </tr>\n",
84
+ " </tbody>\n",
85
+ "</table>\n",
86
+ "</div>"
87
+ ],
88
+ "text/plain": [
89
+ " rank_fine rank_global rank_sum rank_weighted rank_avg\n",
90
+ "R@1 0.214571 0.163306 0.192869 0.191274 0.192869\n",
91
+ "R@5 0.483140 0.403566 0.447425 0.444862 0.447425\n",
92
+ "R@20 0.747095 0.694350 0.728355 0.726361 0.728355"
93
+ ]
94
+ },
95
+ "execution_count": 3,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "data = []\n",
102
+ "for i in [1, 5, 20]:\n",
103
+ " curr_d = {}\n",
104
+ " for c in result.columns.tolist():\n",
105
+ " if c.startswith('rank'):\n",
106
+ " curr_d[c] = result[result[c] <= i].shape[0] / result.shape[0]\n",
107
+ " data.append(curr_d)\n",
108
+ "\n",
109
+ "data_df = pd.DataFrame(data, index=['R@1', 'R@5', 'R@20'])\n",
110
+ "data_df\n"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 7,
116
+ "id": "10493857",
117
+ "metadata": {},
118
+ "outputs": [
119
+ {
120
+ "data": {
121
+ "text/html": [
122
+ "<div>\n",
123
+ "<style scoped>\n",
124
+ " .dataframe tbody tr th:only-of-type {\n",
125
+ " vertical-align: middle;\n",
126
+ " }\n",
127
+ "\n",
128
+ " .dataframe tbody tr th {\n",
129
+ " vertical-align: top;\n",
130
+ " }\n",
131
+ "\n",
132
+ " .dataframe thead th {\n",
133
+ " text-align: right;\n",
134
+ " }\n",
135
+ "</style>\n",
136
+ "<table border=\"1\" class=\"dataframe\">\n",
137
+ " <thead>\n",
138
+ " <tr style=\"text-align: right;\">\n",
139
+ " <th></th>\n",
140
+ " <th>rank_fine</th>\n",
141
+ " <th>rank_global</th>\n",
142
+ " <th>rank_sum</th>\n",
143
+ " <th>rank_weighted</th>\n",
144
+ " <th>rank_avg</th>\n",
145
+ " </tr>\n",
146
+ " </thead>\n",
147
+ " <tbody>\n",
148
+ " <tr>\n",
149
+ " <th>R@1</th>\n",
150
+ " <td>0.420882</td>\n",
151
+ " <td>0.369731</td>\n",
152
+ " <td>0.412907</td>\n",
153
+ " <td>0.411939</td>\n",
154
+ " <td>0.412907</td>\n",
155
+ " </tr>\n",
156
+ " <tr>\n",
157
+ " <th>R@5</th>\n",
158
+ " <td>0.744475</td>\n",
159
+ " <td>0.707052</td>\n",
160
+ " <td>0.738893</td>\n",
161
+ " <td>0.737412</td>\n",
162
+ " <td>0.738893</td>\n",
163
+ " </tr>\n",
164
+ " <tr>\n",
165
+ " <th>R@20</th>\n",
166
+ " <td>0.927660</td>\n",
167
+ " <td>0.916325</td>\n",
168
+ " <td>0.926407</td>\n",
169
+ " <td>0.926122</td>\n",
170
+ " <td>0.926407</td>\n",
171
+ " </tr>\n",
172
+ " </tbody>\n",
173
+ "</table>\n",
174
+ "</div>"
175
+ ],
176
+ "text/plain": [
177
+ " rank_fine rank_global rank_sum rank_weighted rank_avg\n",
178
+ "R@1 0.420882 0.369731 0.412907 0.411939 0.412907\n",
179
+ "R@5 0.744475 0.707052 0.738893 0.737412 0.738893\n",
180
+ "R@20 0.927660 0.916325 0.926407 0.926122 0.926407"
181
+ ]
182
+ },
183
+ "execution_count": 7,
184
+ "metadata": {},
185
+ "output_type": "execute_result"
186
+ }
187
+ ],
188
+ "source": [
189
+ "data = []\n",
190
+ "for i in [1, 5, 20]:\n",
191
+ " curr_d = {}\n",
192
+ " for c in result.columns.tolist():\n",
193
+ " if c.startswith('rank'):\n",
194
+ " curr_d[c] = result[result[c] <= i].shape[0] / result.shape[0]\n",
195
+ " data.append(curr_d)\n",
196
+ "\n",
197
+ "data_df = pd.DataFrame(data, index=['R@1', 'R@5', 'R@20'])\n",
198
+ "data_df\n"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "1e4201db",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "x"
209
+ ]
210
+ }
211
+ ],
212
+ "metadata": {
213
+ "kernelspec": {
214
+ "display_name": "spec",
215
+ "language": "python",
216
+ "name": "python3"
217
+ },
218
+ "language_info": {
219
+ "codemirror_mode": {
220
+ "name": "ipython",
221
+ "version": 3
222
+ },
223
+ "file_extension": ".py",
224
+ "mimetype": "text/x-python",
225
+ "name": "python",
226
+ "nbconvert_exporter": "python",
227
+ "pygments_lexer": "ipython3",
228
+ "version": "3.11.7"
229
+ }
230
+ },
231
+ "nbformat": 4,
232
+ "nbformat_minor": 5
233
+ }
notebooks/spectra_sim.ipynb CHANGED
The diff for this file is too large to render. See raw diff