yzhouchen001 commited on
Commit
b1aa639
·
1 Parent(s): 7b7a7b6
mvp/data/datasets.py CHANGED
@@ -19,6 +19,9 @@ import math
19
  import itertools
20
  from rdkit.Chem import AllChem
21
  from rdkit import Chem
 
 
 
22
  class JESTR1_MassSpecDataset(MassSpecDataset):
23
  def __init__(
24
  self,
@@ -90,8 +93,6 @@ class JESTR1_MassSpecDataset(MassSpecDataset):
90
  item[key] = transform(spec) if transform is not None else spec
91
  else:
92
  item["spec"] = self.spec_transform(spec)
93
- else:
94
- item["spec"] = spec
95
 
96
  if self.return_mol_freq:
97
  item["mol_freq"] = metadata["mol_freq"]
@@ -132,7 +133,9 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
132
  cons_spec_dir_pth: str = None,
133
  return_mol_freq: bool = False,
134
  return_identifier: bool = True,
135
- dtype: T.Type = torch.float32
 
 
136
  ):
137
  """
138
  Args:
@@ -146,6 +149,8 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
146
  self.use_cons_spec = False
147
  self.use_NL_spec = False
148
  self.spectra_view = spectra_view
 
 
149
 
150
  if isinstance(self.pth, str):
151
  self.pth = Path(self.pth)
@@ -155,19 +160,7 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
155
  self.metadata = pd.read_csv(self.pth, sep="\t")
156
 
157
  # load subformulas
158
- all_spec_ids = self.metadata['identifier'].tolist()
159
- subformulaLoader = data_utils.Subformula_Loader(spectra_view=spectra_view, dir_path=subformula_dir_pth)
160
-
161
- form_list = self.metadata['formula'].tolist()
162
- prec_mz_list = self.metadata['precursor_mz'].tolist()
163
- id_to_spec = subformulaLoader(all_spec_ids, form_list, prec_mz_list)
164
-
165
- # create subformula spectra if no subformula is available
166
- tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
167
- tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
168
- tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
169
- id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
170
-
171
 
172
  # load fingerprints
173
  self._load_fp(fp_dir_pth)
@@ -179,6 +172,7 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
179
  self._load_NL_spec(NL_spec_dir_pth)
180
 
181
  self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
 
182
  formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
183
  self.metadata = self.metadata.merge(formula_df, on='identifier')
184
 
@@ -208,6 +202,27 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
208
 
209
  return item
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  class ContrastiveDataset(Dataset):
212
  def __init__(
213
  self,
@@ -255,7 +270,11 @@ class ContrastiveDataset(Dataset):
255
  # standard collate
256
  for k in batch[0].keys():
257
  if k not in non_standard_collate:
258
- collated_batch[k] = default_collate([item[k] for item in batch])
 
 
 
 
259
 
260
  # batch graphs
261
  batch_mol = []
@@ -327,10 +346,13 @@ class ExpandedRetrievalDataset:
327
  candidates_pth: T.Optional[T.Union[Path, str]] = None,
328
  fp_size: int = None,
329
  fp_radius: int = None,
 
330
  **kwargs):
 
331
 
332
- self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
333
- # super().__init__(**kwargs)
 
334
 
335
  if self.use_fp:
336
  self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
@@ -348,9 +370,10 @@ class ExpandedRetrievalDataset:
348
 
349
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
350
  test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
351
- test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
 
 
352
 
353
- spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
354
  for spec_id, s in zip(test_ms_id, test_smiles):
355
  candidates = self.candidates[s]
356
  # mol_label = self.mol_label_transform(s)
@@ -363,7 +386,7 @@ class ExpandedRetrievalDataset:
363
  print(f"Target smiles not in candidate set")
364
 
365
 
366
- self.spec_cand.extend([(spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
367
 
368
  def __getattr__(self, name):
369
  return self.instance.__getattribute__(name)
@@ -376,7 +399,33 @@ class ExpandedRetrievalDataset:
376
  cand_smiles = self.spec_cand[i][1]
377
  label = self.spec_cand[i][2]
378
 
379
- item = self.instance.__getitem__(spec_i, transform_mol=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  item['cand'] = self.mol_transform(cand_smiles)
381
  item['cand_smiles'] = cand_smiles
382
  item['label'] = label
 
19
  import itertools
20
  from rdkit.Chem import AllChem
21
  from rdkit import Chem
22
+ from magma.run_magma import run_magma
23
+ import matchms
24
+
25
  class JESTR1_MassSpecDataset(MassSpecDataset):
26
  def __init__(
27
  self,
 
93
  item[key] = transform(spec) if transform is not None else spec
94
  else:
95
  item["spec"] = self.spec_transform(spec)
 
 
96
 
97
  if self.return_mol_freq:
98
  item["mol_freq"] = metadata["mol_freq"]
 
133
  cons_spec_dir_pth: str = None,
134
  return_mol_freq: bool = False,
135
  return_identifier: bool = True,
136
+ dtype: T.Type = torch.float32,
137
+ formula_source = 'default',
138
+ stage: Stage = Stage.TRAIN
139
  ):
140
  """
