Spaces:
Sleeping
Sleeping
Commit
·
2073e38
1
Parent(s):
92d5847
index on master: 92d5847 working training code
Browse files- .dockerignore +4 -0
- .gitattributes +1 -0
- .streamlit/config.toml +2 -0
- Dockerfile +43 -0
- README.md +65 -7
- config/inference_config.py +6 -4
- container_setup/Dockerfile +8 -6
- container_setup/credentials +2 -2
- data/checkpoints/checkpoint-12300/config.json +351 -0
- data/checkpoints/checkpoint-12300/model.safetensors +3 -0
- data/checkpoints/checkpoint-12300/special_tokens_map.json +7 -0
- data/checkpoints/checkpoint-12300/tokenizer.json +0 -0
- data/checkpoints/checkpoint-12300/tokenizer_config.json +58 -0
- data/checkpoints/checkpoint-12300/trainer_state.json +0 -0
- data/checkpoints/checkpoint-12300/training_args.bin +0 -0
- data/checkpoints/checkpoint-12300/vocab.txt +0 -0
- entrypoints/app.py +13 -9
- scripts/launch_app.sh +7 -0
- src/app/data_validation.py +5 -4
- src/app/setup_model.py +3 -3
- src/app/visualization.py +4 -1
.dockerignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
data/arxivData.json
|
| 3 |
+
outputs
|
| 4 |
+
__pycache__/*
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
data/checkpoints/checkpoint-12300/model.safetensors filter=lfs diff=lfs merge=lfs -text
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[server]
|
| 2 |
+
fileWatcherType = "none"
|
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13-slim
|
| 2 |
+
|
| 3 |
+
# THIS IS DEVELOPMENT DOCKERFILE
|
| 4 |
+
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && \
|
| 9 |
+
apt-get install -y --no-install-recommends \
|
| 10 |
+
build-essential \
|
| 11 |
+
python3-dev && \
|
| 12 |
+
rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
ARG USER_ID=1000
|
| 15 |
+
ARG GROUP_ID=1000
|
| 16 |
+
|
| 17 |
+
# Create a group and user with the specified UID and GID
|
| 18 |
+
RUN addgroup --gid $GROUP_ID appgroup && \
|
| 19 |
+
adduser --uid $USER_ID --gid $GROUP_ID --shell /bin/bash --disabled-password --gecos "" appuser
|
| 20 |
+
|
| 21 |
+
# Install sudo and grant privileges
|
| 22 |
+
RUN apt-get update && apt-get install -y sudo && \
|
| 23 |
+
echo "appuser ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
|
| 24 |
+
|
| 25 |
+
# Create /app directory with proper ownership
|
| 26 |
+
RUN mkdir -p /app && chown -R appuser:appgroup /app
|
| 27 |
+
|
| 28 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 29 |
+
|
| 30 |
+
# Switch to the new user
|
| 31 |
+
USER appuser
|
| 32 |
+
|
| 33 |
+
WORKDIR /app
|
| 34 |
+
|
| 35 |
+
COPY --chown=appuser:appgroup . /app
|
| 36 |
+
|
| 37 |
+
RUN uv venv .venv
|
| 38 |
+
|
| 39 |
+
RUN uv sync
|
| 40 |
+
|
| 41 |
+
EXPOSE 7860
|
| 42 |
+
|
| 43 |
+
CMD scripts/launch_app.sh
|
README.md
CHANGED
|
@@ -1,22 +1,80 @@
|
|
|
|
|
| 1 |
|
|
|
|
| 2 |
|
|
|
|
| 3 |
|
|
|
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
add .env with COMET_API_KEY, COMET_MODE=GET
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# arXiv Paper Classification
|
| 2 |
|
| 3 |
+
A machine learning application that predicts arXiv categories for academic papers based on their title and abstract. This tool uses a fine-tuned SciBERT model to classify papers into arXiv subject categories. This task is completed as homework for YSDA ML 2 course
|
| 4 |
|
| 5 |
+
I personally hate jupyter-notebooks, so as a proof that i conducted experiments i made Comet ML logger project public.
|
| 6 |
|
| 7 |
+
Latest training logs, configs and other details can be found here https://www.comet.com/adenshulga/arxiv-papers-classification/ef1256f1d4eb4b588da881366eb27578?compareXAxis=step&experiment-tab=panels&showOutliers=true&smoothing=0&xAxis=step
|
| 8 |
|
| 9 |
+
## Installation
|
| 10 |
|
| 11 |
+
There are two relatively close dockerfile configurations. container_setup folder contains scripts and dockerfile to setup interactive developmpent environment. Dockerfile in the root is for deploying a StreamlitApp.
|
| 12 |
|
| 13 |
+
### Streamlit App Setup
|
| 14 |
|
| 15 |
+
1. Clone the repository:
|
| 16 |
+
```bash
|
| 17 |
+
git clone https://github.com/adenshulga/arxiv-paper-classification.git
|
| 18 |
+
cd arxiv-paper-classification
|
| 19 |
+
```
|
| 20 |
|
| 21 |
+
2. Give permissions for executable scripts:
|
| 22 |
+
```
|
| 23 |
+
chmod +x scripts/pipeline.sh scripts/launch_app.sh
|
| 24 |
+
```
|
| 25 |
|
| 26 |
+
3. Build and launch docker container:
|
| 27 |
+
```
|
| 28 |
+
docker build -t arxiv-paper-clf .
|
| 29 |
+
docker run -p 9001:9001 arxiv-paper-clf
|
| 30 |
+
```
|
| 31 |
|
|
|
|
| 32 |
|
| 33 |
+
### Configuration
|
| 34 |
|
| 35 |
+
You can modify the inference settings in `config/inference_config.py`:
|
| 36 |
|
| 37 |
+
- `model_name`: Base model name from Hugging Face
|
| 38 |
+
- `checkpoint_path`: Path to fine-tuned model checkpoint
|
| 39 |
+
- `top_percent`: Cumulative score threshold for showing predictions
|
| 40 |
+
- `minimal_score`: Minimum confidence score to display
|
| 41 |
|
| 42 |
+
## Development and model Training
|
| 43 |
|
| 44 |
+
To enter development environment
|
| 45 |
+
1. Fill container_setup/credentials file
|
| 46 |
+
|
| 47 |
+
2. Give executable permissions to build and launch scripts:
|
| 48 |
+
```
|
| 49 |
+
chmod +x container_setup/build.sh container_setup/launch_script.sh
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
3. Specify resources constrains in ./container_setup/launch_container.sh
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
4. Build and launch docker container
|
| 56 |
+
```
|
| 57 |
+
./container_setup/build.sh
|
| 58 |
+
./container_setup/launch_container.sh
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
5. Attach to running container
|
| 62 |
+
```
|
| 63 |
+
docker attach <container-id>
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
6. Install the dependencies
|
| 67 |
+
```
|
| 68 |
+
uv venv
|
| 69 |
+
uv sync
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
To train the model:
|
| 73 |
+
|
| 74 |
+
1. Load and unzip the arxiv dataset in the `data` folder(https://www.kaggle.com/datasets/neelshah18/arxivdataset)
|
| 75 |
+
2. Configure the process in config/pipeline_config.py
|
| 76 |
+
|
| 77 |
+
Run the training script:
|
| 78 |
+
```
|
| 79 |
+
scripts/pipeline.sh
|
| 80 |
+
```
|
config/inference_config.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
from hydra.core.config_store import ConfigStore
|
| 4 |
|
| 5 |
|
| 6 |
@dataclass
|
|
@@ -8,5 +6,9 @@ class InferenceConfig:
|
|
| 8 |
"""Configuration for inference"""
|
| 9 |
|
| 10 |
model_name: str = "allenai/scibert_scivocab_uncased"
|
| 11 |
-
checkpoint_path: str = "data/checkpoints/checkpoint-
|
| 12 |
top_percent: float = 0.95
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
@dataclass
|
|
|
|
| 6 |
"""Configuration for inference"""
|
| 7 |
|
| 8 |
model_name: str = "allenai/scibert_scivocab_uncased"
|
| 9 |
+
checkpoint_path: str = "data/checkpoints/checkpoint-12300"
|
| 10 |
top_percent: float = 0.95
|
| 11 |
+
minimal_score: float = 0.01
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
cfg = InferenceConfig()
|
container_setup/Dockerfile
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
FROM python:3.13-slim
|
| 2 |
|
|
|
|
|
|
|
| 3 |
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
ENV PYTHONUNBUFFERED=1
|
| 5 |
|
|
@@ -36,16 +38,16 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
|
| 36 |
# Switch to the new user
|
| 37 |
USER appuser
|
| 38 |
|
| 39 |
-
SHELL ["/usr/bin/fish", "-c"]
|
| 40 |
|
| 41 |
WORKDIR /app
|
| 42 |
|
| 43 |
-
COPY --chown=appuser:appgroup . /app
|
| 44 |
|
| 45 |
-
RUN uv venv .venv
|
| 46 |
|
| 47 |
-
RUN uv sync
|
| 48 |
|
| 49 |
-
EXPOSE
|
| 50 |
|
| 51 |
-
# CMD
|
|
|
|
| 1 |
FROM python:3.13-slim
|
| 2 |
|
| 3 |
+
# THIS IS DEVELOPMENT DOCKERFILE
|
| 4 |
+
|
| 5 |
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
ENV PYTHONUNBUFFERED=1
|
| 7 |
|
|
|
|
| 38 |
# Switch to the new user
|
| 39 |
USER appuser
|
| 40 |
|
| 41 |
+
# SHELL ["/usr/bin/fish", "-c"]
|
| 42 |
|
| 43 |
WORKDIR /app
|
| 44 |
|
| 45 |
+
# COPY --chown=appuser:appgroup . /app
|
| 46 |
|
| 47 |
+
# RUN uv venv .venv
|
| 48 |
|
| 49 |
+
# RUN uv sync
|
| 50 |
|
| 51 |
+
# EXPOSE 7860
|
| 52 |
|
| 53 |
+
# CMD scripts/launch_app.sh
|
container_setup/credentials
CHANGED
|
@@ -5,5 +5,5 @@ CONTAINER_NAME=$USER"-arxiv-papers-classification"
|
|
| 5 |
SRC="." # folder to propulse in docker container
|
| 6 |
DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
|
| 7 |
DOCKER_GROUP_ID=$(id -g)
|
| 8 |
-
CONTAINER_PORT=
|
| 9 |
-
INNER_PORT=
|
|
|
|
| 5 |
SRC="." # folder to propulse in docker container
|
| 6 |
DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
|
| 7 |
DOCKER_GROUP_ID=$(id -g)
|
| 8 |
+
CONTAINER_PORT=7860 # used in launch_container file
|
| 9 |
+
INNER_PORT=7860
|
data/checkpoints/checkpoint-12300/config.json
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"hidden_act": "gelu",
|
| 8 |
+
"hidden_dropout_prob": 0.1,
|
| 9 |
+
"hidden_size": 768,
|
| 10 |
+
"id2label": {
|
| 11 |
+
"0": "A.m",
|
| 12 |
+
"1": "Artificial intelligence and nonmonotonic reasoning and belief\n revision",
|
| 13 |
+
"2": "Comptuational science",
|
| 14 |
+
"3": "Computer Science",
|
| 15 |
+
"4": "H.m",
|
| 16 |
+
"5": "IEEE",
|
| 17 |
+
"6": "MIMO, relay, queue-aware, distributive resource control",
|
| 18 |
+
"7": "Mathematical logic and foundations",
|
| 19 |
+
"8": "aaai.org",
|
| 20 |
+
"9": "adap-org",
|
| 21 |
+
"10": "artificial intelligence, approximate reasoning",
|
| 22 |
+
"11": "astro-ph",
|
| 23 |
+
"12": "astro-ph.CO",
|
| 24 |
+
"13": "astro-ph.EP",
|
| 25 |
+
"14": "astro-ph.GA",
|
| 26 |
+
"15": "astro-ph.HE",
|
| 27 |
+
"16": "astro-ph.IM",
|
| 28 |
+
"17": "astro-ph.SR",
|
| 29 |
+
"18": "cmp-lg",
|
| 30 |
+
"19": "cond-mat",
|
| 31 |
+
"20": "cond-mat.dis-nn",
|
| 32 |
+
"21": "cond-mat.mes-hall",
|
| 33 |
+
"22": "cond-mat.mtrl-sci",
|
| 34 |
+
"23": "cond-mat.other",
|
| 35 |
+
"24": "cond-mat.quant-gas",
|
| 36 |
+
"25": "cond-mat.soft",
|
| 37 |
+
"26": "cond-mat.stat-mech",
|
| 38 |
+
"27": "cond-mat.str-el",
|
| 39 |
+
"28": "cond-mat.supr-con",
|
| 40 |
+
"29": "cs.AI",
|
| 41 |
+
"30": "cs.AR",
|
| 42 |
+
"31": "cs.CC",
|
| 43 |
+
"32": "cs.CE",
|
| 44 |
+
"33": "cs.CG",
|
| 45 |
+
"34": "cs.CL",
|
| 46 |
+
"35": "cs.CL, cs.AI, math.CT",
|
| 47 |
+
"36": "cs.CR",
|
| 48 |
+
"37": "cs.CV",
|
| 49 |
+
"38": "cs.CY",
|
| 50 |
+
"39": "cs.DB",
|
| 51 |
+
"40": "cs.DC",
|
| 52 |
+
"41": "cs.DL",
|
| 53 |
+
"42": "cs.DM",
|
| 54 |
+
"43": "cs.DS",
|
| 55 |
+
"44": "cs.ET",
|
| 56 |
+
"45": "cs.FL",
|
| 57 |
+
"46": "cs.GL",
|
| 58 |
+
"47": "cs.GR",
|
| 59 |
+
"48": "cs.GT",
|
| 60 |
+
"49": "cs.HC",
|
| 61 |
+
"50": "cs.IR",
|
| 62 |
+
"51": "cs.IT",
|
| 63 |
+
"52": "cs.LG",
|
| 64 |
+
"53": "cs.LO",
|
| 65 |
+
"54": "cs.MA",
|
| 66 |
+
"55": "cs.MM",
|
| 67 |
+
"56": "cs.MS",
|
| 68 |
+
"57": "cs.NA",
|
| 69 |
+
"58": "cs.NE",
|
| 70 |
+
"59": "cs.NI",
|
| 71 |
+
"60": "cs.OH",
|
| 72 |
+
"61": "cs.OS",
|
| 73 |
+
"62": "cs.PF",
|
| 74 |
+
"63": "cs.PL",
|
| 75 |
+
"64": "cs.RO",
|
| 76 |
+
"65": "cs.SC",
|
| 77 |
+
"66": "cs.SD",
|
| 78 |
+
"67": "cs.SE",
|
| 79 |
+
"68": "cs.SI",
|
| 80 |
+
"69": "cs.SY",
|
| 81 |
+
"70": "econ.EM",
|
| 82 |
+
"71": "eess.AS",
|
| 83 |
+
"72": "eess.IV",
|
| 84 |
+
"73": "eess.SP",
|
| 85 |
+
"74": "gr-qc",
|
| 86 |
+
"75": "hep-ex",
|
| 87 |
+
"76": "hep-lat",
|
| 88 |
+
"77": "hep-ph",
|
| 89 |
+
"78": "hep-th",
|
| 90 |
+
"79": "math-ph",
|
| 91 |
+
"80": "math.AC",
|
| 92 |
+
"81": "math.AG",
|
| 93 |
+
"82": "math.AP",
|
| 94 |
+
"83": "math.AT",
|
| 95 |
+
"84": "math.CA",
|
| 96 |
+
"85": "math.CO",
|
| 97 |
+
"86": "math.CT",
|
| 98 |
+
"87": "math.CV",
|
| 99 |
+
"88": "math.DG",
|
| 100 |
+
"89": "math.DS",
|
| 101 |
+
"90": "math.FA",
|
| 102 |
+
"91": "math.GM",
|
| 103 |
+
"92": "math.GN",
|
| 104 |
+
"93": "math.GR",
|
| 105 |
+
"94": "math.GT",
|
| 106 |
+
"95": "math.HO",
|
| 107 |
+
"96": "math.IT",
|
| 108 |
+
"97": "math.LO",
|
| 109 |
+
"98": "math.MG",
|
| 110 |
+
"99": "math.MP",
|
| 111 |
+
"100": "math.NA",
|
| 112 |
+
"101": "math.NT",
|
| 113 |
+
"102": "math.OA",
|
| 114 |
+
"103": "math.OC",
|
| 115 |
+
"104": "math.PR",
|
| 116 |
+
"105": "math.QA",
|
| 117 |
+
"106": "math.RA",
|
| 118 |
+
"107": "math.RT",
|
| 119 |
+
"108": "math.SP",
|
| 120 |
+
"109": "math.ST",
|
| 121 |
+
"110": "nlin.AO",
|
| 122 |
+
"111": "nlin.AO, nlin.CD, q-bio.NC, physics.bio-ph, cond-mat.dis-nn",
|
| 123 |
+
"112": "nlin.CD",
|
| 124 |
+
"113": "nlin.CG",
|
| 125 |
+
"114": "nlin.PS",
|
| 126 |
+
"115": "nucl-ex",
|
| 127 |
+
"116": "nucl-th",
|
| 128 |
+
"117": "physics.ao-ph",
|
| 129 |
+
"118": "physics.app-ph",
|
| 130 |
+
"119": "physics.bio-ph",
|
| 131 |
+
"120": "physics.chem-ph",
|
| 132 |
+
"121": "physics.class-ph",
|
| 133 |
+
"122": "physics.comp-ph",
|
| 134 |
+
"123": "physics.data-an",
|
| 135 |
+
"124": "physics.flu-dyn",
|
| 136 |
+
"125": "physics.gen-ph",
|
| 137 |
+
"126": "physics.geo-ph",
|
| 138 |
+
"127": "physics.hist-ph",
|
| 139 |
+
"128": "physics.ins-det",
|
| 140 |
+
"129": "physics.med-ph",
|
| 141 |
+
"130": "physics.optics",
|
| 142 |
+
"131": "physics.pop-ph",
|
| 143 |
+
"132": "physics.soc-ph",
|
| 144 |
+
"133": "physics.space-ph",
|
| 145 |
+
"134": "q-bio",
|
| 146 |
+
"135": "q-bio.BM",
|
| 147 |
+
"136": "q-bio.BM, q-bio.MN, q-bio.NC, nlin.AO, nlin.CD",
|
| 148 |
+
"137": "q-bio.CB",
|
| 149 |
+
"138": "q-bio.GN",
|
| 150 |
+
"139": "q-bio.MN",
|
| 151 |
+
"140": "q-bio.NC",
|
| 152 |
+
"141": "q-bio.OT",
|
| 153 |
+
"142": "q-bio.PE",
|
| 154 |
+
"143": "q-bio.QM",
|
| 155 |
+
"144": "q-bio.SC",
|
| 156 |
+
"145": "q-bio.TO",
|
| 157 |
+
"146": "q-fin.CP",
|
| 158 |
+
"147": "q-fin.EC",
|
| 159 |
+
"148": "q-fin.GN",
|
| 160 |
+
"149": "q-fin.PM",
|
| 161 |
+
"150": "q-fin.PR",
|
| 162 |
+
"151": "q-fin.RM",
|
| 163 |
+
"152": "q-fin.ST",
|
| 164 |
+
"153": "q-fin.TR",
|
| 165 |
+
"154": "quant-ph",
|
| 166 |
+
"155": "stat.AP",
|
| 167 |
+
"156": "stat.CO",
|
| 168 |
+
"157": "stat.ME",
|
| 169 |
+
"158": "stat.ML",
|
| 170 |
+
"159": "stat.OT",
|
| 171 |
+
"160": "stat.TH"
|
| 172 |
+
},
|
| 173 |
+
"initializer_range": 0.02,
|
| 174 |
+
"intermediate_size": 3072,
|
| 175 |
+
"label2id": {
|
| 176 |
+
"A.m": 0,
|
| 177 |
+
"Artificial intelligence and nonmonotonic reasoning and belief\n revision": 1,
|
| 178 |
+
"Comptuational science": 2,
|
| 179 |
+
"Computer Science": 3,
|
| 180 |
+
"H.m": 4,
|
| 181 |
+
"IEEE": 5,
|
| 182 |
+
"MIMO, relay, queue-aware, distributive resource control": 6,
|
| 183 |
+
"Mathematical logic and foundations": 7,
|
| 184 |
+
"aaai.org": 8,
|
| 185 |
+
"adap-org": 9,
|
| 186 |
+
"artificial intelligence, approximate reasoning": 10,
|
| 187 |
+
"astro-ph": 11,
|
| 188 |
+
"astro-ph.CO": 12,
|
| 189 |
+
"astro-ph.EP": 13,
|
| 190 |
+
"astro-ph.GA": 14,
|
| 191 |
+
"astro-ph.HE": 15,
|
| 192 |
+
"astro-ph.IM": 16,
|
| 193 |
+
"astro-ph.SR": 17,
|
| 194 |
+
"cmp-lg": 18,
|
| 195 |
+
"cond-mat": 19,
|
| 196 |
+
"cond-mat.dis-nn": 20,
|
| 197 |
+
"cond-mat.mes-hall": 21,
|
| 198 |
+
"cond-mat.mtrl-sci": 22,
|
| 199 |
+
"cond-mat.other": 23,
|
| 200 |
+
"cond-mat.quant-gas": 24,
|
| 201 |
+
"cond-mat.soft": 25,
|
| 202 |
+
"cond-mat.stat-mech": 26,
|
| 203 |
+
"cond-mat.str-el": 27,
|
| 204 |
+
"cond-mat.supr-con": 28,
|
| 205 |
+
"cs.AI": 29,
|
| 206 |
+
"cs.AR": 30,
|
| 207 |
+
"cs.CC": 31,
|
| 208 |
+
"cs.CE": 32,
|
| 209 |
+
"cs.CG": 33,
|
| 210 |
+
"cs.CL": 34,
|
| 211 |
+
"cs.CL, cs.AI, math.CT": 35,
|
| 212 |
+
"cs.CR": 36,
|
| 213 |
+
"cs.CV": 37,
|
| 214 |
+
"cs.CY": 38,
|
| 215 |
+
"cs.DB": 39,
|
| 216 |
+
"cs.DC": 40,
|
| 217 |
+
"cs.DL": 41,
|
| 218 |
+
"cs.DM": 42,
|
| 219 |
+
"cs.DS": 43,
|
| 220 |
+
"cs.ET": 44,
|
| 221 |
+
"cs.FL": 45,
|
| 222 |
+
"cs.GL": 46,
|
| 223 |
+
"cs.GR": 47,
|
| 224 |
+
"cs.GT": 48,
|
| 225 |
+
"cs.HC": 49,
|
| 226 |
+
"cs.IR": 50,
|
| 227 |
+
"cs.IT": 51,
|
| 228 |
+
"cs.LG": 52,
|
| 229 |
+
"cs.LO": 53,
|
| 230 |
+
"cs.MA": 54,
|
| 231 |
+
"cs.MM": 55,
|
| 232 |
+
"cs.MS": 56,
|
| 233 |
+
"cs.NA": 57,
|
| 234 |
+
"cs.NE": 58,
|
| 235 |
+
"cs.NI": 59,
|
| 236 |
+
"cs.OH": 60,
|
| 237 |
+
"cs.OS": 61,
|
| 238 |
+
"cs.PF": 62,
|
| 239 |
+
"cs.PL": 63,
|
| 240 |
+
"cs.RO": 64,
|
| 241 |
+
"cs.SC": 65,
|
| 242 |
+
"cs.SD": 66,
|
| 243 |
+
"cs.SE": 67,
|
| 244 |
+
"cs.SI": 68,
|
| 245 |
+
"cs.SY": 69,
|
| 246 |
+
"econ.EM": 70,
|
| 247 |
+
"eess.AS": 71,
|
| 248 |
+
"eess.IV": 72,
|
| 249 |
+
"eess.SP": 73,
|
| 250 |
+
"gr-qc": 74,
|
| 251 |
+
"hep-ex": 75,
|
| 252 |
+
"hep-lat": 76,
|
| 253 |
+
"hep-ph": 77,
|
| 254 |
+
"hep-th": 78,
|
| 255 |
+
"math-ph": 79,
|
| 256 |
+
"math.AC": 80,
|
| 257 |
+
"math.AG": 81,
|
| 258 |
+
"math.AP": 82,
|
| 259 |
+
"math.AT": 83,
|
| 260 |
+
"math.CA": 84,
|
| 261 |
+
"math.CO": 85,
|
| 262 |
+
"math.CT": 86,
|
| 263 |
+
"math.CV": 87,
|
| 264 |
+
"math.DG": 88,
|
| 265 |
+
"math.DS": 89,
|
| 266 |
+
"math.FA": 90,
|
| 267 |
+
"math.GM": 91,
|
| 268 |
+
"math.GN": 92,
|
| 269 |
+
"math.GR": 93,
|
| 270 |
+
"math.GT": 94,
|
| 271 |
+
"math.HO": 95,
|
| 272 |
+
"math.IT": 96,
|
| 273 |
+
"math.LO": 97,
|
| 274 |
+
"math.MG": 98,
|
| 275 |
+
"math.MP": 99,
|
| 276 |
+
"math.NA": 100,
|
| 277 |
+
"math.NT": 101,
|
| 278 |
+
"math.OA": 102,
|
| 279 |
+
"math.OC": 103,
|
| 280 |
+
"math.PR": 104,
|
| 281 |
+
"math.QA": 105,
|
| 282 |
+
"math.RA": 106,
|
| 283 |
+
"math.RT": 107,
|
| 284 |
+
"math.SP": 108,
|
| 285 |
+
"math.ST": 109,
|
| 286 |
+
"nlin.AO": 110,
|
| 287 |
+
"nlin.AO, nlin.CD, q-bio.NC, physics.bio-ph, cond-mat.dis-nn": 111,
|
| 288 |
+
"nlin.CD": 112,
|
| 289 |
+
"nlin.CG": 113,
|
| 290 |
+
"nlin.PS": 114,
|
| 291 |
+
"nucl-ex": 115,
|
| 292 |
+
"nucl-th": 116,
|
| 293 |
+
"physics.ao-ph": 117,
|
| 294 |
+
"physics.app-ph": 118,
|
| 295 |
+
"physics.bio-ph": 119,
|
| 296 |
+
"physics.chem-ph": 120,
|
| 297 |
+
"physics.class-ph": 121,
|
| 298 |
+
"physics.comp-ph": 122,
|
| 299 |
+
"physics.data-an": 123,
|
| 300 |
+
"physics.flu-dyn": 124,
|
| 301 |
+
"physics.gen-ph": 125,
|
| 302 |
+
"physics.geo-ph": 126,
|
| 303 |
+
"physics.hist-ph": 127,
|
| 304 |
+
"physics.ins-det": 128,
|
| 305 |
+
"physics.med-ph": 129,
|
| 306 |
+
"physics.optics": 130,
|
| 307 |
+
"physics.pop-ph": 131,
|
| 308 |
+
"physics.soc-ph": 132,
|
| 309 |
+
"physics.space-ph": 133,
|
| 310 |
+
"q-bio": 134,
|
| 311 |
+
"q-bio.BM": 135,
|
| 312 |
+
"q-bio.BM, q-bio.MN, q-bio.NC, nlin.AO, nlin.CD": 136,
|
| 313 |
+
"q-bio.CB": 137,
|
| 314 |
+
"q-bio.GN": 138,
|
| 315 |
+
"q-bio.MN": 139,
|
| 316 |
+
"q-bio.NC": 140,
|
| 317 |
+
"q-bio.OT": 141,
|
| 318 |
+
"q-bio.PE": 142,
|
| 319 |
+
"q-bio.QM": 143,
|
| 320 |
+
"q-bio.SC": 144,
|
| 321 |
+
"q-bio.TO": 145,
|
| 322 |
+
"q-fin.CP": 146,
|
| 323 |
+
"q-fin.EC": 147,
|
| 324 |
+
"q-fin.GN": 148,
|
| 325 |
+
"q-fin.PM": 149,
|
| 326 |
+
"q-fin.PR": 150,
|
| 327 |
+
"q-fin.RM": 151,
|
| 328 |
+
"q-fin.ST": 152,
|
| 329 |
+
"q-fin.TR": 153,
|
| 330 |
+
"quant-ph": 154,
|
| 331 |
+
"stat.AP": 155,
|
| 332 |
+
"stat.CO": 156,
|
| 333 |
+
"stat.ME": 157,
|
| 334 |
+
"stat.ML": 158,
|
| 335 |
+
"stat.OT": 159,
|
| 336 |
+
"stat.TH": 160
|
| 337 |
+
},
|
| 338 |
+
"layer_norm_eps": 1e-12,
|
| 339 |
+
"max_position_embeddings": 512,
|
| 340 |
+
"model_type": "bert",
|
| 341 |
+
"num_attention_heads": 12,
|
| 342 |
+
"num_hidden_layers": 12,
|
| 343 |
+
"pad_token_id": 0,
|
| 344 |
+
"position_embedding_type": "absolute",
|
| 345 |
+
"problem_type": "multi_label_classification",
|
| 346 |
+
"torch_dtype": "float32",
|
| 347 |
+
"transformers_version": "4.50.3",
|
| 348 |
+
"type_vocab_size": 2,
|
| 349 |
+
"use_cache": true,
|
| 350 |
+
"vocab_size": 31090
|
| 351 |
+
}
|
data/checkpoints/checkpoint-12300/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe26497cd2f66e6db8739c13ac359fe50c8023ff6bb897c80df6ddae58a77cd3
|
| 3 |
+
size 440192628
|
data/checkpoints/checkpoint-12300/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
data/checkpoints/checkpoint-12300/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/checkpoints/checkpoint-12300/tokenizer_config.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"101": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"102": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"103": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"104": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 51 |
+
"never_split": null,
|
| 52 |
+
"pad_token": "[PAD]",
|
| 53 |
+
"sep_token": "[SEP]",
|
| 54 |
+
"strip_accents": null,
|
| 55 |
+
"tokenize_chinese_chars": true,
|
| 56 |
+
"tokenizer_class": "BertTokenizer",
|
| 57 |
+
"unk_token": "[UNK]"
|
| 58 |
+
}
|
data/checkpoints/checkpoint-12300/trainer_state.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/checkpoints/checkpoint-12300/training_args.bin
ADDED
|
Binary file (5.37 kB). View file
|
|
|
data/checkpoints/checkpoint-12300/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
entrypoints/app.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
| 3 |
from src.app.tags_mapping import tags2full_name
|
| 4 |
from src.app.visualization import visualize_predicted_categories
|
| 5 |
-
from config.inference_config import InferenceConfig
|
| 6 |
-
from src.app.data_validation import validate_data
|
| 7 |
|
| 8 |
st.title("arXiv Paper Classifier")
|
| 9 |
st.markdown("Enter paper details to predict arXiv categories")
|
|
@@ -11,17 +12,20 @@ st.markdown("Enter paper details to predict arXiv categories")
|
|
| 11 |
st.text_input("Enter paper name", key="paper_name")
|
| 12 |
st.text_area("Enter paper abstract", key="paper_abstract", height=250)
|
| 13 |
|
| 14 |
-
if st.button("Predict Categories", type="primary")
|
| 15 |
-
|
|
|
|
| 16 |
with st.spinner("Analyzing paper..."):
|
| 17 |
-
pipeline = setup_pipeline(
|
| 18 |
scores: list[LabelScore] = pipeline(
|
| 19 |
st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
|
| 20 |
-
|
| 21 |
) # type: ignore
|
| 22 |
|
| 23 |
-
top_labels = get_top_label_names(scores, tags2full_name,
|
| 24 |
|
| 25 |
-
visualize_predicted_categories(
|
|
|
|
|
|
|
| 26 |
else:
|
| 27 |
st.info("Enter paper details and click 'Predict Categories' to get predictions.")
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from config.inference_config import cfg
|
| 4 |
+
from src.app.data_validation import validate_data
|
| 5 |
+
from src.app.setup_model import LabelScore, get_top_label_names, setup_pipeline
|
| 6 |
from src.app.tags_mapping import tags2full_name
|
| 7 |
from src.app.visualization import visualize_predicted_categories
|
|
|
|
|
|
|
| 8 |
|
| 9 |
st.title("arXiv Paper Classifier")
|
| 10 |
st.markdown("Enter paper details to predict arXiv categories")
|
|
|
|
| 12 |
st.text_input("Enter paper name", key="paper_name")
|
| 13 |
st.text_area("Enter paper abstract", key="paper_abstract", height=250)
|
| 14 |
|
| 15 |
+
if st.button("Predict Categories", type="primary") and validate_data(
|
| 16 |
+
st.session_state["paper_name"], st.session_state["paper_abstract"]
|
| 17 |
+
):
|
| 18 |
with st.spinner("Analyzing paper..."):
|
| 19 |
+
pipeline = setup_pipeline(cfg)
|
| 20 |
scores: list[LabelScore] = pipeline(
|
| 21 |
st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
|
| 22 |
+
top_k=None,
|
| 23 |
) # type: ignore
|
| 24 |
|
| 25 |
+
top_labels = get_top_label_names(scores, tags2full_name, cfg.top_percent)
|
| 26 |
|
| 27 |
+
visualize_predicted_categories(
|
| 28 |
+
top_labels, scores, tags2full_name, minimal_score=cfg.minimal_score
|
| 29 |
+
)
|
| 30 |
else:
|
| 31 |
st.info("Enter paper details and click 'Predict Categories' to get predictions.")
|
scripts/launch_app.sh
CHANGED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export PYTHONPATH='.'
|
| 4 |
+
|
| 5 |
+
# source .venv/bin/activate
|
| 6 |
+
|
| 7 |
+
CUDA_VISIBLE_DEVICES="" uv run -m streamlit run entrypoints/app.py --server.address=0.0.0.0 --server.port=9001
|
src/app/data_validation.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
|
| 4 |
-
def validate_data(paper_name: str, paper_abstract: str) ->
|
| 5 |
-
if paper_name == ""
|
| 6 |
st.error("Paper name or abstract are required")
|
| 7 |
-
return
|
| 8 |
if paper_abstract == "":
|
| 9 |
st.warning(
|
| 10 |
"Without abstract, the performance of the model will be significantly worse"
|
| 11 |
)
|
| 12 |
-
return
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
|
| 4 |
+
def validate_data(paper_name: str, paper_abstract: str) -> bool:
|
| 5 |
+
if paper_name == "" and paper_abstract == "":
|
| 6 |
st.error("Paper name or abstract are required")
|
| 7 |
+
return False
|
| 8 |
if paper_abstract == "":
|
| 9 |
st.warning(
|
| 10 |
"Without abstract, the performance of the model will be significantly worse"
|
| 11 |
)
|
| 12 |
+
return True
|
| 13 |
+
return True
|
src/app/setup_model.py
CHANGED
|
@@ -3,8 +3,6 @@ import typing as tp
|
|
| 3 |
from config.inference_config import InferenceConfig
|
| 4 |
import streamlit as st
|
| 5 |
|
| 6 |
-
from src.app.tags_mapping import tags2full_name
|
| 7 |
-
|
| 8 |
|
| 9 |
class LabelScore(tp.TypedDict):
|
| 10 |
label: str
|
|
@@ -14,7 +12,9 @@ class LabelScore(tp.TypedDict):
|
|
| 14 |
@st.cache_resource
|
| 15 |
def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
|
| 16 |
model = pipeline(
|
| 17 |
-
"text-classification",
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
return model
|
| 20 |
|
|
|
|
| 3 |
from config.inference_config import InferenceConfig
|
| 4 |
import streamlit as st
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class LabelScore(tp.TypedDict):
|
| 8 |
label: str
|
|
|
|
| 12 |
@st.cache_resource
|
| 13 |
def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
|
| 14 |
model = pipeline(
|
| 15 |
+
"text-classification",
|
| 16 |
+
model=cfg.checkpoint_path,
|
| 17 |
+
tokenizer=cfg.model_name,
|
| 18 |
)
|
| 19 |
return model
|
| 20 |
|
src/app/visualization.py
CHANGED
|
@@ -7,6 +7,7 @@ def visualize_predicted_categories(
|
|
| 7 |
top_labels: list[LabelScore],
|
| 8 |
scores: list[LabelScore],
|
| 9 |
label_to_name_mapping: Dict[str, str],
|
|
|
|
| 10 |
):
|
| 11 |
"""
|
| 12 |
Visualize the predicted categories in a streamlit app
|
|
@@ -19,7 +20,9 @@ def visualize_predicted_categories(
|
|
| 19 |
st.subheader("Predicted Categories")
|
| 20 |
|
| 21 |
for i, label in enumerate(top_labels):
|
| 22 |
-
score =
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Color gradient based on confidence
|
| 25 |
color_intensity = min(int(score * 255), 255)
|
|
|
|
| 7 |
top_labels: list[LabelScore],
|
| 8 |
scores: list[LabelScore],
|
| 9 |
label_to_name_mapping: Dict[str, str],
|
| 10 |
+
minimal_score: float = 0.01,
|
| 11 |
):
|
| 12 |
"""
|
| 13 |
Visualize the predicted categories in a streamlit app
|
|
|
|
| 20 |
st.subheader("Predicted Categories")
|
| 21 |
|
| 22 |
for i, label in enumerate(top_labels):
|
| 23 |
+
score = label["score"]
|
| 24 |
+
if score < minimal_score:
|
| 25 |
+
continue
|
| 26 |
|
| 27 |
# Color gradient based on confidence
|
| 28 |
color_intensity = min(int(score * 255), 255)
|