import streamlit as st import pandas as pd import io import numpy as np from streamlit_plotly_events import plotly_events import dgl from app_utils.viz_utils import run from app_utils.examples import EXAMPLES from app_utils.model_utils import load_model_components st.set_page_config(page_title="Spectra Tool Demo", layout="wide") st.title("FLARE Peak-to-Node Alignement Visualization") st.markdown("Provide inputs below or load one of the example datasets.") FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz'] def reset_fields(): for field in FIELDS: st.session_state[field] = "" # ------------------------ # Session state defaults # ------------------------ if "run_clicked" not in st.session_state: st.session_state.run_clicked = False if "selected_spectrum_idx" not in st.session_state: st.session_state.selected_spectrum_idx = None if "selected_node_idx" not in st.session_state: st.session_state.selected_node_idx = None for f in FIELDS: if f not in st.session_state: st.session_state[f] = "" if "model" not in st.session_state: spec_featurizer, mol_featurizer, model = load_model_components() st.session_state.spec_featurizer = spec_featurizer st.session_state.mol_featurizer = mol_featurizer st.session_state.model = model # ------------------------ # Example loader dropdown # ------------------------ example_names = list(EXAMPLES.keys()) # Dropdown menu for selecting example selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names) # Load button if st.button("Load Example") and selected_example != "-- Select --": reset_fields() ex_data = EXAMPLES[selected_example] st.session_state.mzs = ex_data["mzs"] st.session_state.intensities = ex_data['intensities'] st.session_state.smiles = ex_data["smiles"] st.session_state.formula = ex_data["formula"] st.session_state.adduct = ex_data["adduct"] st.session_state.precursor_mz = ex_data["precursor_mz"] # reset graph st.session_state.run_clicked = False st.session_state.selected_spectrum_idx = None st.session_state.selected_node_idx = None # ------------------------ # Inputs # ------------------------ st.subheader("Spectra") mz_input = st.text_input( "m/z values (comma-separated):", value=st.session_state.mzs, placeholder="100,150,200,250,300" ) intensity_input = st.text_input( "Intensities (comma-separated):", value=st.session_state.intensities, placeholder="10,50,80,40,20" ) st.subheader("SMILES") smiles_input = st.text_input("Enter SMILES string:", value=st.session_state.smiles) st.subheader("Formula") formula_input = st.text_input("Enter molecular formula:", value=st.session_state.formula) st.subheader("Adduct") adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct) st.subheader("Precursor mz") precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz) # ------------------------ # Run model # ------------------------ if st.button("Run"): for f in FIELDS: if not st.session_state[f]: st.error(f"Field {f} is empty.") reset_fields() st.stop() st.session_state.mzs = mz_input st.session_state.intensities = intensity_input st.session_state.smiles = smiles_input st.session_state.formula = formula_input st.session_state.adduct = adduct_input st.session_state.precursor_mz = precursor_input mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()] intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()] if len(mz_input) != len(intensity_input): st.error("Number of m/z values must match the number of intensty values") reset_fields() st.stop() ms = np.array(list(zip(mz_input, intensity_input))) st.session_state.fig, st.session_state.sim_norm = run( ms, st.session_state.smiles, st.session_state.formula, st.session_state.precursor_mz, st.session_state.adduct, st.session_state.spec_featurizer, st.session_state.mol_featurizer, st.session_state.model, mass_diff_thresh=20, precursor_intensity=1.1 ) st.session_state.selected_spectrum_idx = None st.session_state.selected_node_idx = None st.session_state.run_clicked = True # ------------------------ # Display visualization # ------------------------ if st.session_state.run_clicked: st.text("Only annotated peaks are shown. Peaks assigned the same subformula are combined by summing all the intensities and the smallest m/z value is shown.") st.text("Double click on a peak or node to visualize similarity scores") fig = st.session_state.fig if st.session_state.selected_spectrum_idx is not None: idx = st.session_state.selected_spectrum_idx scores = st.session_state.sim_norm[idx, :] st.session_state.fig.data[2].marker.color = scores st.session_state.fig.data[0].marker.color = [ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0]) ] elif st.session_state.selected_node_idx is not None: idx = st.session_state.selected_node_idx scores = st.session_state.sim_norm[:, idx] st.session_state.fig.data[0].marker.color = scores st.session_state.fig.data[2].marker.color = [ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1]) ] # Render figure selected = plotly_events( st.session_state.fig, click_event=True, hover_event=False, key="events" ) # Handle click and update figure immediately if selected: point = selected[0] curve, idx = point["curveNumber"], point["pointIndex"] if curve == 0: # Spectrum clicked st.session_state.selected_spectrum_idx = idx st.session_state.selected_node_idx = None scores = st.session_state.sim_norm[idx, :] st.session_state.fig.data[2].marker.color = scores st.session_state.fig.data[0].marker.color = [ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0]) ] elif curve == 2: # Node clicked st.session_state.selected_node_idx = idx st.session_state.selected_spectrum_idx = None scores = st.session_state.sim_norm[:, idx] st.session_state.fig.data[0].marker.color = scores st.session_state.fig.data[2].marker.color = [ "red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1]) ]