Andrey Shulga commited on
Commit
92d5847
·
1 Parent(s): 97f90b2

working training code

Browse files
.gitignore CHANGED
@@ -8,6 +8,9 @@ wheels/
8
 
9
  outputs
10
 
 
 
 
11
  data
12
  NOTES.md
13
  .env
 
8
 
9
  outputs
10
 
11
+ notebooks
12
+
13
+
14
  data
15
  NOTES.md
16
  .env
README.md CHANGED
@@ -17,5 +17,6 @@ add .env with COMET_API_KEY, COMET_MODE=GET
17
  chmod +x on scripts
18
 
19
 
 
20
 
21
-
 
17
  chmod +x on scripts
18
 
19
 
20
+ add data to folder data
21
 
22
+ specify cuda device in pipeline script
config/inference_config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ from hydra.core.config_store import ConfigStore
4
+
5
+
6
+ @dataclass
7
+ class InferenceConfig:
8
+ """Configuration for inference"""
9
+
10
+ model_name: str = "allenai/scibert_scivocab_uncased"
11
+ checkpoint_path: str = "data/checkpoints/checkpoint-10"
12
+ top_percent: float = 0.95
config/pipeline_config.py CHANGED
@@ -16,13 +16,13 @@ class DatasetConfig:
16
  class CustomTrainingArguments:
17
  output_dir: str = "data/checkpoints"
18
  overwrite_output_dir: bool = True
19
- num_train_epochs: float = 10
20
  learning_rate: float = 5e-5
21
  lr_scheduler_type: str = "cosine"
22
  # lr_scheduler_kwargs={},
23
  warmup_ratio: float = 0.03125
24
- warmup_steps: int = 10
25
- # per_device_train_batch_size: int = 32
26
  gradient_accumulation_steps: int = 1
27
  log_level: str = "error"
28
  # logging_dir="output_dir/runs/CURRENT_DATETIME_HOSTNAME" # логи для tensorboard (default)
@@ -46,13 +46,21 @@ class CustomTrainingArguments:
46
  # resume_from_checkpoint: str = "last-checkpoint"
47
  auto_find_batch_size: bool = True
48
  report_to: str = "comet_ml"
 
 
49
 
50
 
51
  @dataclass
52
  class ModelConfig:
53
  """Configuration for model architecture and parameters"""
54
 
55
- model_name: str = "bert-base-uncased"
 
 
 
 
 
 
56
 
57
 
58
  @dataclass
 
16
  class CustomTrainingArguments:
17
  output_dir: str = "data/checkpoints"
18
  overwrite_output_dir: bool = True
19
+ num_train_epochs: float = 3
20
  learning_rate: float = 5e-5
21
  lr_scheduler_type: str = "cosine"
22
  # lr_scheduler_kwargs={},
23
  warmup_ratio: float = 0.03125
24
+ warmup_steps: int = 1
25
+ # per_device_train_batch_size: int = 64
26
  gradient_accumulation_steps: int = 1
27
  log_level: str = "error"
28
  # logging_dir="output_dir/runs/CURRENT_DATETIME_HOSTNAME" # логи для tensorboard (default)
 
46
  # resume_from_checkpoint: str = "last-checkpoint"
47
  auto_find_batch_size: bool = True
48
  report_to: str = "comet_ml"
49
+ metric_for_best_model: str = "f1"
50
+ greater_is_better: bool = True
51
 
52
 
53
  @dataclass
54
  class ModelConfig:
55
  """Configuration for model architecture and parameters"""
56
 
57
+ model_name: str = "allenai/scibert_scivocab_uncased"
58
+
59
+ # model_name: tp.Literal[
60
+ # "FacebookAI/roberta-base",
61
+ # "distilbert-base-uncased",
62
+ # "allenai/scibert_scivocab_uncased",
63
+ # ] = "allenai/scibert_scivocab_uncased"
64
 
65
 
66
  @dataclass
container_setup/Dockerfile CHANGED
@@ -42,12 +42,10 @@ WORKDIR /app
42
 
