Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from config.inference_config import cfg | |
| from src.app.data_validation import validate_data | |
| from src.app.setup_model import LabelScore, get_top_label_names, setup_pipeline | |
| from src.app.tags_mapping import tags2full_name | |
| from src.app.visualization import visualize_predicted_categories | |
| st.title("arXiv Paper Classifier") | |
| st.markdown("Enter paper details to predict arXiv categories") | |
| st.text_input("Enter paper name", key="paper_name") | |
| st.text_area("Enter paper abstract", key="paper_abstract", height=250) | |
| if st.button("Predict Categories", type="primary") and validate_data( | |
| st.session_state["paper_name"], st.session_state["paper_abstract"] | |
| ): | |
| with st.spinner("Analyzing paper..."): | |
| pipeline = setup_pipeline(cfg) | |
| scores: list[LabelScore] = pipeline( | |
| st.session_state["paper_name"] + " " + st.session_state["paper_abstract"], | |
| top_k=None, | |
| ) # type: ignore | |
| top_labels = get_top_label_names(scores, tags2full_name, cfg.top_percent) | |
| visualize_predicted_categories( | |
| top_labels, scores, tags2full_name, minimal_score=cfg.minimal_score | |
| ) | |
| else: | |
| st.info("Enter paper details and click 'Predict Categories' to get predictions.") | |