Anton Bushuiev commited on
Commit
68a450f
·
1 Parent(s): 26dbaa6

Make similarity_threshold adjustable

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -46,9 +46,6 @@ LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5')
46
  DATA_PATH = Path('./DreaMS/data')
47
  EXAMPLE_PATH = Path('./data')
48
 
49
- # Similarity threshold for filtering results
50
- SIMILARITY_THRESHOLD = 0.75
51
-
52
  # Cache for SMILES images to avoid regeneration
53
  _smiles_cache = {}
54
 
@@ -311,7 +308,7 @@ def _predict_gpu(in_pth, progress):
311
  return embs
312
 
313
 
314
- def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine=False):
315
  """
316
  Create a single result row for the DataFrame
317
 
@@ -324,6 +321,7 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calcula
324
  sims: Similarity matrix
325
  cos_sim: Cosine similarity calculator
326
  embs: Query embeddings
 
327
  calculate_modified_cosine: Whether to calculate modified cosine similarity
328
 
329
  Returns:
@@ -343,9 +341,9 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calcula
343
  'precursor_mz': msdata.get_prec_mzs(i),
344
  'topk': n + 1,
345
  'library_j': j,
346
- 'library_SMILES': smiles_to_html_img(smiles) if dreams_similarity > SIMILARITY_THRESHOLD else None,
347
  'library_SMILES_raw': smiles,
348
- 'Spectrum': spectrum_to_html_img(spec1, spec2) if dreams_similarity > SIMILARITY_THRESHOLD else None,
349
  'Spectrum_raw': su.unpad_peak_list(spec1),
350
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
351
  'DreaMS_similarity': dreams_similarity,
@@ -367,13 +365,14 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calcula
367
  return row_data
368
 
369
 
370
- def _process_results_dataframe(df, in_pth, calculate_modified_cosine=False):
371
  """
372
  Process and clean the results DataFrame
373
 
374
  Args:
375
  df: Raw results DataFrame
376
  in_pth: Input file path for CSV export
 
377
  calculate_modified_cosine: Whether modified cosine similarity was calculated
378
 
379
  Returns:
@@ -423,7 +422,7 @@ def _process_results_dataframe(df, in_pth, calculate_modified_cosine=False):
423
  df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
424
  df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
425
  df = df.drop(columns=['Top k'])
426
- df = df[df["DreaMS similarity"] > SIMILARITY_THRESHOLD]
427
 
428
  # Add row numbers
429
  df.insert(0, 'Row', range(1, len(df) + 1))
@@ -431,7 +430,7 @@ def _process_results_dataframe(df, in_pth, calculate_modified_cosine=False):
431
  return df, str(df_path)
432
 
433
 
434
- def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
435
  """
436
  Core prediction function that orchestrates the entire prediction pipeline
437
 
@@ -490,7 +489,7 @@ def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
490
  desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
491
 
492
  for n, j in enumerate(topk):
493
- row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine)
494
  df.append(row_data)
495
 
496
  # Clear cache every 100 spectra to prevent memory buildup
@@ -501,7 +500,7 @@ def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
501
 
502
  # Process and clean results
503
  progress(0.9, desc="Post-processing results...")
504
- df, csv_path = _process_results_dataframe(df, in_pth, calculate_modified_cosine)
505
 
506
  progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
507
 
@@ -515,7 +514,7 @@ def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
515
  temp_in_path.unlink()
516
 
517
 
518
- def predict(lib_pth, in_pth, calculate_modified_cosine=False, progress=gr.Progress(track_tqdm=True)):
519
  """
520
  Main prediction function with error handling
521
 
@@ -540,7 +539,7 @@ def predict(lib_pth, in_pth, calculate_modified_cosine=False, progress=gr.Progre
540
  if not Path(lib_pth).exists():
541
  raise gr.Error("Spectral library not found. Please ensure the library file exists.")
542
 
543
- df, csv_path = _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress)
544
 
545
  return df, csv_path
546
 
@@ -616,6 +615,14 @@ def _create_gradio_interface():
616
 
617
  # Settings section
618
  with gr.Accordion("⚙️ Settings", open=False):
 
 
 
 
 
 
 
 
619
  calculate_modified_cosine = gr.Checkbox(
620
  label="Calculate modified cosine similarity",
621
  value=False,
@@ -647,7 +654,7 @@ def _create_gradio_interface():
647
  )
648
 
