Spaces:
Sleeping
Sleeping
Andrey Shulga
commited on
Commit
·
92d5847
1
Parent(s):
97f90b2
working training code
Browse files- .gitignore +3 -0
- README.md +2 -1
- config/inference_config.py +12 -0
- config/pipeline_config.py +12 -4
- container_setup/Dockerfile +2 -4
- container_setup/build.sh +1 -1
- container_setup/credentials +2 -2
- container_setup/launch_container.sh +1 -0
- entrypoints/app.py +27 -60
- pyproject.toml +1 -0
- scripts/pipeline.sh +1 -1
- src/app/data_validation.py +12 -0
- src/app/setup_model.py +49 -0
- src/app/tags_mapping.py +157 -0
- src/app/visualization.py +57 -0
- src/pipeline/arxiv_dataset.py +2 -1
- src/pipeline/env_setup.py +1 -0
- src/pipeline/metrics.py +30 -7
- uv.lock +20 -0
.gitignore
CHANGED
|
@@ -8,6 +8,9 @@ wheels/
|
|
| 8 |
|
| 9 |
outputs
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
data
|
| 12 |
NOTES.md
|
| 13 |
.env
|
|
|
|
| 8 |
|
| 9 |
outputs
|
| 10 |
|
| 11 |
+
notebooks
|
| 12 |
+
|
| 13 |
+
|
| 14 |
data
|
| 15 |
NOTES.md
|
| 16 |
.env
|
README.md
CHANGED
|
@@ -17,5 +17,6 @@ add .env with COMET_API_KEY, COMET_MODE=GET
|
|
| 17 |
chmod +x on scripts
|
| 18 |
|
| 19 |
|
|
|
|
| 20 |
|
| 21 |
-
|
|
|
|
| 17 |
chmod +x on scripts
|
| 18 |
|
| 19 |
|
| 20 |
+
add data to folder data
|
| 21 |
|
| 22 |
+
specify cuda device in pipeline script
|
config/inference_config.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
from hydra.core.config_store import ConfigStore
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class InferenceConfig:
|
| 8 |
+
"""Configuration for inference"""
|
| 9 |
+
|
| 10 |
+
model_name: str = "allenai/scibert_scivocab_uncased"
|
| 11 |
+
checkpoint_path: str = "data/checkpoints/checkpoint-10"
|
| 12 |
+
top_percent: float = 0.95
|
config/pipeline_config.py
CHANGED
|
@@ -16,13 +16,13 @@ class DatasetConfig:
|
|
| 16 |
class CustomTrainingArguments:
|
| 17 |
output_dir: str = "data/checkpoints"
|
| 18 |
overwrite_output_dir: bool = True
|
| 19 |
-
num_train_epochs: float =
|
| 20 |
learning_rate: float = 5e-5
|
| 21 |
lr_scheduler_type: str = "cosine"
|
| 22 |
# lr_scheduler_kwargs={},
|
| 23 |
warmup_ratio: float = 0.03125
|
| 24 |
-
warmup_steps: int =
|
| 25 |
-
# per_device_train_batch_size: int =
|
| 26 |
gradient_accumulation_steps: int = 1
|
| 27 |
log_level: str = "error"
|
| 28 |
# logging_dir="output_dir/runs/CURRENT_DATETIME_HOSTNAME" # логи для tensorboard (default)
|
|
@@ -46,13 +46,21 @@ class CustomTrainingArguments:
|
|
| 46 |
# resume_from_checkpoint: str = "last-checkpoint"
|
| 47 |
auto_find_batch_size: bool = True
|
| 48 |
report_to: str = "comet_ml"
|
|
|
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
@dataclass
|
| 52 |
class ModelConfig:
|
| 53 |
"""Configuration for model architecture and parameters"""
|
| 54 |
|
| 55 |
-
model_name: str = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 16 |
class CustomTrainingArguments:
|
| 17 |
output_dir: str = "data/checkpoints"
|
| 18 |
overwrite_output_dir: bool = True
|
| 19 |
+
num_train_epochs: float = 3
|
| 20 |
learning_rate: float = 5e-5
|
| 21 |
lr_scheduler_type: str = "cosine"
|
| 22 |
# lr_scheduler_kwargs={},
|
| 23 |
warmup_ratio: float = 0.03125
|
| 24 |
+
warmup_steps: int = 1
|
| 25 |
+
# per_device_train_batch_size: int = 64
|
| 26 |
gradient_accumulation_steps: int = 1
|
| 27 |
log_level: str = "error"
|
| 28 |
# logging_dir="output_dir/runs/CURRENT_DATETIME_HOSTNAME" # логи для tensorboard (default)
|
|
|
|
| 46 |
# resume_from_checkpoint: str = "last-checkpoint"
|
| 47 |
auto_find_batch_size: bool = True
|
| 48 |
report_to: str = "comet_ml"
|
| 49 |
+
metric_for_best_model: str = "f1"
|
| 50 |
+
greater_is_better: bool = True
|
| 51 |
|
| 52 |
|
| 53 |
@dataclass
|
| 54 |
class ModelConfig:
|
| 55 |
"""Configuration for model architecture and parameters"""
|
| 56 |
|
| 57 |
+
model_name: str = "allenai/scibert_scivocab_uncased"
|
| 58 |
+
|
| 59 |
+
# model_name: tp.Literal[
|
| 60 |
+
# "FacebookAI/roberta-base",
|
| 61 |
+
# "distilbert-base-uncased",
|
| 62 |
+
# "allenai/scibert_scivocab_uncased",
|
| 63 |
+
# ] = "allenai/scibert_scivocab_uncased"
|
| 64 |
|
| 65 |
|
| 66 |
@dataclass
|
container_setup/Dockerfile
CHANGED
|
@@ -42,12 +42,10 @@ WORKDIR /app
|
|
| 42 |
|
| 43 |
COPY --chown=appuser:appgroup . /app
|
| 44 |
|
| 45 |
-
# 1) Create a dedicated venv in .venv
|
| 46 |
RUN uv venv .venv
|
| 47 |
|
| 48 |
-
# 2) Install / sync packages into that .venv
|
| 49 |
RUN uv sync
|
| 50 |
|
| 51 |
-
EXPOSE
|
| 52 |
|
| 53 |
-
CMD ["uv", "run", "python", "entrypoints/app.py"]
|
|
|
|
| 42 |
|
| 43 |
COPY --chown=appuser:appgroup . /app
|
| 44 |
|
|
|
|
| 45 |
RUN uv venv .venv
|
| 46 |
|
|
|
|
| 47 |
RUN uv sync
|
| 48 |
|
| 49 |
+
EXPOSE 9000
|
| 50 |
|
| 51 |
+
# CMD ["uv", "run", "python", "entrypoints/app.py"]
|
container_setup/build.sh
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
source container_setup/credentials
|
| 4 |
|
| 5 |
-
docker build
|
| 6 |
--build-arg DOCKER_NAME=${DOCKER_NAME} \
|
| 7 |
--build-arg USER_ID=${DOCKER_USER_ID} \
|
| 8 |
--build-arg GROUP_ID=${DOCKER_GROUP_ID}
|
|
|
|
| 2 |
|
| 3 |
source container_setup/credentials
|
| 4 |
|
| 5 |
+
docker build -f container_setup/Dockerfile -t ${DOCKER_NAME} . \
|
| 6 |
--build-arg DOCKER_NAME=${DOCKER_NAME} \
|
| 7 |
--build-arg USER_ID=${DOCKER_USER_ID} \
|
| 8 |
--build-arg GROUP_ID=${DOCKER_GROUP_ID}
|
container_setup/credentials
CHANGED
|
@@ -5,5 +5,5 @@ CONTAINER_NAME=$USER"-arxiv-papers-classification"
|
|
| 5 |
SRC="." # folder to propulse in docker container
|
| 6 |
DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
|
| 7 |
DOCKER_GROUP_ID=$(id -g)
|
| 8 |
-
CONTAINER_PORT=
|
| 9 |
-
INNER_PORT=
|
|
|
|
| 5 |
SRC="." # folder to propulse in docker container
|
| 6 |
DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
|
| 7 |
DOCKER_GROUP_ID=$(id -g)
|
| 8 |
+
CONTAINER_PORT=9001 # used in launch_container file
|
| 9 |
+
INNER_PORT=9001
|
container_setup/launch_container.sh
CHANGED
|
@@ -11,6 +11,7 @@ docker run \
|
|
| 11 |
--rm \
|
| 12 |
-it \
|
| 13 |
--init \
|
|
|
|
| 14 |
-v ${SRC}:/app \
|
| 15 |
-p ${INNER_PORT}:${CONTAINER_PORT} \
|
| 16 |
${DOCKER_NAME} \
|
|
|
|
| 11 |
--rm \
|
| 12 |
-it \
|
| 13 |
--init \
|
| 14 |
+
--gpus '"device=0,1,2"' \
|
| 15 |
-v ${SRC}:/app \
|
| 16 |
-p ${INNER_PORT}:${CONTAINER_PORT} \
|
| 17 |
${DOCKER_NAME} \
|
entrypoints/app.py
CHANGED
|
@@ -1,60 +1,27 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# st.error("Error")
|
| 29 |
-
# exp = ZeroDivisionError("Trying to divide by Zero")
|
| 30 |
-
# st.exception(exp)
|
| 31 |
-
|
| 32 |
-
# # инициализируем переменные
|
| 33 |
-
# st.session_state.key1 = "value1" # Attribute API
|
| 34 |
-
# st.session_state["key2"] = "value2" # Dictionary like API
|
| 35 |
-
|
| 36 |
-
# # посмотреть что в st.session_state
|
| 37 |
-
# st.write(st.session_state)
|
| 38 |
-
|
| 39 |
-
# # magic
|
| 40 |
-
# st.session_state
|
| 41 |
-
|
| 42 |
-
# # ошибка если неправильный ключ
|
| 43 |
-
# # st.write(st.session_state["missing_key"])
|
| 44 |
-
|
| 45 |
-
# import streamlit as st
|
| 46 |
-
# from transformers import pipeline
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# @st.cache_resource # кэширование
|
| 50 |
-
# def load_model():
|
| 51 |
-
# return pipeline("sentiment-analysis") # скачивание модели
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# model = load_model()
|
| 55 |
-
|
| 56 |
-
# query = st.text_input("Your query", value="I love Streamlit! 🎈")
|
| 57 |
-
# if query:
|
| 58 |
-
# result = model(query)[0] # классифицируем
|
| 59 |
-
# st.write(query)
|
| 60 |
-
# st.write(result)
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from src.app.setup_model import setup_pipeline, get_top_label_names, LabelScore
|
| 3 |
+
from src.app.tags_mapping import tags2full_name
|
| 4 |
+
from src.app.visualization import visualize_predicted_categories
|
| 5 |
+
from config.inference_config import InferenceConfig
|
| 6 |
+
from src.app.data_validation import validate_data
|
| 7 |
+
|
| 8 |
+
st.title("arXiv Paper Classifier")
|
| 9 |
+
st.markdown("Enter paper details to predict arXiv categories")
|
| 10 |
+
|
| 11 |
+
st.text_input("Enter paper name", key="paper_name")
|
| 12 |
+
st.text_area("Enter paper abstract", key="paper_abstract", height=250)
|
| 13 |
+
|
| 14 |
+
if st.button("Predict Categories", type="primary"):
|
| 15 |
+
validate_data(st.session_state["paper_name"], st.session_state["paper_abstract"])
|
| 16 |
+
with st.spinner("Analyzing paper..."):
|
| 17 |
+
pipeline = setup_pipeline(InferenceConfig())
|
| 18 |
+
scores: list[LabelScore] = pipeline(
|
| 19 |
+
st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
|
| 20 |
+
output_scores=True,
|
| 21 |
+
) # type: ignore
|
| 22 |
+
|
| 23 |
+
top_labels = get_top_label_names(scores, tags2full_name, 0.95)
|
| 24 |
+
|
| 25 |
+
visualize_predicted_categories(top_labels, scores, tags2full_name)
|
| 26 |
+
else:
|
| 27 |
+
st.info("Enter paper details and click 'Predict Categories' to get predictions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -17,6 +17,7 @@ dependencies = [
|
|
| 17 |
"python-dotenv>=1.1.0",
|
| 18 |
"scikit-learn>=1.6.1",
|
| 19 |
"streamlit>=1.44.1",
|
|
|
|
| 20 |
"torch>=2.6.0",
|
| 21 |
"transformers>=4.50.3",
|
| 22 |
]
|
|
|
|
| 17 |
"python-dotenv>=1.1.0",
|
| 18 |
"scikit-learn>=1.6.1",
|
| 19 |
"streamlit>=1.44.1",
|
| 20 |
+
"tiktoken>=0.9.0",
|
| 21 |
"torch>=2.6.0",
|
| 22 |
"transformers>=4.50.3",
|
| 23 |
]
|
scripts/pipeline.sh
CHANGED
|
@@ -4,4 +4,4 @@ export PYTHONPATH='.'
|
|
| 4 |
|
| 5 |
source .venv/bin/activate
|
| 6 |
|
| 7 |
-
|
|
|
|
| 4 |
|
| 5 |
source .venv/bin/activate
|
| 6 |
|
| 7 |
+
CUDA_VISIBLE_DEVICES=0 uv run entrypoints/pipeline.py
|
src/app/data_validation.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def validate_data(paper_name: str, paper_abstract: str) -> None:
|
| 5 |
+
if paper_name == "" or paper_abstract == "":
|
| 6 |
+
st.error("Paper name or abstract are required")
|
| 7 |
+
return
|
| 8 |
+
if paper_abstract == "":
|
| 9 |
+
st.warning(
|
| 10 |
+
"Without abstract, the performance of the model will be significantly worse"
|
| 11 |
+
)
|
| 12 |
+
return
|
src/app/setup_model.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline, Pipeline
|
| 2 |
+
import typing as tp
|
| 3 |
+
from config.inference_config import InferenceConfig
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
from src.app.tags_mapping import tags2full_name
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LabelScore(tp.TypedDict):
|
| 10 |
+
label: str
|
| 11 |
+
score: float
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@st.cache_resource
|
| 15 |
+
def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
|
| 16 |
+
model = pipeline(
|
| 17 |
+
"text-classification", model=cfg.checkpoint_path, tokenizer=cfg.model_name
|
| 18 |
+
)
|
| 19 |
+
return model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_top_labels(scores: list[LabelScore], top_percent: float) -> list[LabelScore]:
|
| 23 |
+
top_scores = sorted(scores, key=lambda x: x["score"], reverse=True)
|
| 24 |
+
cumulative_score = 0
|
| 25 |
+
selected_labels: list[LabelScore] = []
|
| 26 |
+
for score in top_scores:
|
| 27 |
+
cumulative_score += score["score"]
|
| 28 |
+
selected_labels.append(score)
|
| 29 |
+
if cumulative_score >= top_percent:
|
| 30 |
+
break
|
| 31 |
+
return selected_labels
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_full_names(
|
| 35 |
+
labels: list[LabelScore], label2name: dict[str, str]
|
| 36 |
+
) -> list[LabelScore]:
|
| 37 |
+
return [
|
| 38 |
+
LabelScore(label=label2name[label["label"]], score=label["score"])
|
| 39 |
+
if label["label"] in label2name
|
| 40 |
+
else LabelScore(label=label["label"], score=label["score"])
|
| 41 |
+
for label in labels
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_top_label_names(
|
| 46 |
+
scores: list[LabelScore], label2name: dict[str, str], top_percent: float
|
| 47 |
+
) -> list[LabelScore]:
|
| 48 |
+
top_labels = get_top_labels(scores, top_percent)
|
| 49 |
+
return get_full_names(top_labels, label2name)
|
src/app/tags_mapping.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tags2full_name = {
|
| 2 |
+
"cs.AI": "Artificial Intelligence",
|
| 3 |
+
"cs.AR": "Hardware Architecture",
|
| 4 |
+
"cs.CC": "Computational Complexity",
|
| 5 |
+
"cs.CE": "Computational Engineering, Finance, and Science",
|
| 6 |
+
"cs.CG": "Computational Geometry",
|
| 7 |
+
"cs.CL": "Computation and Language",
|
| 8 |
+
"cs.CR": "Cryptography and Security",
|
| 9 |
+
"cs.CV": "Computer Vision and Pattern Recognition",
|
| 10 |
+
"cs.CY": "Computers and Society",
|
| 11 |
+
"cs.DB": "Databases",
|
| 12 |
+
"cs.DC": "Distributed, Parallel, and Cluster Computing",
|
| 13 |
+
"cs.DL": "Digital Libraries",
|
| 14 |
+
"cs.DM": "Discrete Mathematics",
|
| 15 |
+
"cs.DS": "Data Structures and Algorithms",
|
| 16 |
+
"cs.ET": "Emerging Technologies",
|
| 17 |
+
"cs.FL": "Formal Languages and Automata Theory",
|
| 18 |
+
"cs.GL": "General Literature",
|
| 19 |
+
"cs.GR": "Graphics",
|
| 20 |
+
"cs.GT": "Computer Science and Game Theory",
|
| 21 |
+
"cs.HC": "Human-Computer Interaction",
|
| 22 |
+
"cs.IR": "Information Retrieval",
|
| 23 |
+
"cs.IT": "Information Theory",
|
| 24 |
+
"cs.LG": "Machine Learning",
|
| 25 |
+
"cs.LO": "Logic in Computer Science",
|
| 26 |
+
"cs.MA": "Multiagent Systems",
|
| 27 |
+
"cs.MM": "Multimedia",
|
| 28 |
+
"cs.MS": "Mathematical Software",
|
| 29 |
+
"cs.NA": "Numerical Analysis",
|
| 30 |
+
"cs.NE": "Neural and Evolutionary Computing",
|
| 31 |
+
"cs.NI": "Networking and Internet Architecture",
|
| 32 |
+
"cs.OH": "Other Computer Science",
|
| 33 |
+
"cs.OS": "Operating Systems",
|
| 34 |
+
"cs.PF": "Performance",
|
| 35 |
+
"cs.PL": "Programming Languages",
|
| 36 |
+
"cs.RO": "Robotics",
|
| 37 |
+
"cs.SC": "Symbolic Computation",
|
| 38 |
+
"cs.SD": "Sound",
|
| 39 |
+
"cs.SE": "Software Engineering",
|
| 40 |
+
"cs.SI": "Social and Information Networks",
|
| 41 |
+
"cs.SY": "Systems and Control",
|
| 42 |
+
"econ.EM": "Econometrics",
|
| 43 |
+
"econ.GN": "General Economics",
|
| 44 |
+
"econ.TH": "Theoretical Economics",
|
| 45 |
+
"eess.AS": "Audio and Speech Processing",
|
| 46 |
+
"eess.IV": "Image and Video Processing",
|
| 47 |
+
"eess.SP": "Signal Processing",
|
| 48 |
+
"eess.SY": "Systems and Control",
|
| 49 |
+
"math.AC": "Commutative Algebra",
|
| 50 |
+
"math.AG": "Algebraic Geometry",
|
| 51 |
+
"math.AP": "Analysis of PDEs",
|
| 52 |
+
"math.AT": "Algebraic Topology",
|
| 53 |
+
"math.CA": "Classical Analysis and ODEs",
|
| 54 |
+
"math.CO": "Combinatorics",
|
| 55 |
+
"math.CT": "Category Theory",
|
| 56 |
+
"math.CV": "Complex Variables",
|
| 57 |
+
"math.DG": "Differential Geometry",
|
| 58 |
+
"math.DS": "Dynamical Systems",
|
| 59 |
+
"math.FA": "Functional Analysis",
|
| 60 |
+
"math.GM": "General Mathematics",
|
| 61 |
+
"math.GN": "General Topology",
|
| 62 |
+
"math.GR": "Group Theory",
|
| 63 |
+
"math.GT": "Geometric Topology",
|
| 64 |
+
"math.HO": "History and Overview",
|
| 65 |
+
"math.IT": "Information Theory",
|
| 66 |
+
"math.KT": "K-Theory and Homology",
|
| 67 |
+
"math.LO": "Logic",
|
| 68 |
+
"math.MG": "Metric Geometry",
|
| 69 |
+
"math.MP": "Mathematical Physics",
|
| 70 |
+
"math.NA": "Numerical Analysis",
|
| 71 |
+
"math.NT": "Number Theory",
|
| 72 |
+
"math.OA": "Operator Algebras",
|
| 73 |
+
"math.OC": "Optimization and Control",
|
| 74 |
+
"math.PR": "Probability",
|
| 75 |
+
"math.QA": "Quantum Algebra",
|
| 76 |
+
"math.RA": "Rings and Algebras",
|
| 77 |
+
"math.RT": "Representation Theory",
|
| 78 |
+
"math.SG": "Symplectic Geometry",
|
| 79 |
+
"math.SP": "Spectral Theory",
|
| 80 |
+
"math.ST": "Statistics Theory",
|
| 81 |
+
"astro-ph.CO": "Cosmology and Nongalactic Astrophysics",
|
| 82 |
+
"astro-ph.EP": "Earth and Planetary Astrophysics",
|
| 83 |
+
"astro-ph.GA": "Astrophysics of Galaxies",
|
| 84 |
+
"astro-ph.HE": "High Energy Astrophysical Phenomena",
|
| 85 |
+
"astro-ph.IM": "Instrumentation and Methods for Astrophysics",
|
| 86 |
+
"astro-ph.SR": "Solar and Stellar Astrophysics",
|
| 87 |
+
"cond-mat.dis-nn": "Disordered Systems and Neural Networks",
|
| 88 |
+
"cond-mat.mes-hall": "Mesoscale and Nanoscale Physics",
|
| 89 |
+
"cond-mat.mtrl-sci": "Materials Science",
|
| 90 |
+
"cond-mat.other": "Other Condensed Matter",
|
| 91 |
+
"cond-mat.quant-gas": "Quantum Gases",
|
| 92 |
+
"cond-mat.soft": "Soft Condensed Matter",
|
| 93 |
+
"cond-mat.stat-mech": "Statistical Mechanics",
|
| 94 |
+
"cond-mat.str-el": "Strongly Correlated Electrons",
|
| 95 |
+
"cond-mat.supr-con": "Superconductivity",
|
| 96 |
+
"gr-qc": "General Relativity and Quantum Cosmology",
|
| 97 |
+
"hep-ex": "High Energy Physics - Experiment",
|
| 98 |
+
"hep-lat": "High Energy Physics - Lattice",
|
| 99 |
+
"hep-ph": "High Energy Physics - Phenomenology",
|
| 100 |
+
"hep-th": "High Energy Physics - Theory",
|
| 101 |
+
"math-ph": "Mathematical Physics",
|
| 102 |
+
"nlin.AO": "Adaptation and Self-Organizing Systems",
|
| 103 |
+
"nlin.CD": "Chaotic Dynamics",
|
| 104 |
+
"nlin.CG": "Cellular Automata and Lattice Gases",
|
| 105 |
+
"nlin.PS": "Pattern Formation and Solitons",
|
| 106 |
+
"nlin.SI": "Exactly Solvable and Integrable Systems",
|
| 107 |
+
"nucl-ex": "Nuclear Experiment",
|
| 108 |
+
"nucl-th": "Nuclear Theory",
|
| 109 |
+
"physics.acc-ph": "Accelerator Physics",
|
| 110 |
+
"physics.ao-ph": "Atmospheric and Oceanic Physics",
|
| 111 |
+
"physics.app-ph": "Applied Physics",
|
| 112 |
+
"physics.atm-clus": "Atomic and Molecular Clusters",
|
| 113 |
+
"physics.atom-ph": "Atomic Physics",
|
| 114 |
+
"physics.bio-ph": "Biological Physics",
|
| 115 |
+
"physics.chem-ph": "Chemical Physics",
|
| 116 |
+
"physics.class-ph": "Classical Physics",
|
| 117 |
+
"physics.comp-ph": "Computational Physics",
|
| 118 |
+
"physics.data-an": "Data Analysis, Statistics and Probability",
|
| 119 |
+
"physics.ed-ph": "Physics Education",
|
| 120 |
+
"physics.flu-dyn": "Fluid Dynamics",
|
| 121 |
+
"physics.gen-ph": "General Physics",
|
| 122 |
+
"physics.geo-ph": "Geophysics",
|
| 123 |
+
"physics.hist-ph": "History and Philosophy of Physics",
|
| 124 |
+
"physics.ins-det": "Instrumentation and Detectors",
|
| 125 |
+
"physics.med-ph": "Medical Physics",
|
| 126 |
+
"physics.optics": "Optics",
|
| 127 |
+
"physics.plasm-ph": "Plasma Physics",
|
| 128 |
+
"physics.pop-ph": "Popular Physics",
|
| 129 |
+
"physics.soc-ph": "Physics and Society",
|
| 130 |
+
"physics.space-ph": "Space Physics",
|
| 131 |
+
"quant-ph": "Quantum Physics",
|
| 132 |
+
"q-bio.BM": "Biomolecules",
|
| 133 |
+
"q-bio.CB": "Cell Behavior",
|
| 134 |
+
"q-bio.GN": "Genomics",
|
| 135 |
+
"q-bio.MN": "Molecular Networks",
|
| 136 |
+
"q-bio.NC": "Neurons and Cognition",
|
| 137 |
+
"q-bio.OT": "Other Quantitative Biology",
|
| 138 |
+
"q-bio.PE": "Populations and Evolution",
|
| 139 |
+
"q-bio.QM": "Quantitative Methods",
|
| 140 |
+
"q-bio.SC": "Subcellular Processes",
|
| 141 |
+
"q-bio.TO": "Tissues and Organs",
|
| 142 |
+
"q-fin.CP": "Computational Finance",
|
| 143 |
+
"q-fin.EC": "Economics",
|
| 144 |
+
"q-fin.GN": "General Finance",
|
| 145 |
+
"q-fin.MF": "Mathematical Finance",
|
| 146 |
+
"q-fin.PM": "Portfolio Management",
|
| 147 |
+
"q-fin.PR": "Pricing of Securities",
|
| 148 |
+
"q-fin.RM": "Risk Management",
|
| 149 |
+
"q-fin.ST": "Statistical Finance",
|
| 150 |
+
"q-fin.TR": "Trading and Market Microstructure",
|
| 151 |
+
"stat.AP": "Applications",
|
| 152 |
+
"stat.CO": "Computation",
|
| 153 |
+
"stat.ME": "Methodology",
|
| 154 |
+
"stat.ML": "Machine Learning",
|
| 155 |
+
"stat.OT": "Other Statistics",
|
| 156 |
+
"stat.TH": "Statistics Theory",
|
| 157 |
+
}
|
src/app/visualization.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from src.app.setup_model import LabelScore
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def visualize_predicted_categories(
|
| 7 |
+
top_labels: list[LabelScore],
|
| 8 |
+
scores: list[LabelScore],
|
| 9 |
+
label_to_name_mapping: Dict[str, str],
|
| 10 |
+
):
|
| 11 |
+
"""
|
| 12 |
+
Visualize the predicted categories in a streamlit app
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
top_labels: List of top labels to display
|
| 16 |
+
scores: All scores from the model
|
| 17 |
+
label_to_name_mapping: Mapping from label codes to full names
|
| 18 |
+
"""
|
| 19 |
+
st.subheader("Predicted Categories")
|
| 20 |
+
|
| 21 |
+
for i, label in enumerate(top_labels):
|
| 22 |
+
score = next((s["score"] for s in scores if s["label"] == label["label"]), 0)
|
| 23 |
+
|
| 24 |
+
# Color gradient based on confidence
|
| 25 |
+
color_intensity = min(int(score * 255), 255)
|
| 26 |
+
|
| 27 |
+
with st.container(border=True):
|
| 28 |
+
cols = st.columns([3, 1])
|
| 29 |
+
with cols[0]:
|
| 30 |
+
# Access full_name from the mapping if available
|
| 31 |
+
full_name = label_to_name_mapping.get(label["label"], label["label"])
|
| 32 |
+
st.markdown(f"**{full_name}**")
|
| 33 |
+
st.caption(f"Tag: {label['label']}")
|
| 34 |
+
with cols[1]:
|
| 35 |
+
st.markdown(
|
| 36 |
+
f"<h3 style='text-align: right; color: rgb(0, {color_intensity}, {255 - color_intensity});'>{score:.2f}</h3>",
|
| 37 |
+
unsafe_allow_html=True,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
display_all_scores(scores, label_to_name_mapping)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def display_all_scores(scores: list[LabelScore], label_to_name_mapping: Dict[str, str]):
|
| 44 |
+
"""
|
| 45 |
+
Display all scores in an expandable section
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
scores: All scores from the model
|
| 49 |
+
label_to_name_mapping: Mapping from label codes to full names
|
| 50 |
+
"""
|
| 51 |
+
with st.expander("View all category scores"):
|
| 52 |
+
sorted_scores = sorted(scores, key=lambda x: x["score"], reverse=True)
|
| 53 |
+
for score_item in sorted_scores[:20]: # Show top 20
|
| 54 |
+
label_name = label_to_name_mapping.get(
|
| 55 |
+
score_item["label"], score_item["label"]
|
| 56 |
+
)
|
| 57 |
+
st.text(f"{label_name} ({score_item['label']}): {score_item['score']:.4f}")
|
src/pipeline/arxiv_dataset.py
CHANGED
|
@@ -22,7 +22,7 @@ class ArxivPaper(tp.TypedDict):
|
|
| 22 |
|
| 23 |
|
| 24 |
def load_arxiv_dataset() -> Dataset:
|
| 25 |
-
df = pd.read_json("data/arxivData.json")
|
| 26 |
dataset = Dataset.from_pandas(df[["summary", "tag", "title"]])
|
| 27 |
return dataset
|
| 28 |
|
|
@@ -51,6 +51,7 @@ def generate_preprocessing_function(
|
|
| 51 |
text,
|
| 52 |
truncation=True,
|
| 53 |
padding="max_length",
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
tags_list: list[ArxivTag] = ast.literal_eval(row["tag"])
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def load_arxiv_dataset() -> Dataset:
|
| 25 |
+
df = pd.read_json("data/arxivData.json")
|
| 26 |
dataset = Dataset.from_pandas(df[["summary", "tag", "title"]])
|
| 27 |
return dataset
|
| 28 |
|
|
|
|
| 51 |
text,
|
| 52 |
truncation=True,
|
| 53 |
padding="max_length",
|
| 54 |
+
max_length=512,
|
| 55 |
)
|
| 56 |
|
| 57 |
tags_list: list[ArxivTag] = ast.literal_eval(row["tag"])
|
src/pipeline/env_setup.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
import os
|
|
|
|
| 3 |
|
| 4 |
REQUIRED_ENV_VARS = ["COMET_API_KEY", "COMET_MODE"]
|
| 5 |
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
import os
|
| 3 |
+
import torch
|
| 4 |
|
| 5 |
REQUIRED_ENV_VARS = ["COMET_API_KEY", "COMET_MODE"]
|
| 6 |
|
src/pipeline/metrics.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
-
import evaluate
|
| 2 |
import numpy as np
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def sigmoid(x):
|
|
@@ -10,8 +14,27 @@ def sigmoid(x):
|
|
| 10 |
|
| 11 |
def compute_metrics(eval_pred):
|
| 12 |
predictions, labels = eval_pred
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
predictions=predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from sklearn.metrics import (
|
| 3 |
+
accuracy_score,
|
| 4 |
+
f1_score,
|
| 5 |
+
precision_score,
|
| 6 |
+
recall_score,
|
| 7 |
+
roc_auc_score,
|
| 8 |
+
)
|
| 9 |
|
| 10 |
|
| 11 |
def sigmoid(x):
|
|
|
|
| 14 |
|
| 15 |
def compute_metrics(eval_pred):
|
| 16 |
predictions, labels = eval_pred
|
| 17 |
+
|
| 18 |
+
# Handle T5 model output which can be a tuple
|
| 19 |
+
if isinstance(predictions, tuple):
|
| 20 |
+
predictions = predictions[0]
|
| 21 |
+
|
| 22 |
+
prediction_scores = sigmoid(predictions)
|
| 23 |
+
predictions = (prediction_scores > 0.5).astype(int)
|
| 24 |
+
|
| 25 |
+
# Multi-label metrics
|
| 26 |
+
accuracy = accuracy_score(labels, predictions)
|
| 27 |
+
roc_auc = roc_auc_score(labels, prediction_scores)
|
| 28 |
+
f1 = f1_score(labels, predictions, average="weighted", zero_division=0)
|
| 29 |
+
precision = precision_score(
|
| 30 |
+
labels, predictions, average="weighted", zero_division=0
|
| 31 |
)
|
| 32 |
+
recall = recall_score(labels, predictions, average="weighted", zero_division=0)
|
| 33 |
+
|
| 34 |
+
return {
|
| 35 |
+
"accuracy": accuracy,
|
| 36 |
+
"f1": f1,
|
| 37 |
+
"precision": precision,
|
| 38 |
+
"recall": recall,
|
| 39 |
+
"roc_auc": roc_auc,
|
| 40 |
+
}
|
uv.lock
CHANGED
|
@@ -122,6 +122,7 @@ dependencies = [
|
|
| 122 |
{ name = "python-dotenv" },
|
| 123 |
{ name = "scikit-learn" },
|
| 124 |
{ name = "streamlit" },
|
|
|
|
| 125 |
{ name = "torch" },
|
| 126 |
{ name = "transformers" },
|
| 127 |
]
|
|
@@ -140,6 +141,7 @@ requires-dist = [
|
|
| 140 |
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
| 141 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
| 142 |
{ name = "streamlit", specifier = ">=1.44.1" },
|
|
|
|
| 143 |
{ name = "torch", specifier = ">=2.6.0" },
|
| 144 |
{ name = "transformers", specifier = ">=4.50.3" },
|
| 145 |
]
|
|
@@ -1717,6 +1719,24 @@ wheels = [
|
|
| 1717 |
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 },
|
| 1718 |
]
|
| 1719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1720 |
[[package]]
|
| 1721 |
name = "tokenizers"
|
| 1722 |
version = "0.21.1"
|
|
|
|
| 122 |
{ name = "python-dotenv" },
|
| 123 |
{ name = "scikit-learn" },
|
| 124 |
{ name = "streamlit" },
|
| 125 |
+
{ name = "tiktoken" },
|
| 126 |
{ name = "torch" },
|
| 127 |
{ name = "transformers" },
|
| 128 |
]
|
|
|
|
| 141 |
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
| 142 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
| 143 |
{ name = "streamlit", specifier = ">=1.44.1" },
|
| 144 |
+
{ name = "tiktoken", specifier = ">=0.9.0" },
|
| 145 |
{ name = "torch", specifier = ">=2.6.0" },
|
| 146 |
{ name = "transformers", specifier = ">=4.50.3" },
|
| 147 |
]
|
|
|
|
| 1719 |
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 },
|
| 1720 |
]
|
| 1721 |
|
| 1722 |
+
[[package]]
|
| 1723 |
+
name = "tiktoken"
|
| 1724 |
+
version = "0.9.0"
|
| 1725 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1726 |
+
dependencies = [
|
| 1727 |
+
{ name = "regex" },
|
| 1728 |
+
{ name = "requests" },
|
| 1729 |
+
]
|
| 1730 |
+
sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991 }
|
| 1731 |
+
wheels = [
|
| 1732 |
+
{ url = "https://files.pythonhosted.org/packages/7a/11/09d936d37f49f4f494ffe660af44acd2d99eb2429d60a57c71318af214e0/tiktoken-0.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b0e8e05a26eda1249e824156d537015480af7ae222ccb798e5234ae0285dbdb", size = 1064919 },
|
| 1733 |
+
{ url = "https://files.pythonhosted.org/packages/80/0e/f38ba35713edb8d4197ae602e80837d574244ced7fb1b6070b31c29816e0/tiktoken-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:27d457f096f87685195eea0165a1807fae87b97b2161fe8c9b1df5bd74ca6f63", size = 1007877 },
|
| 1734 |
+
{ url = "https://files.pythonhosted.org/packages/fe/82/9197f77421e2a01373e27a79dd36efdd99e6b4115746ecc553318ecafbf0/tiktoken-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cf8ded49cddf825390e36dd1ad35cd49589e8161fdcb52aa25f0583e90a3e01", size = 1140095 },
|
| 1735 |
+
{ url = "https://files.pythonhosted.org/packages/f2/bb/4513da71cac187383541facd0291c4572b03ec23c561de5811781bbd988f/tiktoken-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc156cb314119a8bb9748257a2eaebd5cc0753b6cb491d26694ed42fc7cb3139", size = 1195649 },
|
| 1736 |
+
{ url = "https://files.pythonhosted.org/packages/fa/5c/74e4c137530dd8504e97e3a41729b1103a4ac29036cbfd3250b11fd29451/tiktoken-0.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cd69372e8c9dd761f0ab873112aba55a0e3e506332dd9f7522ca466e817b1b7a", size = 1258465 },
|
| 1737 |
+
{ url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669 },
|
| 1738 |
+
]
|
| 1739 |
+
|
| 1740 |
[[package]]
|
| 1741 |
name = "tokenizers"
|
| 1742 |
version = "0.21.1"
|