43
  COPY --chown=appuser:appgroup . /app
44
 
45
- # 1) Create a dedicated venv in .venv
46
  RUN uv venv .venv
47
 
48
- # 2) Install / sync packages into that .venv
49
  RUN uv sync
50
 
51
- EXPOSE 8000
52
 
53
- CMD ["uv", "run", "python", "entrypoints/app.py"]
 
42
 
43
  COPY --chown=appuser:appgroup . /app
44
 
 
45
  RUN uv venv .venv
46
 
 
47
  RUN uv sync
48
 
49
+ EXPOSE 9000
50
 
51
+ # CMD ["uv", "run", "python", "entrypoints/app.py"]
container_setup/build.sh CHANGED
@@ -2,7 +2,7 @@
2
 
3
  source container_setup/credentials
4
 
5
- docker build --no-cache -f container_setup/Dockerfile -t ${DOCKER_NAME} . \
6
  --build-arg DOCKER_NAME=${DOCKER_NAME} \
7
  --build-arg USER_ID=${DOCKER_USER_ID} \
8
  --build-arg GROUP_ID=${DOCKER_GROUP_ID}
 
2
 
3
  source container_setup/credentials
4
 
5
+ docker build -f container_setup/Dockerfile -t ${DOCKER_NAME} . \
6
  --build-arg DOCKER_NAME=${DOCKER_NAME} \
7
  --build-arg USER_ID=${DOCKER_USER_ID} \
8
  --build-arg GROUP_ID=${DOCKER_GROUP_ID}
container_setup/credentials CHANGED
@@ -5,5 +5,5 @@ CONTAINER_NAME=$USER"-arxiv-papers-classification"
5
  SRC="." # folder to propulse in docker container
6
  DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
7
  DOCKER_GROUP_ID=$(id -g)
8
- CONTAINER_PORT=8001 # used in launch_container file
9
- INNER_PORT=8001
 
5
  SRC="." # folder to propulse in docker container
6
  DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
7
  DOCKER_GROUP_ID=$(id -g)
8
+ CONTAINER_PORT=9001 # used in launch_container file
9
+ INNER_PORT=9001
container_setup/launch_container.sh CHANGED
@@ -11,6 +11,7 @@ docker run \
11
  --rm \
12
  -it \
13
  --init \
 
14
  -v ${SRC}:/app \
15
  -p ${INNER_PORT}:${CONTAINER_PORT} \
16
  ${DOCKER_NAME} \
 
11
  --rm \
12
  -it \
13
  --init \
14
+ --gpus '"device=0,1,2"' \
15
  -v ${SRC}:/app \
16
  -p ${INNER_PORT}:${CONTAINER_PORT} \
17
  ${DOCKER_NAME} \
entrypoints/app.py CHANGED
@@ -1,60 +1,27 @@
1
- # import streamlit as st
2
-
3
- # st.title("This is a title")
4
- # st.header("This is a header")
5
- # st.subheader("This is a subheader")
6
- # st.text("This is a text")
7
- # st.markdown("# This is a markdown header 1")
8
- # st.markdown("## This is a markdown header 2")
9
- # st.markdown("### This is a markdown header 3")
10
- # st.markdown("This is a markdown: *bold* **italic** `inline code` ~strikethrough~")
11
- # st.markdown("""This is a code block with syntax highlighting
12
- # ```python
13
- # print("Hello world!")
14
- # ```
15
- # """)
16
- # st.html(
17
- # "image from url example with html: "
18
- # "<img src='https://www.wallpaperflare.com/static/450/825/286/kitten-cute-animals-grass-5k-wallpaper.jpg' width=400px>",
19
- # )
20
-
21
-
22
- # st.write("Text with write")
23
- # st.write(range(10))
24
-
25
- # st.success("Success")
26
- # st.info("Information")
27
- # st.warning("Warning")
28
- # st.error("Error")
29
- # exp = ZeroDivisionError("Trying to divide by Zero")
30
- # st.exception(exp)
31
-
32
- # # инициализируем переменные
33
- # st.session_state.key1 = "value1" # Attribute API
34
- # st.session_state["key2"] = "value2" # Dictionary like API
35
-
36
- # # посмотреть что в st.session_state
37
- # st.write(st.session_state)
38
-
39
- # # magic
40
- # st.session_state
41
-
42
- # # ошибка если неправильный ключ
43
- # # st.write(st.session_state["missing_key"])
44
-
45
- # import streamlit as st
46
- # from transformers import pipeline
47
-
48
-
49
- # @st.cache_resource # кэширование
50
- # def load_model():
51
- # return pipeline("sentiment-analysis") # скачивание модели
52
-
53
-
54
- # model = load_model()
55
-
56
- # query = st.text_input("Your query", value="I love Streamlit! 🎈")
57
- # if query:
58
- # result = model(query)[0] # классифицируем
59
- # st.write(query)
60
- # st.write(result)
 