141
  Args:
 
149
  self.use_cons_spec = False
150
  self.use_NL_spec = False
151
  self.spectra_view = spectra_view
152
+ self.formula_source = formula_source
153
+ self.subformula_dir_pth = subformula_dir_pth
154
 
155
  if isinstance(self.pth, str):
156
  self.pth = Path(self.pth)
 
160
  self.metadata = pd.read_csv(self.pth, sep="\t")
161
 
162
  # load subformulas
163
+ id_to_spec = self._load_id_to_spec(stage)
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # load fingerprints
166
  self._load_fp(fp_dir_pth)
 
172
  self._load_NL_spec(NL_spec_dir_pth)
173
 
174
  self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
175
+
176
  formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
177
  self.metadata = self.metadata.merge(formula_df, on='identifier')
178
 
 
202
 
203
  return item
204
 
205
+ def _load_id_to_spec(self, stage):
206
+ if stage == Stage.TRAIN:
207
+ self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value]
208
+ else:
209
+ self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value]
210
+
211
+ all_spec_ids = self.metadata['identifier'].tolist()
212
+ self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source)
213
+
214
+ form_list = self.metadata['formula'].tolist()
215
+ prec_mz_list = self.metadata['precursor_mz'].tolist()
216
+ id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
217
+
218
+ # create subformula spectra if no subformula is available
219
+ tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
220
+ tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
221
+ tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
222
+ id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
223
+
224
+ return id_to_spec
225
+
226
  class ContrastiveDataset(Dataset):
227
  def __init__(
228
  self,
 
270
  # standard collate
271
  for k in batch[0].keys():
272
  if k not in non_standard_collate:
273
+ try:
274
+ collated_batch[k] = default_collate([item[k] for item in batch])
275
+ except:
276
+ print(f"Error in collating key {k}")
277
+ raise
278
 
279
  # batch graphs
280
  batch_mol = []
 
346
  candidates_pth: T.Optional[T.Union[Path, str]] = None,
347
  fp_size: int = None,
348
  fp_radius: int = None,
349
+ use_magma = False,
350
  **kwargs):
351
+
352
 
353
+ self.use_magma = use_magma
354
+
355
+ self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage = Stage.TEST) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
356
 
357
  if self.use_fp:
358
  self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
 
370
 
371
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
372
  test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
373
+ test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
374
+
375
+ self.spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
376
 
 
377
  for spec_id, s in zip(test_ms_id, test_smiles):
378
  candidates = self.candidates[s]
379
  # mol_label = self.mol_label_transform(s)
 
386
  print(f"Target smiles not in candidate set")
387
 
388
 
389
+ self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
390
 
391
  def __getattr__(self, name):
392
  return self.instance.__getattribute__(name)
 
399
  cand_smiles = self.spec_cand[i][1]
400
  label = self.spec_cand[i][2]
401
 
402
+ if self.use_magma:
403
+ item = self.instance.__getitem__(spec_i, transform_mol=False, transform_spec=False)
404
+
405
+ mzs = np.array([float(x) for x in self.metadata.iloc[spec_i]['mzs'].split(',')])
406
+ intensities = np.array([float(x) for x in self.metadata.iloc[spec_i]['intensities'].split(',')])
407
+ adduct = self.metadata.iloc[spec_i]['adduct']
408
+ precursor_mz = self.metadata.iloc[spec_i]['precursor_mz']
409
+ formula = self.metadata.iloc[spec_i]['formula']
410
+ spec_data = run_magma(i, mzs, intensities, cand_smiles, adduct)
411
+
412
+ spec = self.subformulaLoader.load_magma_data(spec_data, formula, precursor_mz)
413
+
414
+ spec = matchms.Spectrum(
415
+ mz = np.array(spec['formula_mzs']),
416
+ intensities = np.array(spec['formula_intensities']),
417
+ metadata = {'precursor_mz': precursor_mz, 'formulas': np.array(spec['formulas'])})
418
+
419
+ if isinstance(self.spec_transform, dict):
420
+
421
+ for key, transform in self.spec_transform.items():
422
+ item[key] = transform(spec) if transform is not None else spec
423
+ else:
424
+ item["spec"] = self.spec_transform(spec)
425
+
426
+ else:
427
+ item = self.instance.__getitem__(spec_i, transform_mol=False)
428
+
429
  item['cand'] = self.mol_transform(cand_smiles)
