adenshulga's picture
index on master: 92d5847 working training code
2073e38
import streamlit as st
from config.inference_config import cfg
from src.app.data_validation import validate_data
from src.app.setup_model import LabelScore, get_top_label_names, setup_pipeline
from src.app.tags_mapping import tags2full_name
from src.app.visualization import visualize_predicted_categories
st.title("arXiv Paper Classifier")
st.markdown("Enter paper details to predict arXiv categories")
st.text_input("Enter paper name", key="paper_name")
st.text_area("Enter paper abstract", key="paper_abstract", height=250)
if st.button("Predict Categories", type="primary") and validate_data(
st.session_state["paper_name"], st.session_state["paper_abstract"]
):
with st.spinner("Analyzing paper..."):
pipeline = setup_pipeline(cfg)
scores: list[LabelScore] = pipeline(
st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
top_k=None,
) # type: ignore
top_labels = get_top_label_names(scores, tags2full_name, cfg.top_percent)
visualize_predicted_categories(
top_labels, scores, tags2full_name, minimal_score=cfg.minimal_score
)
else:
st.info("Enter paper details and click 'Predict Categories' to get predictions.")