1
+ import streamlit as st
2
+ from src.app.setup_model import setup_pipeline, get_top_label_names, LabelScore
3
+ from src.app.tags_mapping import tags2full_name
4
+ from src.app.visualization import visualize_predicted_categories
5
+ from config.inference_config import InferenceConfig
6
+ from src.app.data_validation import validate_data
7
+
8
+ st.title("arXiv Paper Classifier")
9
+ st.markdown("Enter paper details to predict arXiv categories")
10
+
11
+ st.text_input("Enter paper name", key="paper_name")
12
+ st.text_area("Enter paper abstract", key="paper_abstract", height=250)
13
+
14
+ if st.button("Predict Categories", type="primary"):
15
+ validate_data(st.session_state["paper_name"], st.session_state["paper_abstract"])
16
+ with st.spinner("Analyzing paper..."):
17
+ pipeline = setup_pipeline(InferenceConfig())
18
+ scores: list[LabelScore] = pipeline(
19
+ st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
20
+ output_scores=True,
21
+ ) # type: ignore
22
+
23
+ top_labels = get_top_label_names(scores, tags2full_name, 0.95)
24
+
25
+ visualize_predicted_categories(top_labels, scores, tags2full_name)
26
+ else:
27
+ st.info("Enter paper details and click 'Predict Categories' to get predictions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -17,6 +17,7 @@ dependencies = [
17
  "python-dotenv>=1.1.0",
18
  "scikit-learn>=1.6.1",
19
  "streamlit>=1.44.1",
 
20
  "torch>=2.6.0",
21
  "transformers>=4.50.3",
22
  ]
 
17
  "python-dotenv>=1.1.0",
18
  "scikit-learn>=1.6.1",
19
  "streamlit>=1.44.1",
20
+ "tiktoken>=0.9.0",
21
  "torch>=2.6.0",
22
  "transformers>=4.50.3",
23
  ]
scripts/pipeline.sh CHANGED
@@ -4,4 +4,4 @@ export PYTHONPATH='.'
4
 
5
  source .venv/bin/activate
6
 
7
- python entrypoints/pipeline.py
 
4
 
5
  source .venv/bin/activate
6
 
