FLARE / app.py
yzhouchen001's picture
app
cb1f3fc
import streamlit as st
import pandas as pd
import io
import numpy as np
from streamlit_plotly_events import plotly_events
import dgl
from app_utils.viz_utils import run
from app_utils.examples import EXAMPLES
from app_utils.model_utils import load_model_components
st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
st.title("FLARE Peak-to-Node Alignement Visualization")
st.markdown("Provide inputs below or load one of the example datasets.")
FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz']
def reset_fields():
for field in FIELDS:
st.session_state[field] = ""
# ------------------------
# Session state defaults
# ------------------------
if "run_clicked" not in st.session_state:
st.session_state.run_clicked = False
if "selected_spectrum_idx" not in st.session_state:
st.session_state.selected_spectrum_idx = None
if "selected_node_idx" not in st.session_state:
st.session_state.selected_node_idx = None
for f in FIELDS:
if f not in st.session_state:
st.session_state[f] = ""
if "model" not in st.session_state:
spec_featurizer, mol_featurizer, model = load_model_components()
st.session_state.spec_featurizer = spec_featurizer
st.session_state.mol_featurizer = mol_featurizer
st.session_state.model = model
# ------------------------
# Example loader dropdown
# ------------------------
example_names = list(EXAMPLES.keys())
# Dropdown menu for selecting example
selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names)
# Load button
if st.button("Load Example") and selected_example != "-- Select --":
reset_fields()
ex_data = EXAMPLES[selected_example]
st.session_state.mzs = ex_data["mzs"]
st.session_state.intensities = ex_data['intensities']
st.session_state.smiles = ex_data["smiles"]
st.session_state.formula = ex_data["formula"]
st.session_state.adduct = ex_data["adduct"]
st.session_state.precursor_mz = ex_data["precursor_mz"]
# reset graph
st.session_state.run_clicked = False
st.session_state.selected_spectrum_idx = None
st.session_state.selected_node_idx = None
# ------------------------
# Inputs
# ------------------------
st.subheader("Spectra")
mz_input = st.text_input(
"m/z values (comma-separated):",
value=st.session_state.mzs,
placeholder="100,150,200,250,300"
)
intensity_input = st.text_input(
"Intensities (comma-separated):",
value=st.session_state.intensities,
placeholder="10,50,80,40,20"
)
st.subheader("SMILES")
smiles_input = st.text_input("Enter SMILES string:", value=st.session_state.smiles)
st.subheader("Formula")
formula_input = st.text_input("Enter molecular formula:", value=st.session_state.formula)
st.subheader("Adduct")
adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
st.subheader("Precursor mz")
precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz)
# ------------------------
# Run model
# ------------------------
if st.button("Run"):
for f in FIELDS:
if not st.session_state[f]:
st.error(f"Field {f} is empty.")
reset_fields()
st.stop()
st.session_state.mzs = mz_input
st.session_state.intensities = intensity_input
st.session_state.smiles = smiles_input
st.session_state.formula = formula_input
st.session_state.adduct = adduct_input
st.session_state.precursor_mz = precursor_input
mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()]
intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()]
if len(mz_input) != len(intensity_input):
st.error("Number of m/z values must match the number of intensty values")
reset_fields()
st.stop()
ms = np.array(list(zip(mz_input, intensity_input)))
st.session_state.fig, st.session_state.sim_norm = run(
ms,
st.session_state.smiles,
st.session_state.formula,
st.session_state.precursor_mz,
st.session_state.adduct,
st.session_state.spec_featurizer,
st.session_state.mol_featurizer,
st.session_state.model,
mass_diff_thresh=20,
precursor_intensity=1.1
)
st.session_state.selected_spectrum_idx = None
st.session_state.selected_node_idx = None
st.session_state.run_clicked = True
# ------------------------
# Display visualization
# ------------------------
if st.session_state.run_clicked:
st.text("Only annotated peaks are shown. Peaks assigned the same subformula are combined by summing all the intensities and the smallest m/z value is shown.")
st.text("Double click on a peak or node to visualize similarity scores")
fig = st.session_state.fig
if st.session_state.selected_spectrum_idx is not None:
idx = st.session_state.selected_spectrum_idx
scores = st.session_state.sim_norm[idx, :]
st.session_state.fig.data[2].marker.color = scores
st.session_state.fig.data[0].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
]
elif st.session_state.selected_node_idx is not None:
idx = st.session_state.selected_node_idx
scores = st.session_state.sim_norm[:, idx]
st.session_state.fig.data[0].marker.color = scores
st.session_state.fig.data[2].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
]
# Render figure
selected = plotly_events(
st.session_state.fig,
click_event=True,
hover_event=False,
key="events"
)
# Handle click and update figure immediately
if selected:
point = selected[0]
curve, idx = point["curveNumber"], point["pointIndex"]
if curve == 0: # Spectrum clicked
st.session_state.selected_spectrum_idx = idx
st.session_state.selected_node_idx = None
scores = st.session_state.sim_norm[idx, :]
st.session_state.fig.data[2].marker.color = scores
st.session_state.fig.data[0].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
]
elif curve == 2: # Node clicked
st.session_state.selected_node_idx = idx
st.session_state.selected_spectrum_idx = None
scores = st.session_state.sim_norm[:, idx]
st.session_state.fig.data[0].marker.color = scores
st.session_state.fig.data[2].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
]