Spaces:
Sleeping
Sleeping
File size: 6,837 Bytes
649e64f 97f0281 326e019 66236fb 97f0281 fccf43b 97f0281 649e64f cb1f3fc 649e64f 326e019 649e64f 326e019 649e64f 326e019 649e64f 326e019 649e64f 326e019 649e64f 326e019 28267e0 649e64f 326e019 28267e0 326e019 28267e0 326e019 28267e0 326e019 28267e0 326e019 28267e0 326e019 649e64f 326e019 649e64f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import streamlit as st
import pandas as pd
import io
import numpy as np
from streamlit_plotly_events import plotly_events
import dgl
from app_utils.viz_utils import run
from app_utils.examples import EXAMPLES
from app_utils.model_utils import load_model_components
st.set_page_config(page_title="Spectra Tool Demo", layout="wide")
st.title("FLARE Peak-to-Node Alignement Visualization")
st.markdown("Provide inputs below or load one of the example datasets.")
FIELDS = ['mzs', 'intensities', 'smiles', 'formula', 'adduct', 'precursor_mz']
def reset_fields():
for field in FIELDS:
st.session_state[field] = ""
# ------------------------
# Session state defaults
# ------------------------
if "run_clicked" not in st.session_state:
st.session_state.run_clicked = False
if "selected_spectrum_idx" not in st.session_state:
st.session_state.selected_spectrum_idx = None
if "selected_node_idx" not in st.session_state:
st.session_state.selected_node_idx = None
for f in FIELDS:
if f not in st.session_state:
st.session_state[f] = ""
if "model" not in st.session_state:
spec_featurizer, mol_featurizer, model = load_model_components()
st.session_state.spec_featurizer = spec_featurizer
st.session_state.mol_featurizer = mol_featurizer
st.session_state.model = model
# ------------------------
# Example loader dropdown
# ------------------------
example_names = list(EXAMPLES.keys())
# Dropdown menu for selecting example
selected_example = st.selectbox("Choose an example:", ["-- Select --"] + example_names)
# Load button
if st.button("Load Example") and selected_example != "-- Select --":
reset_fields()
ex_data = EXAMPLES[selected_example]
st.session_state.mzs = ex_data["mzs"]
st.session_state.intensities = ex_data['intensities']
st.session_state.smiles = ex_data["smiles"]
st.session_state.formula = ex_data["formula"]
st.session_state.adduct = ex_data["adduct"]
st.session_state.precursor_mz = ex_data["precursor_mz"]
# reset graph
st.session_state.run_clicked = False
st.session_state.selected_spectrum_idx = None
st.session_state.selected_node_idx = None
# ------------------------
# Inputs
# ------------------------
st.subheader("Spectra")
mz_input = st.text_input(
"m/z values (comma-separated):",
value=st.session_state.mzs,
placeholder="100,150,200,250,300"
)
intensity_input = st.text_input(
"Intensities (comma-separated):",
value=st.session_state.intensities,
placeholder="10,50,80,40,20"
)
st.subheader("SMILES")
smiles_input = st.text_input("Enter SMILES string:", value=st.session_state.smiles)
st.subheader("Formula")
formula_input = st.text_input("Enter molecular formula:", value=st.session_state.formula)
st.subheader("Adduct")
adduct_input = st.text_input("Enter adduct:", value=st.session_state.adduct)
st.subheader("Precursor mz")
precursor_input = st.text_input("Enter precursor mz:", value=st.session_state.precursor_mz)
# ------------------------
# Run model
# ------------------------
if st.button("Run"):
for f in FIELDS:
if not st.session_state[f]:
st.error(f"Field {f} is empty.")
reset_fields()
st.stop()
st.session_state.mzs = mz_input
st.session_state.intensities = intensity_input
st.session_state.smiles = smiles_input
st.session_state.formula = formula_input
st.session_state.adduct = adduct_input
st.session_state.precursor_mz = precursor_input
mz_input = [float(x) for x in st.session_state.mzs.split(",") if x.strip()]
intensity_input = [float(x) for x in st.session_state.intensities.split(",") if x.strip()]
if len(mz_input) != len(intensity_input):
st.error("Number of m/z values must match the number of intensty values")
reset_fields()
st.stop()
ms = np.array(list(zip(mz_input, intensity_input)))
st.session_state.fig, st.session_state.sim_norm = run(
ms,
st.session_state.smiles,
st.session_state.formula,
st.session_state.precursor_mz,
st.session_state.adduct,
st.session_state.spec_featurizer,
st.session_state.mol_featurizer,
st.session_state.model,
mass_diff_thresh=20,
precursor_intensity=1.1
)
st.session_state.selected_spectrum_idx = None
st.session_state.selected_node_idx = None
st.session_state.run_clicked = True
# ------------------------
# Display visualization
# ------------------------
if st.session_state.run_clicked:
st.text("Only annotated peaks are shown. Peaks assigned the same subformula are combined by summing all the intensities and the smallest m/z value is shown.")
st.text("Double click on a peak or node to visualize similarity scores")
fig = st.session_state.fig
if st.session_state.selected_spectrum_idx is not None:
idx = st.session_state.selected_spectrum_idx
scores = st.session_state.sim_norm[idx, :]
st.session_state.fig.data[2].marker.color = scores
st.session_state.fig.data[0].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
]
elif st.session_state.selected_node_idx is not None:
idx = st.session_state.selected_node_idx
scores = st.session_state.sim_norm[:, idx]
st.session_state.fig.data[0].marker.color = scores
st.session_state.fig.data[2].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
]
# Render figure
selected = plotly_events(
st.session_state.fig,
click_event=True,
hover_event=False,
key="events"
)
# Handle click and update figure immediately
if selected:
point = selected[0]
curve, idx = point["curveNumber"], point["pointIndex"]
if curve == 0: # Spectrum clicked
st.session_state.selected_spectrum_idx = idx
st.session_state.selected_node_idx = None
scores = st.session_state.sim_norm[idx, :]
st.session_state.fig.data[2].marker.color = scores
st.session_state.fig.data[0].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[0])
]
elif curve == 2: # Node clicked
st.session_state.selected_node_idx = idx
st.session_state.selected_spectrum_idx = None
scores = st.session_state.sim_norm[:, idx]
st.session_state.fig.data[0].marker.color = scores
st.session_state.fig.data[2].marker.color = [
"red" if i == idx else "lightgray" for i in range(st.session_state.sim_norm.shape[1])
]
|