7
+ CUDA_VISIBLE_DEVICES=0 uv run entrypoints/pipeline.py
src/app/data_validation.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def validate_data(paper_name: str, paper_abstract: str) -> None:
5
+ if paper_name == "" or paper_abstract == "":
6
+ st.error("Paper name or abstract are required")
7
+ return
8
+ if paper_abstract == "":
9
+ st.warning(
10
+ "Without abstract, the performance of the model will be significantly worse"
11
+ )
12
+ return
src/app/setup_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, Pipeline
2
+ import typing as tp
3
+ from config.inference_config import InferenceConfig
4
+ import streamlit as st
5
+
6
+ from src.app.tags_mapping import tags2full_name
7
+
8
+
9
+ class LabelScore(tp.TypedDict):
10
+ label: str
11
+ score: float
12
+
13
+
14
+ @st.cache_resource
15
+ def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
16
+ model = pipeline(
17
+ "text-classification", model=cfg.checkpoint_path, tokenizer=cfg.model_name
18
+ )
19
+ return model
20
+
21
+
22
+ def get_top_labels(scores: list[LabelScore], top_percent: float) -> list[LabelScore]:
23
+ top_scores = sorted(scores, key=lambda x: x["score"], reverse=True)
24
+ cumulative_score = 0
25
+ selected_labels: list[LabelScore] = []
26
+ for score in top_scores:
27
+ cumulative_score += score["score"]
28
+ selected_labels.append(score)
29
+ if cumulative_score >= top_percent:
30
+ break
31
+ return selected_labels
32
+
33
+
34
+ def get_full_names(
35
+ labels: list[LabelScore], label2name: dict[str, str]
36
+ ) -> list[LabelScore]:
37
+ return [
38
+ LabelScore(label=label2name[label["label"]], score=label["score"])
39
+ if label["label"] in label2name
40
+ else LabelScore(label=label["label"], score=label["score"])
41
+ for label in labels
42
+ ]
43
+
44
+
45
+ def get_top_label_names(
46
+ scores: list[LabelScore], label2name: dict[str, str], top_percent: float
47
+ ) -> list[LabelScore]:
48
+ top_labels = get_top_labels(scores, top_percent)
49
+ return get_full_names(top_labels, label2name)
src/app/tags_mapping.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tags2full_name = {
2
+ "cs.AI": "Artificial Intelligence",
3
+ "cs.AR": "Hardware Architecture",
4
+ "cs.CC": "Computational Complexity",
5
+ "cs.CE": "Computational Engineering, Finance, and Science",
6
+ "cs.CG": "Computational Geometry",
7
+ "cs.CL": "Computation and Language",
8
+ "cs.CR": "Cryptography and Security",
9
+ "cs.CV": "Computer Vision and Pattern Recognition",
10
+ "cs.CY": "Computers and Society",
11
+ "cs.DB": "Databases",
12
+ "cs.DC": "Distributed, Parallel, and Cluster Computing",
13
+ "cs.DL": "Digital Libraries",
14
+ "cs.DM": "Discrete Mathematics",
15
+ "cs.DS": "Data Structures and Algorithms",
16
+ "cs.ET": "Emerging Technologies",
17
+ "cs.FL": "Formal Languages and Automata Theory",
18
+ "cs.GL": "General Literature",
19
+ "cs.GR": "Graphics",
20
+ "cs.GT": "Computer Science and Game Theory",
21
+ "cs.HC": "Human-Computer Interaction",
22
+ "cs.IR": "Information Retrieval",
23
+ "cs.IT": "Information Theory",
24
+ "cs.LG": "Machine Learning",
25
+ "cs.LO": "Logic in Computer Science",
26
+ "cs.MA": "Multiagent Systems",
27
+ "cs.MM": "Multimedia",
28
+ "cs.MS": "Mathematical Software",
29
+ "cs.NA": "Numerical Analysis",
30
+ "cs.NE": "Neural and Evolutionary Computing",
31
+ "cs.NI": "Networking and Internet Architecture",
32
+ "cs.OH": "Other Computer Science",
33
+ "cs.OS": "Operating Systems",
34
+ "cs.PF": "Performance",
35
+ "cs.PL": "Programming Languages",
36
+ "cs.RO": "Robotics",
37
+ "cs.SC": "Symbolic Computation",
38
+ "cs.SD": "Sound",
39
+ "cs.SE": "Software Engineering",
40
+ "cs.SI": "Social and Information Networks",
41
+ "cs.SY": "Systems and Control",
42
+ "econ.EM": "Econometrics",
43
+ "econ.GN": "General Economics",
44
+ "econ.TH": "Theoretical Economics",
45
+ "eess.AS": "Audio and Speech Processing",
46
+ "eess.IV": "Image and Video Processing",
47
+ "eess.SP": "Signal Processing",
48
+ "eess.SY": "Systems and Control",
49
+ "math.AC": "Commutative Algebra",
50
+ "math.AG": "Algebraic Geometry",
51
+ "math.AP": "Analysis of PDEs",
52
+ "math.AT": "Algebraic Topology",
53
+ "math.CA": "Classical Analysis and ODEs",
54
+ "math.CO": "Combinatorics",
55
+ "math.CT": "Category Theory",
56
+ "math.CV": "Complex Variables",
57
+ "math.DG": "Differential Geometry",
58
+ "math.DS": "Dynamical Systems",
59
+ "math.FA": "Functional Analysis",
60
+ "math.GM": "General Mathematics",
61
+ "math.GN": "General Topology",
62
+ "math.GR": "Group Theory",
63
+ "math.GT": "Geometric Topology",
64
+ "math.HO": "History and Overview",
65
+ "math.IT": "Information Theory",
66
+ "math.KT": "K-Theory and Homology",
67
+ "math.LO": "Logic",
68
+ "math.MG": "Metric Geometry",
69
+ "math.MP": "Mathematical Physics",
70
+ "math.NA": "Numerical Analysis",
71
+ "math.NT": "Number Theory",
72
+ "math.OA": "Operator Algebras",
73
+ "math.OC": "Optimization and Control",
74
+ "math.PR": "Probability",
75
+ "math.QA": "Quantum Algebra",
76
+ "math.RA": "Rings and Algebras",
77
+ "math.RT": "Representation Theory",
78
+ "math.SG": "Symplectic Geometry",
79
+ "math.SP": "Spectral Theory",
80
+ "math.ST": "Statistics Theory",
81
+ "astro-ph.CO": "Cosmology and Nongalactic Astrophysics",
82
+ "astro-ph.EP": "Earth and Planetary Astrophysics",
83
+ "astro-ph.GA": "Astrophysics of Galaxies",
84
+ "astro-ph.HE": "High Energy Astrophysical Phenomena",
85
+ "astro-ph.IM": "Instrumentation and Methods for Astrophysics",
86
+ "astro-ph.SR": "Solar and Stellar Astrophysics",
87
+ "cond-mat.dis-nn": "Disordered Systems and Neural Networks",
88
+ "cond-mat.mes-hall": "Mesoscale and Nanoscale Physics",
89
+ "cond-mat.mtrl-sci": "Materials Science",
90
+ "cond-mat.other": "Other Condensed Matter",
91
+ "cond-mat.quant-gas": "Quantum Gases",
92
+ "cond-mat.soft": "Soft Condensed Matter",
93
+ "cond-mat.stat-mech": "Statistical Mechanics",
94
+ "cond-mat.str-el": "Strongly Correlated Electrons",
95
+ "cond-mat.supr-con": "Superconductivity",
96
+ "gr-qc": "General Relativity and Quantum Cosmology",
97
+ "hep-ex": "High Energy Physics - Experiment",
98
+ "hep-lat": "High Energy Physics - Lattice",
99
+ "hep-ph": "High Energy Physics - Phenomenology",
100
+ "hep-th": "High Energy Physics - Theory",
101
+ "math-ph": "Mathematical Physics",
102
+ "nlin.AO": "Adaptation and Self-Organizing Systems",
103
+ "nlin.CD": "Chaotic Dynamics",
104
+ "nlin.CG": "Cellular Automata and Lattice Gases",
105
+ "nlin.PS": "Pattern Formation and Solitons",
106
+ "nlin.SI": "Exactly Solvable and Integrable Systems",
107
+ "nucl-ex": "Nuclear Experiment",
108
+ "nucl-th": "Nuclear Theory",
109
+ "physics.acc-ph": "Accelerator Physics",
110
+ "physics.ao-ph": "Atmospheric and Oceanic Physics",
111
+ "physics.app-ph": "Applied Physics",
112
+ "physics.atm-clus": "Atomic and Molecular Clusters",
113
+ "physics.atom-ph": "Atomic Physics",
114
+ "physics.bio-ph": "Biological Physics",
115
+ "physics.chem-ph": "Chemical Physics",
116
+ "physics.class-ph": "Classical Physics",
117
+ "physics.comp-ph": "Computational Physics",
118
+ "physics.data-an": "Data Analysis, Statistics and Probability",
119
+ "physics.ed-ph": "Physics Education",
120
+ "physics.flu-dyn": "Fluid Dynamics",
121
+ "physics.gen-ph": "General Physics",
122
+ "physics.geo-ph": "Geophysics",
123
+ "physics.hist-ph": "History and Philosophy of Physics",
124
+ "physics.ins-det": "Instrumentation and Detectors",
125
+ "physics.med-ph": "Medical Physics",
126
+ "physics.optics": "Optics",
127
+ "physics.plasm-ph": "Plasma Physics",
128
+ "physics.pop-ph": "Popular Physics",
129
+ "physics.soc-ph": "Physics and Society",
130
+ "physics.space-ph": "Space Physics",
131
+ "quant-ph": "Quantum Physics",
132
+ "q-bio.BM": "Biomolecules",
133
+ "q-bio.CB": "Cell Behavior",
134
+ "q-bio.GN": "Genomics",
135
+ "q-bio.MN": "Molecular Networks",
136
+ "q-bio.NC": "Neurons and Cognition",
137
+ "q-bio.OT": "Other Quantitative Biology",
138
+ "q-bio.PE": "Populations and Evolution",
139
+ "q-bio.QM": "Quantitative Methods",
140
+ "q-bio.SC": "Subcellular Processes",
141
+ "q-bio.TO": "Tissues and Organs",
142
+ "q-fin.CP": "Computational Finance",
143
+ "q-fin.EC": "Economics",
144
+ "q-fin.GN": "General Finance",
145
+ "q-fin.MF": "Mathematical Finance",
146
+ "q-fin.PM": "Portfolio Management",
147
+ "q-fin.PR": "Pricing of Securities",
148
+ "q-fin.RM": "Risk Management",
149
+ "q-fin.ST": "Statistical Finance",
150
+ "q-fin.TR": "Trading and Market Microstructure",
151
+ "stat.AP": "Applications",
152
+ "stat.CO": "Computation",
153
+ "stat.ME": "Methodology",
154
+ "stat.ML": "Machine Learning",
155
+ "stat.OT": "Other Statistics",
156
+ "stat.TH": "Statistics Theory",
157
+ }
src/app/visualization.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.app.setup_model import LabelScore
3
+ from typing import Dict
4
+
5
+
6
+ def visualize_predicted_categories(
7
+ top_labels: list[LabelScore],
8
+ scores: list[LabelScore],
9
+ label_to_name_mapping: Dict[str, str],
10
+ ):
11
+ """
12
+ Visualize the predicted categories in a streamlit app
13
+
14
+ Args:
15
+ top_labels: List of top labels to display
16
+ scores: All scores from the model
17
+ label_to_name_mapping: Mapping from label codes to full names
18
+ """
19
+ st.subheader("Predicted Categories")
20
+
21
+ for i, label in enumerate(top_labels):
22
+ score = next((s["score"] for s in scores if s["label"] == label["label"]), 0)
23
+
24
+ # Color gradient based on confidence
25
+ color_intensity = min(int(score * 255), 255)
26
+
27
+ with st.container(border=True):
28
+ cols = st.columns([3, 1])
29
+ with cols[0]:
30
+ # Access full_name from the mapping if available
31
+ full_name = label_to_name_mapping.get(label["label"], label["label"])
32
+ st.markdown(f"**{full_name}**")
33
+ st.caption(f"Tag: {label['label']}")
34
+ with cols[1]:
35
+ st.markdown(
36
+ f"<h3 style='text-align: right; color: rgb(0, {color_intensity}, {255 - color_intensity});'>{score:.2f}</h3>",
37
+ unsafe_allow_html=True,
38
+ )
39
+
40
+ display_all_scores(scores, label_to_name_mapping)
41
+
42
+
43
+ def display_all_scores(scores: list[LabelScore], label_to_name_mapping: Dict[str, str]):
44
+ """
45
+ Display all scores in an expandable section
46
+
47
+ Args:
48
+ scores: All scores from the model
49
+ label_to_name_mapping: Mapping from label codes to full names
50
+ """
51
+ with st.expander("View all category scores"):
52
+ sorted_scores = sorted(scores, key=lambda x: x["score"], reverse=True)
53
+ for score_item in sorted_scores[:20]: # Show top 20
54
+ label_name = label_to_name_mapping.get(
55
+ score_item["label"], score_item["label"]
56
+ )
57
+ st.text(f"{label_name} ({score_item['label']}): {score_item['score']:.4f}")
src/pipeline/arxiv_dataset.py CHANGED
@@ -22,7 +22,7 @@ class ArxivPaper(tp.TypedDict):
22
 