430
  item['cand_smiles'] = cand_smiles
431
  item['label'] = label
mvp/data/transforms.py CHANGED
@@ -160,7 +160,7 @@ class SpecFormulaMzFeaturizer(SpecTransform):
160
  # print(f"Couldn't vectorize {f}, element {e} not supported")
161
  continue
162
  return formula_vector
163
-
164
  class SpecFormulaFeaturizer(SpecTransform):
165
  ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only'''
166
  def __init__(
@@ -208,7 +208,7 @@ class SpecFormulaFeaturizer(SpecTransform):
208
  try:
209
  formula_vector[i][self.elem_to_pos[e]]+=ct
210
  except:
211
- print(f"Couldn't vectorize {f}, element {e} not supported")
212
  continue
213
  except:
214
  print(f"Couldn't vectorize {f}, formula not supported")
 
160
  # print(f"Couldn't vectorize {f}, element {e} not supported")
161
  continue
162
  return formula_vector
163
+
164
  class SpecFormulaFeaturizer(SpecTransform):
165
  ''' Uses processed mz and intensities, excludes mz values, keep peaks with formulas only'''
166
  def __init__(
 
208
  try:
209
  formula_vector[i][self.elem_to_pos[e]]+=ct
210
  except:
211
+ # print(f"Couldn't vectorize {f}, element {e} not supported")
212
  continue
213
  except:
214
  print(f"Couldn't vectorize {f}, formula not supported")
mvp/definitions.py CHANGED
@@ -40,4 +40,6 @@ MSGYM_STANDARD_MH = {
40
  }
41
  MSGYM_STANDARD_all = { # got these from Yinkai
42
  "mz_mean": 80.88304948022557,
43
- "mz_std" : 197.4588028571758}
 
 
 
40
  }
41
  MSGYM_STANDARD_all = { # got these from Yinkai
42
  "mz_mean": 80.88304948022557,
43
+ "mz_std" : 197.4588028571758}
44
+
45
+ PRECURSOR_INTENSITY = 1.1
mvp/params_formSpec.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # Experiment setup
2
  job_key: ''
3
- run_name: 'sirius_labels'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
@@ -12,14 +12,14 @@ checkpoint_pth:
12
  # Training setup
13
  max_epochs: 2000
14
  accelerator: 'gpu'
15
- devices: [1]
16
  log_every_n_steps: 250
17
  val_check_interval: 1.0
18
 
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
- subformula_dir_pth: /r/hassounlab/msgym_sirius # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
25
  cons_spec_dir_pth:
@@ -39,6 +39,7 @@ num_workers: 50
39
  ############################## Data transforms ##############################
40
  # - Spectra
41
  spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
 
42
  # 1. Binner
43
  max_mz: 1000
44
  bin_width: 1
 
1
  # Experiment setup
2
  job_key: ''
3
+ run_name: 'magma_all_labels'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
 
12
  # Training setup
13
  max_epochs: 2000
14
  accelerator: 'gpu'
15
+ devices: [0]
16
  log_every_n_steps: 250
17
  val_check_interval: 1.0
18
 
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
+ subformula_dir_pth: /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
25
  cons_spec_dir_pth:
 
39
  ############################## Data transforms ##############################
40
  # - Spectra
41
  spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
42
+ formula_source: 'magma_all' # magma_1, magma_all, sirius, default
43
  # 1. Binner
44
  max_mz: 1000
45
  bin_width: 1
mvp/run.sh CHANGED
@@ -1,3 +1,3 @@
1
- python train.py
2
  python test.py
3
  python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
 
1
+ # python train.py
2
  python test.py
3
  python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
mvp/test.py CHANGED
@@ -35,12 +35,14 @@ def main(params):
35
 
36
  # Init paths to data files
37
  if params['debug']:
38
- params['dataset_pth'] = "../data/sample/data.tsv"
 
39
  params['split_pth']=None
40
  params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
41
 
42
  # Load dataset
43
  spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
 
44
  mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
45
  dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
46
 
 
35
 
36
  # Init paths to data files
37
  if params['debug']:
38
+
39
+ params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
40
  params['split_pth']=None
41
  params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
42
 
43
  # Load dataset
44
  spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
45
+
46
  mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
47
  dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
48
 
mvp/utils/data.py CHANGED
@@ -7,7 +7,7 @@ from massspecgym.data.transforms import SpecTransform, MolTransform
7
  from mvp.data.transforms import MolToGraph
8
  import mvp.data.datasets as jestr_datasets
9
  import typing as T
10
- from mvp.definitions import MSGYM_FORMULA_VECTOR_NORM, MSGYM_STANDARD_MH
11
  import matchms
12
  import tqdm
13
 
@@ -30,6 +30,7 @@ class Subformula_Loader:
30
 
31
  def __call__(self, ids, form_list, prec_mz_list):
32
  id_to_form_spec = {}
 
33
  for id, curr_form, curr_prec_mz in tqdm.tqdm(zip(ids, form_list, prec_mz_list), total=len(ids)):
34
  data = self.load(id, curr_form, curr_prec_mz)
35
  if data is not None:
@@ -51,10 +52,10 @@ class Subformula_Loader:
51
  if curr_form not in formulas and self.use_prec_mz:
52
  mzs = np.concatenate([mzs, [curr_prec_mz]])
53
  formulas = np.concatenate([formulas, [curr_form]])
54
- intensities = np.concatenate([intensities, [1.1]])
55
  elif curr_form in formulas and self.use_prec_mz:
56
  idx = np.where(formulas == curr_form)[0][0]
57
- intensities[idx] = 1.1
58
 
59
  # sort by mzs
60
  ind = mzs.argsort()
@@ -66,8 +67,75 @@ class Subformula_Loader:
66
  return None
67
 
68
  def load_magma_data(self, data, curr_form, curr_prec_mz):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return None
71
 
72
  def load_sirius_data(self, data):
73
  try:
@@ -76,9 +144,9 @@ class Subformula_Loader:
76
  formulas = np.array([entry['molecularFormula'] for entry in data['fragments']])
77
  intensities = np.array([entry['relativeIntensity'] for entry in data['fragments'] ])
78
 
79
- intensities[formulas == data['molecularFormula']] = 1.1
80
 
81
- if not self.use_prec_mz:
82
  not_append_prec_mz = np.array([len(entry['peaks']) != 0 for entry in data['fragments']])
83
 
84
  mzs = mzs[not_append_prec_mz]
@@ -102,7 +170,7 @@ class Subformula_Loader:
102
  data = json.load(f)
103
  if self.formula_source == 'sirius':
104
  return self.load_sirius_data(data)
105
- elif self.formula_source == 'magma':
106
  return self.load_magma_data(data, curr_form, curr_prec_mz)
107
  else:
108
  return self.load_mist_data(data, curr_form, curr_prec_mz)
@@ -200,7 +268,7 @@ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
200
 
201
  dataset_params = {'spectra_view': spectra_view, 'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, "candidates_pth": params['candidates_pth']}
202
  if "SpecFormula" in views or "SpecFormulaMz" in views:
203
- dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
204
  use_formulas = True
205
 
206
  if params['use_cons_spec']:
@@ -223,7 +291,7 @@ def get_ms_dataset(spectra_view: str,
223
  dataset_params = {'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, 'spectra_view': spectra_view}
224
  use_formulas = False
225
  if "SpecFormula" in spectra_view:
226
- dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth']})
227
  use_formulas = True
228
 
229
  if params['pred_fp'] or params['use_fp']:
 
7
  from mvp.data.transforms import MolToGraph
8
  import mvp.data.datasets as jestr_datasets
9
  import typing as T
10
+ from mvp.definitions import MSGYM_FORMULA_VECTOR_NORM, MSGYM_STANDARD_MH, PRECURSOR_INTENSITY
11
  import matchms
12
  import tqdm
13
 
 
30
 
31
  def __call__(self, ids, form_list, prec_mz_list):
32
  id_to_form_spec = {}
33
+ print("Processing formula spectra")
34
  for id, curr_form, curr_prec_mz in tqdm.tqdm(zip(ids, form_list, prec_mz_list), total=len(ids)):
35
  data = self.load(id, curr_form, curr_prec_mz)
36
  if data is not None:
 
52
  if curr_form not in formulas and self.use_prec_mz:
53
  mzs = np.concatenate([mzs, [curr_prec_mz]])
54
  formulas = np.concatenate([formulas, [curr_form]])
55
+ intensities = np.concatenate([intensities, [PRECURSOR_INTENSITY]])
56
  elif curr_form in formulas and self.use_prec_mz:
57
  idx = np.where(formulas == curr_form)[0][0]
58
+ intensities[idx] = PRECURSOR_INTENSITY
59
 
60
  # sort by mzs
61
  ind = mzs.argsort()
 
67
  return None
68
 
69
  def load_magma_data(self, data, curr_form, curr_prec_mz):
70
+
71
+ np.random.seed(42)
72
+
73
+ formula_to_intensity = {}
74
+ formula_to_mz = {}
75
+
76
+ # data is None
77
+ if data is None:
78
+ if self.use_prec_mz:
79
+ return {'formulas': [curr_form], 'formula_mzs': [curr_prec_mz], 'formula_intensities': [PRECURSOR_INTENSITY]}
80
+ else:
81
+ return {'formulas': [], 'formula_mzs': [], 'formula_intensities': []}
82
+
83
+ # randomly choose 1 formula for each peak, keep largest intensity for each formula
84
+ if self.formula_source.endswith('1'):
85
+ for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
86
+
87
+ if not f:
88
+ continue
89
+ selected_f = np.random.choice(f)
90
+ if selected_f in formula_to_intensity:
91
+ if i > formula_to_intensity[selected_f]:
92
+ formula_to_intensity[selected_f] = i
93
+ formula_to_mz[selected_f] = m
94
+ else:
95
+ formula_to_intensity[selected_f] = i
96
+ formula_to_mz[selected_f] = m
97
+
98
+ # take all formulas, divide intensity by number of formulas, keep largest intensity for each formula
99
+ elif self.formula_source.endswith('all'):
100
+ for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
101
+
102
+ if not f:
103
+ continue
104
+ for fi in f:
105
+ if fi in formula_to_intensity:
106
+ if i/len(f) > formula_to_intensity[fi]:
107
+ formula_to_intensity[fi] = i/len(f)
108
+ formula_to_mz[fi] = m
109
+ else:
110
+ formula_to_intensity[fi] = i/len(f)
111
+ formula_to_mz[fi] = m
112
+ else:
113
+ raise Exception(f"Formula source not supported: {self.formula_source}")
114
+
115
+ mzs = list(formula_to_mz.values())
116
+ formulas = list(formula_to_mz.keys())
117
+ intensities = list(formula_to_intensity.values())
118
+
119
+ # add precursor mz
120
+ if self.use_prec_mz:
121
+ if curr_form in formulas:
122
+ intensities[formulas.index(curr_form)] = PRECURSOR_INTENSITY
123
+ else:
124
+ formulas.append(curr_form)
125
+ intensities.append(PRECURSOR_INTENSITY)
126
+ mzs.append(curr_prec_mz)
127
+
128
+ # sort by mzs
129
+ mzs = np.array(mzs)
130
+ formulas = np.array(formulas)
131
+ intensities = np.array(intensities)
132
+
133
+ ind = mzs.argsort()
134
+ mzs = mzs[ind]
135
+ formulas = formulas[ind]
136
+ intensities = intensities[ind]
137
 
138
+ return {'formulas': formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
139
 
140
  def load_sirius_data(self, data):
141
  try:
 
144
  formulas = np.array([entry['molecularFormula'] for entry in data['fragments']])
145
  intensities = np.array([entry['relativeIntensity'] for entry in data['fragments'] ])
146
 
147
+ intensities[formulas == data['molecularFormula']] = PRECURSOR_INTENSITY
148
 
149
+ if not self.use_prec_mz: # removing precursor formula
150
  not_append_prec_mz = np.array([len(entry['peaks']) != 0 for entry in data['fragments']])
151
 
152
  mzs = mzs[not_append_prec_mz]
 
170
  data = json.load(f)
171
  if self.formula_source == 'sirius':
172
  return self.load_sirius_data(data)
173
+ elif self.formula_source.startswith('magma'):
174
  return self.load_magma_data(data, curr_form, curr_prec_mz)
175
  else:
176
  return self.load_mist_data(data, curr_form, curr_prec_mz)
 
268
 
269
  dataset_params = {'spectra_view': spectra_view, 'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, "candidates_pth": params['candidates_pth']}
270
  if "SpecFormula" in views or "SpecFormulaMz" in views:
271
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'use_magma': params['formula_source'].startswith('magma'), 'formula_source':params['formula_source']})
272
  use_formulas = True
273
 
274
  if params['use_cons_spec']:
 
291
  dataset_params = {'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, 'spectra_view': spectra_view}
292
  use_formulas = False
293
  if "SpecFormula" in spectra_view:
294
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'formula_source': params['formula_source']})
295
  use_formulas = True
296
 
297
  if params['pred_fp'] or params['use_fp']: