Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import io | |
| import numpy as np | |
| from streamlit_plotly_events import plotly_events | |
| import dgl | |
| from app_utils.viz_utils import run | |
| from app_utils.examples import EXAMPLES | |
| from app_utils.model_utils import load_model_components | |
| st.set_page_config(page_title="Spectra Tool Demo", layout="wide") | |
| st.title("FLARE Peak-to-Node Alignement Visualization") | |
| st.markdown("Provide inputs below or load one of the example datasets.") | |
| FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz'] | |
| def reset_fields(): | |
| for field in FIELDS: | |
| st.session_state[field] = "" | |
| # ------------------------ | |
| # Session state defaults | |
| # ------------------------ | |
| if "run_clicked" not in st.session_state: | |
| st.session_state.run_clicked = False | |
| if "selected_spectrum_idx" not in st.session_state: | |
| st.session_state.selected_spectrum_idx = None | |
| if "selected_node_idx" not in st.session_state: | |
| st.session_state.selected_node_idx = None | |
| for f in FIELDS: | |
| if f not in st.session_state: | |
| st.session_state[f] = "" | |
| if "model" not in st.session_state: | |
| spec_featurizer, mol_featurizer, model = load_model_components() | |
| st.session_state.spec_featurizer = spec_featurizer | |
| st.session_state.mol_featurizer = mol_featurizer | |
| st.session_state.model = model | |
| # ------------------------ | |
| # Example loader dropdown | |
| # ------------------------ | |
| example_names = list(EXAMPLES.keys()) | |
| # Dropdown menu for selecting example | |
| selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names) | |
| # Load button | |
| if st.button("Load Example") and selected_example != "-- Select --": | |
| reset_fields() | |
| ex_data = EXAMPLES[selected_example] | |
| st.session_state.mzs = ex_data["mzs"] | |
| st.session_state.intensities = ex_data['intensities'] | |
| st.session_state.smiles = ex_data["smiles"] | |
| st.session_state.formula = ex_data["formula"] | |
| st.session_state.adduct = ex_data["adduct"] | |
| st.session_state.precursor_mz = ex_data["precursor_mz"] | |
| # reset graph | |
| st.session_state.run_clicked = False | |
| st.session_state.selected_spectrum_idx = None | |
| st.session_state.selected_node_idx = None | |
| # ------------------------ | |
| # Inputs | |
| # ------------------------ | |
| st.subheader("Spectra") | |
| mz_input = st.text_input( | |
| "m/z values (comma-separated):", | |
| value=st.session_state.mzs, | |
| placeholder="100,150,200,250,300" | |
| ) | |
| intensity_input = st.text_input( | |
| "Intensities (comma-separated):", | |
| value=st.session_state.intensities, | |
| placeholder="10,50,80,40,20" | |
| ) | |
| st.subheader("SMILES") | |
| smiles_input = st.text_input("Enter SMILES string:", value=st.session_state.smiles) | |
| st.subheader("Formula") | |
| formula_input = st.text_input("Enter molecular formula:", value=st.session_state.formula) | |
| st.subheader("Adduct") | |
| adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct) | |
| st.subheader("Precursor mz") | |
| precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz) | |
| # ------------------------ | |
| # Run model | |
| # ------------------------ | |
| if st.button("Run"): | |
| for f in FIELDS: | |
| if not st.session_state[f]: | |
| st.error(f"Field {f} is empty.") | |
| reset_fields() | |
| st.stop() | |
| st.session_state.mzs = mz_input | |
| st.session_state.intensities = intensity_input | |
| st.session_state.smiles = smiles_input | |
| st.session_state.formula = formula_input | |
| st.session_state.adduct = adduct_input | |
| st.session_state.precursor_mz = precursor_input | |
| mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()] | |
| intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()] | |
| if len(mz_input) != len(intensity_input): | |
| st.error("Number of m/z values must match the number of intensty values") | |
| reset_fields() | |
| st.stop() | |
| ms = np.array(list(zip(mz_input, intensity_input))) | |
| st.session_state.fig, st.session_state.sim_norm = run( | |
| ms, | |
| st.session_state.smiles, | |
| st.session_state.formula, | |
| st.session_state.precursor_mz, | |
| st.session_state.adduct, | |
| st.session_state.spec_featurizer, | |
| st.session_state.mol_featurizer, | |
| st.session_state.model, | |
| mass_diff_thresh=20, | |
| precursor_intensity=1.1 | |
| ) | |
| st.session_state.selected_spectrum_idx = None | |
| st.session_state.selected_node_idx = None | |
| st.session_state.run_clicked = True | |
| # ------------------------ | |
| # Display visualization | |
| # ------------------------ | |
| if st.session_state.run_clicked: | |
| st.text("Only annotated peaks are shown. Peaks assigned the same subformula are combined by summing all the intensities and the smallest m/z value is shown.") | |
| st.text("Double click on a peak or node to visualize similarity scores") | |
| fig = st.session_state.fig | |
| if st.session_state.selected_spectrum_idx is not None: | |
| idx = st.session_state.selected_spectrum_idx | |
| scores = st.session_state.sim_norm[idx, :] | |
| st.session_state.fig.data[2].marker.color = scores | |
| st.session_state.fig.data[0].marker.color = [ | |
| "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0]) | |
| ] | |
| elif st.session_state.selected_node_idx is not None: | |
| idx = st.session_state.selected_node_idx | |
| scores = st.session_state.sim_norm[:, idx] | |
| st.session_state.fig.data[0].marker.color = scores | |
| st.session_state.fig.data[2].marker.color = [ | |
| "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1]) | |
| ] | |
| # Render figure | |
| selected = plotly_events( | |
| st.session_state.fig, | |
| click_event=True, | |
| hover_event=False, | |
| key="events" | |
| ) | |
| # Handle click and update figure immediately | |
| if selected: | |
| point = selected[0] | |
| curve, idx = point["curveNumber"], point["pointIndex"] | |
| if curve == 0: # Spectrum clicked | |
| st.session_state.selected_spectrum_idx = idx | |
| st.session_state.selected_node_idx = None | |
| scores = st.session_state.sim_norm[idx, :] | |
| st.session_state.fig.data[2].marker.color = scores | |
| st.session_state.fig.data[0].marker.color = [ | |
| "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0]) | |
| ] | |
| elif curve == 2: # Node clicked | |
| st.session_state.selected_node_idx = idx | |
| st.session_state.selected_spectrum_idx = None | |
| scores = st.session_state.sim_norm[:, idx] | |
| st.session_state.fig.data[0].marker.color = scores | |
| st.session_state.fig.data[2].marker.color = [ | |
| "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1]) | |
| ] | |