23
 
24
  def load_arxiv_dataset() -> Dataset:
25
- df = pd.read_json("data/arxivData.json").head(100)
26
  dataset = Dataset.from_pandas(df[["summary", "tag", "title"]])
27
  return dataset
28
 
@@ -51,6 +51,7 @@ def generate_preprocessing_function(
51
  text,
52
  truncation=True,
53
  padding="max_length",
 
54
  )
55
 
56
  tags_list: list[ArxivTag] = ast.literal_eval(row["tag"])
 
22
 
23
 
24
  def load_arxiv_dataset() -> Dataset:
25
+ df = pd.read_json("data/arxivData.json")
26
  dataset = Dataset.from_pandas(df[["summary", "tag", "title"]])
27
  return dataset
28
 
 
51
  text,
52
  truncation=True,
53
  padding="max_length",
54
+ max_length=512,
55
  )
56
 
57
  tags_list: list[ArxivTag] = ast.literal_eval(row["tag"])
src/pipeline/env_setup.py CHANGED
@@ -1,5 +1,6 @@
1
  from dotenv import load_dotenv
2
  import os
 
3
 
4
  REQUIRED_ENV_VARS = ["COMET_API_KEY", "COMET_MODE"]
5
 
 
1
  from dotenv import load_dotenv
2
  import os