649
  # Connect prediction logic
650
- inputs = [in_pth, calculate_modified_cosine]
651
  outputs = [df, df_file]
652
 
653
  # Function to update dataframe headers based on setting
 
46
  DATA_PATH = Path('./DreaMS/data')
47
  EXAMPLE_PATH = Path('./data')
48
 
 
 
 
49
  # Cache for SMILES images to avoid regeneration
50
  _smiles_cache = {}
51
 
 
308
  return embs
309
 
310
 
311
+ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, similarity_threshold, calculate_modified_cosine=False):
312
  """
313
  Create a single result row for the DataFrame
314
 
 
321
  sims: Similarity matrix
322
  cos_sim: Cosine similarity calculator
323
  embs: Query embeddings
324
+ similarity_threshold: Similarity threshold for filtering results
325
  calculate_modified_cosine: Whether to calculate modified cosine similarity
326
 
327
  Returns:
 
341
  'precursor_mz': msdata.get_prec_mzs(i),
342
  'topk': n + 1,
343
  'library_j': j,
344
+ 'library_SMILES': smiles_to_html_img(smiles) if dreams_similarity > similarity_threshold else None,
345
  'library_SMILES_raw': smiles,
346
+ 'Spectrum': spectrum_to_html_img(spec1, spec2) if dreams_similarity > similarity_threshold else None,
347
  'Spectrum_raw': su.unpad_peak_list(spec1),
348
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
349
  'DreaMS_similarity': dreams_similarity,
 
365
  return row_data
366
 
367
 
368
+ def _process_results_dataframe(df, in_pth, similarity_threshold, calculate_modified_cosine=False):
369
  """
370
  Process and clean the results DataFrame
371
 
372
  Args:
373
  df: Raw results DataFrame
374
  in_pth: Input file path for CSV export
375
+ similarity_threshold: Similarity threshold for filtering results
376
  calculate_modified_cosine: Whether modified cosine similarity was calculated
377
 
378
  Returns:
 
422
  df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"])
423
  df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False)
424
  df = df.drop(columns=['Top k'])
425
+ df = df[df["DreaMS similarity"] > similarity_threshold]
426
 
427
  # Add row numbers
428
  df.insert(0, 'Row', range(1, len(df) + 1))
 
430
  return df, str(df_path)
431
 
432
 
433
+ def _predict_core(lib_pth, in_pth, similarity_threshold, calculate_modified_cosine, progress):
434
  """
435
  Core prediction function that orchestrates the entire prediction pipeline
436
 
 
489
  desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
490
 
491
  for n, j in enumerate(topk):
492
+ row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, similarity_threshold, calculate_modified_cosine)
493
  df.append(row_data)
494
 
495
  # Clear cache every 100 spectra to prevent memory buildup
 
500
 
501
  # Process and clean results
502
  progress(0.9, desc="Post-processing results...")
503
+ df, csv_path = _process_results_dataframe(df, in_pth, similarity_threshold, calculate_modified_cosine)
504
 
505
  progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
506
 
 
514
  temp_in_path.unlink()
515
 
516
 
517
+ def predict(lib_pth, in_pth, similarity_threshold=0.75, calculate_modified_cosine=False, progress=gr.Progress(track_tqdm=True)):
518
  """
519
  Main prediction function with error handling
520
 
 
539
  if not Path(lib_pth).exists():
540
  raise gr.Error("Spectral library not found. Please ensure the library file exists.")
541
 
542
+ df, csv_path = _predict_core(lib_pth, in_pth, similarity_threshold, calculate_modified_cosine, progress)
543
 
544
  return df, csv_path
545
 
 
615
 
616
  # Settings section
617
  with gr.Accordion("⚙️ Settings", open=False):
618
+ similarity_threshold = gr.Slider(
619
+ minimum=-1.0,
620
+ maximum=1.0,
621
+ value=0.75,
622
+ step=0.01,
623
+ label="Similarity threshold",
624
+ info="Only display library matches with DreaMS similarity above this threshold (rendering less results also makes calculation faster)"
625
+ )
626
  calculate_modified_cosine = gr.Checkbox(
627
  label="Calculate modified cosine similarity",
628
  value=False,
 
654
  )
655
 
656
  # Connect prediction logic
657
+ inputs = [in_pth, similarity_threshold, calculate_modified_cosine]
658
  outputs = [df, df_file]
659
 
660
  # Function to update dataframe headers based on setting