File size: 1,233 Bytes
92d5847
2073e38
 
 
 
92d5847
 
 
 
 
 
 
 
 
2073e38
 
 
92d5847
2073e38
92d5847
 
2073e38
92d5847
 
2073e38
92d5847
2073e38
 
 
92d5847
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import streamlit as st

from config.inference_config import cfg
from src.app.data_validation import validate_data
from src.app.setup_model import LabelScore, get_top_label_names, setup_pipeline
from src.app.tags_mapping import tags2full_name
from src.app.visualization import visualize_predicted_categories

st.title("arXiv Paper Classifier")
st.markdown("Enter paper details to predict arXiv categories")

st.text_input("Enter paper name", key="paper_name")
st.text_area("Enter paper abstract", key="paper_abstract", height=250)

if st.button("Predict Categories", type="primary") and validate_data(
    st.session_state["paper_name"], st.session_state["paper_abstract"]
):
    with st.spinner("Analyzing paper..."):
        pipeline = setup_pipeline(cfg)
        scores: list[LabelScore] = pipeline(
            st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
            top_k=None,
        )  # type: ignore

        top_labels = get_top_label_names(scores, tags2full_name, cfg.top_percent)

    visualize_predicted_categories(
        top_labels, scores, tags2full_name, minimal_score=cfg.minimal_score
    )
else:
    st.info("Enter paper details and click 'Predict Categories' to get predictions.")