3
+ import torch
4
 
5
  REQUIRED_ENV_VARS = ["COMET_API_KEY", "COMET_MODE"]
6
 
src/pipeline/metrics.py CHANGED
@@ -1,7 +1,11 @@
1
- import evaluate
2
  import numpy as np
3
-
4
- clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
 
 
 
 
 
5
 
6
 
7
  def sigmoid(x):
@@ -10,8 +14,27 @@ def sigmoid(x):
10
 
11
  def compute_metrics(eval_pred):
12
  predictions, labels = eval_pred
13
- predictions = sigmoid(predictions)
14
- predictions = (predictions > 0.5).astype(int).reshape(-1)
15
- return clf_metrics.compute(
16
- predictions=predictions, references=labels.astype(int).reshape(-1)
 
 
 
 
 
 
 
 
 
 
17
  )
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ from sklearn.metrics import (
3
+ accuracy_score,
4
+ f1_score,
5
+ precision_score,
6
+ recall_score,
7
+ roc_auc_score,
8
+ )
9
 
10
 
11
  def sigmoid(x):
 
14
 
15
  def compute_metrics(eval_pred):
16
  predictions, labels = eval_pred
17
+
18
+ # Handle T5 model output which can be a tuple
19
+ if isinstance(predictions, tuple):
20
+ predictions = predictions[0]
21
+
22
+ prediction_scores = sigmoid(predictions)
23
+ predictions = (prediction_scores > 0.5).astype(int)
24
+
25
+ # Multi-label metrics
26
+ accuracy = accuracy_score(labels, predictions)
27
+ roc_auc = roc_auc_score(labels, prediction_scores)
28
+ f1 = f1_score(labels, predictions, average="weighted", zero_division=0)
29
+ precision = precision_score(
30
+ labels, predictions, average="weighted", zero_division=0
31
  )
32
+ recall = recall_score(labels, predictions, average="weighted", zero_division=0)
33
+
34
+ return {
35
+ "accuracy": accuracy,
36
+ "f1": f1,
37
+ "precision": precision,
38
+ "recall": recall,
39
+ "roc_auc": roc_auc,
40
+ }
uv.lock CHANGED
@@ -122,6 +122,7 @@ dependencies = [
122
  { name = "python-dotenv" },
123
  { name = "scikit-learn" },
124
  { name = "streamlit" },
 
125
  { name = "torch" },
126
  { name = "transformers" },
127
  ]
@@ -140,6 +141,7 @@ requires-dist = [
140
  { name = "python-dotenv", specifier = ">=1.1.0" },
141
  { name = "scikit-learn", specifier = ">=1.6.1" },
142
  { name = "streamlit", specifier = ">=1.44.1" },
 
143
  { name = "torch", specifier = ">=2.6.0" },
144
  { name = "transformers", specifier = ">=4.50.3" },
145
  ]
@@ -1717,6 +1719,24 @@ wheels = [
1717
  { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 },
1718
  ]
1719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1720
  [[package]]
1721
  name = "tokenizers"
1722
  version = "0.21.1"
 
122
  { name = "python-dotenv" },
123
  { name = "scikit-learn" },
124
  { name = "streamlit" },
125
+ { name = "tiktoken" },
126
  { name = "torch" },
127
  { name = "transformers" },
128
  ]
 
141
  { name = "python-dotenv", specifier = ">=1.1.0" },
142
  { name = "scikit-learn", specifier = ">=1.6.1" },
143
  { name = "streamlit", specifier = ">=1.44.1" },
144
+ { name = "tiktoken", specifier = ">=0.9.0" },
145
  { name = "torch", specifier = ">=2.6.0" },
146
  { name = "transformers", specifier = ">=4.50.3" },
147
  ]
 
1719
  { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 },
1720
  ]
1721
 
1722
+ [[package]]
1723
+ name = "tiktoken"
1724
+ version = "0.9.0"
1725
+ source = { registry = "https://pypi.org/simple" }
1726
+ dependencies = [
1727
+ { name = "regex" },
1728
+ { name = "requests" },
1729
+ ]
1730
+ sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991 }
1731
+ wheels = [
1732
+ { url = "https://files.pythonhosted.org/packages/7a/11/09d936d37f49f4f494ffe660af44acd2d99eb2429d60a57c71318af214e0/tiktoken-0.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b0e8e05a26eda1249e824156d537015480af7ae222ccb798e5234ae0285dbdb", size = 1064919 },
1733
+ { url = "https://files.pythonhosted.org/packages/80/0e/f38ba35713edb8d4197ae602e80837d574244ced7fb1b6070b31c29816e0/tiktoken-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:27d457f096f87685195eea0165a1807fae87b97b2161fe8c9b1df5bd74ca6f63", size = 1007877 },
1734
+ { url = "https://files.pythonhosted.org/packages/fe/82/9197f77421e2a01373e27a79dd36efdd99e6b4115746ecc553318ecafbf0/tiktoken-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cf8ded49cddf825390e36dd1ad35cd49589e8161fdcb52aa25f0583e90a3e01", size = 1140095 },
1735
+ { url = "https://files.pythonhosted.org/packages/f2/bb/4513da71cac187383541facd0291c4572b03ec23c561de5811781bbd988f/tiktoken-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc156cb314119a8bb9748257a2eaebd5cc0753b6cb491d26694ed42fc7cb3139", size = 1195649 },
1736
+ { url = "https://files.pythonhosted.org/packages/fa/5c/74e4c137530dd8504e97e3a41729b1103a4ac29036cbfd3250b11fd29451/tiktoken-0.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cd69372e8c9dd761f0ab873112aba55a0e3e506332dd9f7522ca466e817b1b7a", size = 1258465 },
1737
+ { url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669 },
1738
+ ]
1739
+
1740
  [[package]]
1741
  name = "tokenizers"
1742
  version = "0.21.1"