adenshulga commited on
Commit
2073e38
·
1 Parent(s): 92d5847

index on master: 92d5847 working training code

Browse files
.dockerignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .venv
2
+ data/arxivData.json
3
+ outputs
4
+ __pycache__/*
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ data/checkpoints/checkpoint-12300/model.safetensors filter=lfs diff=lfs merge=lfs -text
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [server]
2
+ fileWatcherType = "none"
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13-slim
2
+
3
+ # THIS IS DEVELOPMENT DOCKERFILE
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ python3-dev && \
12
+ rm -rf /var/lib/apt/lists/*
13
+
14
+ ARG USER_ID=1000
15
+ ARG GROUP_ID=1000
16
+
17
+ # Create a group and user with the specified UID and GID
18
+ RUN addgroup --gid $GROUP_ID appgroup && \
19
+ adduser --uid $USER_ID --gid $GROUP_ID --shell /bin/bash --disabled-password --gecos "" appuser
20
+
21
+ # Install sudo and grant privileges
22
+ RUN apt-get update && apt-get install -y sudo && \
23
+ echo "appuser ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
24
+
25
+ # Create /app directory with proper ownership
26
+ RUN mkdir -p /app && chown -R appuser:appgroup /app
27
+
28
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
29
+
30
+ # Switch to the new user
31
+ USER appuser
32
+
33
+ WORKDIR /app
34
+
35
+ COPY --chown=appuser:appgroup . /app
36
+
37
+ RUN uv venv .venv
38
+
39
+ RUN uv sync
40
+
41
+ EXPOSE 7860
42
+
43
+ CMD scripts/launch_app.sh
README.md CHANGED
@@ -1,22 +1,80 @@
 
1
 
 
2
 
 
3
 
 
4
 
 
5
 
6
- uv venv
7
 
8
- chmod +x container_setup/build container_setup/launch_container
9
 
10
- in folder data unzip dataset
 
 
 
 
11
 
 
 
 
 
12
 
 
 
 
 
 
13
 
14
- add .env with COMET_API_KEY, COMET_MODE=GET
15
 
 
16
 
17
- chmod +x on scripts
18
 
 
 
 
 
19
 
20
- add data to folder data
21
 
22
- specify cuda device in pipeline script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # arXiv Paper Classification
2
 
3
+ A machine learning application that predicts arXiv categories for academic papers based on their title and abstract. This tool uses a fine-tuned SciBERT model to classify papers into arXiv subject categories. This task is completed as homework for YSDA ML 2 course
4
 
5
+ I personally hate jupyter-notebooks, so as a proof that i conducted experiments i made Comet ML logger project public.
6
 
7
+ Latest training logs, configs and other details can be found here https://www.comet.com/adenshulga/arxiv-papers-classification/ef1256f1d4eb4b588da881366eb27578?compareXAxis=step&experiment-tab=panels&showOutliers=true&smoothing=0&xAxis=step
8
 
9
+ ## Installation
10
 
11
+ There are two relatively close dockerfile configurations. container_setup folder contains scripts and dockerfile to setup interactive developmpent environment. Dockerfile in the root is for deploying a StreamlitApp.
12
 
13
+ ### Streamlit App Setup
14
 
15
+ 1. Clone the repository:
16
+ ```bash
17
+ git clone https://github.com/adenshulga/arxiv-paper-classification.git
18
+ cd arxiv-paper-classification
19
+ ```
20
 
21
+ 2. Give permissions for executable scripts:
22
+ ```
23
+ chmod +x scripts/pipeline.sh scripts/launch_app.sh
24
+ ```
25
 
26
+ 3. Build and launch docker container:
27
+ ```
28
+ docker build -t arxiv-paper-clf .
29
+ docker run -p 9001:9001 arxiv-paper-clf
30
+ ```
31
 
 
32
 
33
+ ### Configuration
34
 
35
+ You can modify the inference settings in `config/inference_config.py`:
36
 
37
+ - `model_name`: Base model name from Hugging Face
38
+ - `checkpoint_path`: Path to fine-tuned model checkpoint
39
+ - `top_percent`: Cumulative score threshold for showing predictions
40
+ - `minimal_score`: Minimum confidence score to display
41
 
42
+ ## Development and model Training
43
 
44
+ To enter development environment
45
+ 1. Fill container_setup/credentials file
46
+
47
+ 2. Give executable permissions to build and launch scripts:
48
+ ```
49
+ chmod +x container_setup/build.sh container_setup/launch_script.sh
50
+ ```
51
+
52
+ 3. Specify resources constrains in ./container_setup/launch_container.sh
53
+
54
+
55
+ 4. Build and launch docker container
56
+ ```
57
+ ./container_setup/build.sh
58
+ ./container_setup/launch_container.sh
59
+ ```
60
+
61
+ 5. Attach to running container
62
+ ```
63
+ docker attach <container-id>
64
+ ```
65
+
66
+ 6. Install the dependencies
67
+ ```
68
+ uv venv
69
+ uv sync
70
+ ```
71
+
72
+ To train the model:
73
+
74
+ 1. Load and unzip the arxiv dataset in the `data` folder(https://www.kaggle.com/datasets/neelshah18/arxivdataset)
75
+ 2. Configure the process in config/pipeline_config.py
76
+
77
+ Run the training script:
78
+ ```
79
+ scripts/pipeline.sh
80
+ ```
config/inference_config.py CHANGED
@@ -1,6 +1,4 @@
1
- from dataclasses import dataclass, field
2
-
3
- from hydra.core.config_store import ConfigStore
4
 
5
 
6
  @dataclass
@@ -8,5 +6,9 @@ class InferenceConfig:
8
  """Configuration for inference"""
9
 
10
  model_name: str = "allenai/scibert_scivocab_uncased"
11
- checkpoint_path: str = "data/checkpoints/checkpoint-10"
12
  top_percent: float = 0.95
 
 
 
 
 
1
+ from dataclasses import dataclass
 
 
2
 
3
 
4
  @dataclass
 
6
  """Configuration for inference"""
7
 
8
  model_name: str = "allenai/scibert_scivocab_uncased"
9
+ checkpoint_path: str = "data/checkpoints/checkpoint-12300"
10
  top_percent: float = 0.95
11
+ minimal_score: float = 0.01
12
+
13
+
14
+ cfg = InferenceConfig()
container_setup/Dockerfile CHANGED
@@ -1,5 +1,7 @@
1
  FROM python:3.13-slim
2
 
 
 
3
  ENV PYTHONDONTWRITEBYTECODE=1
4
  ENV PYTHONUNBUFFERED=1
5
 
@@ -36,16 +38,16 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
36
  # Switch to the new user
37
  USER appuser
38
 
39
- SHELL ["/usr/bin/fish", "-c"]
40
 
41
  WORKDIR /app
42
 
43
- COPY --chown=appuser:appgroup . /app
44
 
45
- RUN uv venv .venv
46
 
47
- RUN uv sync
48
 
49
- EXPOSE 9000
50
 
51
- # CMD ["uv", "run", "python", "entrypoints/app.py"]
 
1
  FROM python:3.13-slim
2
 
3
+ # THIS IS DEVELOPMENT DOCKERFILE
4
+
5
  ENV PYTHONDONTWRITEBYTECODE=1
6
  ENV PYTHONUNBUFFERED=1
7
 
 
38
  # Switch to the new user
39
  USER appuser
40
 
41
+ # SHELL ["/usr/bin/fish", "-c"]
42
 
43
  WORKDIR /app
44
 
45
+ # COPY --chown=appuser:appgroup . /app
46
 
47
+ # RUN uv venv .venv
48
 
49
+ # RUN uv sync
50
 
51
+ # EXPOSE 7860
52
 
53
+ # CMD scripts/launch_app.sh
container_setup/credentials CHANGED
@@ -5,5 +5,5 @@ CONTAINER_NAME=$USER"-arxiv-papers-classification"
5
  SRC="." # folder to propulse in docker container
6
  DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
7
  DOCKER_GROUP_ID=$(id -g)
8
- CONTAINER_PORT=9001 # used in launch_container file
9
- INNER_PORT=9001
 
5
  SRC="." # folder to propulse in docker container
6
  DOCKER_USER_ID=$(id -u) # to get these values type "id" in shell termilal
7
  DOCKER_GROUP_ID=$(id -g)
8
+ CONTAINER_PORT=7860 # used in launch_container file
9
+ INNER_PORT=7860
data/checkpoints/checkpoint-12300/config.json ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "A.m",
12
+ "1": "Artificial intelligence and nonmonotonic reasoning and belief\n revision",
13
+ "2": "Comptuational science",
14
+ "3": "Computer Science",
15
+ "4": "H.m",
16
+ "5": "IEEE",
17
+ "6": "MIMO, relay, queue-aware, distributive resource control",
18
+ "7": "Mathematical logic and foundations",
19
+ "8": "aaai.org",
20
+ "9": "adap-org",
21
+ "10": "artificial intelligence, approximate reasoning",
22
+ "11": "astro-ph",
23
+ "12": "astro-ph.CO",
24
+ "13": "astro-ph.EP",
25
+ "14": "astro-ph.GA",
26
+ "15": "astro-ph.HE",
27
+ "16": "astro-ph.IM",
28
+ "17": "astro-ph.SR",
29
+ "18": "cmp-lg",
30
+ "19": "cond-mat",
31
+ "20": "cond-mat.dis-nn",
32
+ "21": "cond-mat.mes-hall",
33
+ "22": "cond-mat.mtrl-sci",
34
+ "23": "cond-mat.other",
35
+ "24": "cond-mat.quant-gas",
36
+ "25": "cond-mat.soft",
37
+ "26": "cond-mat.stat-mech",
38
+ "27": "cond-mat.str-el",
39
+ "28": "cond-mat.supr-con",
40
+ "29": "cs.AI",
41
+ "30": "cs.AR",
42
+ "31": "cs.CC",
43
+ "32": "cs.CE",
44
+ "33": "cs.CG",
45
+ "34": "cs.CL",
46
+ "35": "cs.CL, cs.AI, math.CT",
47
+ "36": "cs.CR",
48
+ "37": "cs.CV",
49
+ "38": "cs.CY",
50
+ "39": "cs.DB",
51
+ "40": "cs.DC",
52
+ "41": "cs.DL",
53
+ "42": "cs.DM",
54
+ "43": "cs.DS",
55
+ "44": "cs.ET",
56
+ "45": "cs.FL",
57
+ "46": "cs.GL",
58
+ "47": "cs.GR",
59
+ "48": "cs.GT",
60
+ "49": "cs.HC",
61
+ "50": "cs.IR",
62
+ "51": "cs.IT",
63
+ "52": "cs.LG",
64
+ "53": "cs.LO",
65
+ "54": "cs.MA",
66
+ "55": "cs.MM",
67
+ "56": "cs.MS",
68
+ "57": "cs.NA",
69
+ "58": "cs.NE",
70
+ "59": "cs.NI",
71
+ "60": "cs.OH",
72
+ "61": "cs.OS",
73
+ "62": "cs.PF",
74
+ "63": "cs.PL",
75
+ "64": "cs.RO",
76
+ "65": "cs.SC",
77
+ "66": "cs.SD",
78
+ "67": "cs.SE",
79
+ "68": "cs.SI",
80
+ "69": "cs.SY",
81
+ "70": "econ.EM",
82
+ "71": "eess.AS",
83
+ "72": "eess.IV",
84
+ "73": "eess.SP",
85
+ "74": "gr-qc",
86
+ "75": "hep-ex",
87
+ "76": "hep-lat",
88
+ "77": "hep-ph",
89
+ "78": "hep-th",
90
+ "79": "math-ph",
91
+ "80": "math.AC",
92
+ "81": "math.AG",
93
+ "82": "math.AP",
94
+ "83": "math.AT",
95
+ "84": "math.CA",
96
+ "85": "math.CO",
97
+ "86": "math.CT",
98
+ "87": "math.CV",
99
+ "88": "math.DG",
100
+ "89": "math.DS",
101
+ "90": "math.FA",
102
+ "91": "math.GM",
103
+ "92": "math.GN",
104
+ "93": "math.GR",
105
+ "94": "math.GT",
106
+ "95": "math.HO",
107
+ "96": "math.IT",
108
+ "97": "math.LO",
109
+ "98": "math.MG",
110
+ "99": "math.MP",
111
+ "100": "math.NA",
112
+ "101": "math.NT",
113
+ "102": "math.OA",
114
+ "103": "math.OC",
115
+ "104": "math.PR",
116
+ "105": "math.QA",
117
+ "106": "math.RA",
118
+ "107": "math.RT",
119
+ "108": "math.SP",
120
+ "109": "math.ST",
121
+ "110": "nlin.AO",
122
+ "111": "nlin.AO, nlin.CD, q-bio.NC, physics.bio-ph, cond-mat.dis-nn",
123
+ "112": "nlin.CD",
124
+ "113": "nlin.CG",
125
+ "114": "nlin.PS",
126
+ "115": "nucl-ex",
127
+ "116": "nucl-th",
128
+ "117": "physics.ao-ph",
129
+ "118": "physics.app-ph",
130
+ "119": "physics.bio-ph",
131
+ "120": "physics.chem-ph",
132
+ "121": "physics.class-ph",
133
+ "122": "physics.comp-ph",
134
+ "123": "physics.data-an",
135
+ "124": "physics.flu-dyn",
136
+ "125": "physics.gen-ph",
137
+ "126": "physics.geo-ph",
138
+ "127": "physics.hist-ph",
139
+ "128": "physics.ins-det",
140
+ "129": "physics.med-ph",
141
+ "130": "physics.optics",
142
+ "131": "physics.pop-ph",
143
+ "132": "physics.soc-ph",
144
+ "133": "physics.space-ph",
145
+ "134": "q-bio",
146
+ "135": "q-bio.BM",
147
+ "136": "q-bio.BM, q-bio.MN, q-bio.NC, nlin.AO, nlin.CD",
148
+ "137": "q-bio.CB",
149
+ "138": "q-bio.GN",
150
+ "139": "q-bio.MN",
151
+ "140": "q-bio.NC",
152
+ "141": "q-bio.OT",
153
+ "142": "q-bio.PE",
154
+ "143": "q-bio.QM",
155
+ "144": "q-bio.SC",
156
+ "145": "q-bio.TO",
157
+ "146": "q-fin.CP",
158
+ "147": "q-fin.EC",
159
+ "148": "q-fin.GN",
160
+ "149": "q-fin.PM",
161
+ "150": "q-fin.PR",
162
+ "151": "q-fin.RM",
163
+ "152": "q-fin.ST",
164
+ "153": "q-fin.TR",
165
+ "154": "quant-ph",
166
+ "155": "stat.AP",
167
+ "156": "stat.CO",
168
+ "157": "stat.ME",
169
+ "158": "stat.ML",
170
+ "159": "stat.OT",
171
+ "160": "stat.TH"
172
+ },
173
+ "initializer_range": 0.02,
174
+ "intermediate_size": 3072,
175
+ "label2id": {
176
+ "A.m": 0,
177
+ "Artificial intelligence and nonmonotonic reasoning and belief\n revision": 1,
178
+ "Comptuational science": 2,
179
+ "Computer Science": 3,
180
+ "H.m": 4,
181
+ "IEEE": 5,
182
+ "MIMO, relay, queue-aware, distributive resource control": 6,
183
+ "Mathematical logic and foundations": 7,
184
+ "aaai.org": 8,
185
+ "adap-org": 9,
186
+ "artificial intelligence, approximate reasoning": 10,
187
+ "astro-ph": 11,
188
+ "astro-ph.CO": 12,
189
+ "astro-ph.EP": 13,
190
+ "astro-ph.GA": 14,
191
+ "astro-ph.HE": 15,
192
+ "astro-ph.IM": 16,
193
+ "astro-ph.SR": 17,
194
+ "cmp-lg": 18,
195
+ "cond-mat": 19,
196
+ "cond-mat.dis-nn": 20,
197
+ "cond-mat.mes-hall": 21,
198
+ "cond-mat.mtrl-sci": 22,
199
+ "cond-mat.other": 23,
200
+ "cond-mat.quant-gas": 24,
201
+ "cond-mat.soft": 25,
202
+ "cond-mat.stat-mech": 26,
203
+ "cond-mat.str-el": 27,
204
+ "cond-mat.supr-con": 28,
205
+ "cs.AI": 29,
206
+ "cs.AR": 30,
207
+ "cs.CC": 31,
208
+ "cs.CE": 32,
209
+ "cs.CG": 33,
210
+ "cs.CL": 34,
211
+ "cs.CL, cs.AI, math.CT": 35,
212
+ "cs.CR": 36,
213
+ "cs.CV": 37,
214
+ "cs.CY": 38,
215
+ "cs.DB": 39,
216
+ "cs.DC": 40,
217
+ "cs.DL": 41,
218
+ "cs.DM": 42,
219
+ "cs.DS": 43,
220
+ "cs.ET": 44,
221
+ "cs.FL": 45,
222
+ "cs.GL": 46,
223
+ "cs.GR": 47,
224
+ "cs.GT": 48,
225
+ "cs.HC": 49,
226
+ "cs.IR": 50,
227
+ "cs.IT": 51,
228
+ "cs.LG": 52,
229
+ "cs.LO": 53,
230
+ "cs.MA": 54,
231
+ "cs.MM": 55,
232
+ "cs.MS": 56,
233
+ "cs.NA": 57,
234
+ "cs.NE": 58,
235
+ "cs.NI": 59,
236
+ "cs.OH": 60,
237
+ "cs.OS": 61,
238
+ "cs.PF": 62,
239
+ "cs.PL": 63,
240
+ "cs.RO": 64,
241
+ "cs.SC": 65,
242
+ "cs.SD": 66,
243
+ "cs.SE": 67,
244
+ "cs.SI": 68,
245
+ "cs.SY": 69,
246
+ "econ.EM": 70,
247
+ "eess.AS": 71,
248
+ "eess.IV": 72,
249
+ "eess.SP": 73,
250
+ "gr-qc": 74,
251
+ "hep-ex": 75,
252
+ "hep-lat": 76,
253
+ "hep-ph": 77,
254
+ "hep-th": 78,
255
+ "math-ph": 79,
256
+ "math.AC": 80,
257
+ "math.AG": 81,
258
+ "math.AP": 82,
259
+ "math.AT": 83,
260
+ "math.CA": 84,
261
+ "math.CO": 85,
262
+ "math.CT": 86,
263
+ "math.CV": 87,
264
+ "math.DG": 88,
265
+ "math.DS": 89,
266
+ "math.FA": 90,
267
+ "math.GM": 91,
268
+ "math.GN": 92,
269
+ "math.GR": 93,
270
+ "math.GT": 94,
271
+ "math.HO": 95,
272
+ "math.IT": 96,
273
+ "math.LO": 97,
274
+ "math.MG": 98,
275
+ "math.MP": 99,
276
+ "math.NA": 100,
277
+ "math.NT": 101,
278
+ "math.OA": 102,
279
+ "math.OC": 103,
280
+ "math.PR": 104,
281
+ "math.QA": 105,
282
+ "math.RA": 106,
283
+ "math.RT": 107,
284
+ "math.SP": 108,
285
+ "math.ST": 109,
286
+ "nlin.AO": 110,
287
+ "nlin.AO, nlin.CD, q-bio.NC, physics.bio-ph, cond-mat.dis-nn": 111,
288
+ "nlin.CD": 112,
289
+ "nlin.CG": 113,
290
+ "nlin.PS": 114,
291
+ "nucl-ex": 115,
292
+ "nucl-th": 116,
293
+ "physics.ao-ph": 117,
294
+ "physics.app-ph": 118,
295
+ "physics.bio-ph": 119,
296
+ "physics.chem-ph": 120,
297
+ "physics.class-ph": 121,
298
+ "physics.comp-ph": 122,
299
+ "physics.data-an": 123,
300
+ "physics.flu-dyn": 124,
301
+ "physics.gen-ph": 125,
302
+ "physics.geo-ph": 126,
303
+ "physics.hist-ph": 127,
304
+ "physics.ins-det": 128,
305
+ "physics.med-ph": 129,
306
+ "physics.optics": 130,
307
+ "physics.pop-ph": 131,
308
+ "physics.soc-ph": 132,
309
+ "physics.space-ph": 133,
310
+ "q-bio": 134,
311
+ "q-bio.BM": 135,
312
+ "q-bio.BM, q-bio.MN, q-bio.NC, nlin.AO, nlin.CD": 136,
313
+ "q-bio.CB": 137,
314
+ "q-bio.GN": 138,
315
+ "q-bio.MN": 139,
316
+ "q-bio.NC": 140,
317
+ "q-bio.OT": 141,
318
+ "q-bio.PE": 142,
319
+ "q-bio.QM": 143,
320
+ "q-bio.SC": 144,
321
+ "q-bio.TO": 145,
322
+ "q-fin.CP": 146,
323
+ "q-fin.EC": 147,
324
+ "q-fin.GN": 148,
325
+ "q-fin.PM": 149,
326
+ "q-fin.PR": 150,
327
+ "q-fin.RM": 151,
328
+ "q-fin.ST": 152,
329
+ "q-fin.TR": 153,
330
+ "quant-ph": 154,
331
+ "stat.AP": 155,
332
+ "stat.CO": 156,
333
+ "stat.ME": 157,
334
+ "stat.ML": 158,
335
+ "stat.OT": 159,
336
+ "stat.TH": 160
337
+ },
338
+ "layer_norm_eps": 1e-12,
339
+ "max_position_embeddings": 512,
340
+ "model_type": "bert",
341
+ "num_attention_heads": 12,
342
+ "num_hidden_layers": 12,
343
+ "pad_token_id": 0,
344
+ "position_embedding_type": "absolute",
345
+ "problem_type": "multi_label_classification",
346
+ "torch_dtype": "float32",
347
+ "transformers_version": "4.50.3",
348
+ "type_vocab_size": 2,
349
+ "use_cache": true,
350
+ "vocab_size": 31090
351
+ }
data/checkpoints/checkpoint-12300/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe26497cd2f66e6db8739c13ac359fe50c8023ff6bb897c80df6ddae58a77cd3
3
+ size 440192628
data/checkpoints/checkpoint-12300/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
data/checkpoints/checkpoint-12300/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
data/checkpoints/checkpoint-12300/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "101": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "102": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "103": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "104": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
data/checkpoints/checkpoint-12300/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
data/checkpoints/checkpoint-12300/training_args.bin ADDED
Binary file (5.37 kB). View file
 
data/checkpoints/checkpoint-12300/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
entrypoints/app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
2
- from src.app.setup_model import setup_pipeline, get_top_label_names, LabelScore
 
 
 
3
  from src.app.tags_mapping import tags2full_name
4
  from src.app.visualization import visualize_predicted_categories
5
- from config.inference_config import InferenceConfig
6
- from src.app.data_validation import validate_data
7
 
8
  st.title("arXiv Paper Classifier")
9
  st.markdown("Enter paper details to predict arXiv categories")
@@ -11,17 +12,20 @@ st.markdown("Enter paper details to predict arXiv categories")
11
  st.text_input("Enter paper name", key="paper_name")
12
  st.text_area("Enter paper abstract", key="paper_abstract", height=250)
13
 
14
- if st.button("Predict Categories", type="primary"):
15
- validate_data(st.session_state["paper_name"], st.session_state["paper_abstract"])
 
16
  with st.spinner("Analyzing paper..."):
17
- pipeline = setup_pipeline(InferenceConfig())
18
  scores: list[LabelScore] = pipeline(
19
  st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
20
- output_scores=True,
21
  ) # type: ignore
22
 
23
- top_labels = get_top_label_names(scores, tags2full_name, 0.95)
24
 
25
- visualize_predicted_categories(top_labels, scores, tags2full_name)
 
 
26
  else:
27
  st.info("Enter paper details and click 'Predict Categories' to get predictions.")
 
1
  import streamlit as st
2
+
3
+ from config.inference_config import cfg
4
+ from src.app.data_validation import validate_data
5
+ from src.app.setup_model import LabelScore, get_top_label_names, setup_pipeline
6
  from src.app.tags_mapping import tags2full_name
7
  from src.app.visualization import visualize_predicted_categories
 
 
8
 
9
  st.title("arXiv Paper Classifier")
10
  st.markdown("Enter paper details to predict arXiv categories")
 
12
  st.text_input("Enter paper name", key="paper_name")
13
  st.text_area("Enter paper abstract", key="paper_abstract", height=250)
14
 
15
+ if st.button("Predict Categories", type="primary") and validate_data(
16
+ st.session_state["paper_name"], st.session_state["paper_abstract"]
17
+ ):
18
  with st.spinner("Analyzing paper..."):
19
+ pipeline = setup_pipeline(cfg)
20
  scores: list[LabelScore] = pipeline(
21
  st.session_state["paper_name"] + " " + st.session_state["paper_abstract"],
22
+ top_k=None,
23
  ) # type: ignore
24
 
25
+ top_labels = get_top_label_names(scores, tags2full_name, cfg.top_percent)
26
 
27
+ visualize_predicted_categories(
28
+ top_labels, scores, tags2full_name, minimal_score=cfg.minimal_score
29
+ )
30
  else:
31
  st.info("Enter paper details and click 'Predict Categories' to get predictions.")
scripts/launch_app.sh CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export PYTHONPATH='.'
4
+
5
+ # source .venv/bin/activate
6
+
7
+ CUDA_VISIBLE_DEVICES="" uv run -m streamlit run entrypoints/app.py --server.address=0.0.0.0 --server.port=9001
src/app/data_validation.py CHANGED
@@ -1,12 +1,13 @@
1
  import streamlit as st
2
 
3
 
4
- def validate_data(paper_name: str, paper_abstract: str) -> None:
5
- if paper_name == "" or paper_abstract == "":
6
  st.error("Paper name or abstract are required")
7
- return
8
  if paper_abstract == "":
9
  st.warning(
10
  "Without abstract, the performance of the model will be significantly worse"
11
  )
12
- return
 
 
1
  import streamlit as st
2
 
3
 
4
+ def validate_data(paper_name: str, paper_abstract: str) -> bool:
5
+ if paper_name == "" and paper_abstract == "":
6
  st.error("Paper name or abstract are required")
7
+ return False
8
  if paper_abstract == "":
9
  st.warning(
10
  "Without abstract, the performance of the model will be significantly worse"
11
  )
12
+ return True
13
+ return True
src/app/setup_model.py CHANGED
@@ -3,8 +3,6 @@ import typing as tp
3
  from config.inference_config import InferenceConfig
4
  import streamlit as st
5
 
6
- from src.app.tags_mapping import tags2full_name
7
-
8
 
9
  class LabelScore(tp.TypedDict):
10
  label: str
@@ -14,7 +12,9 @@ class LabelScore(tp.TypedDict):
14
  @st.cache_resource
15
  def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
16
  model = pipeline(
17
- "text-classification", model=cfg.checkpoint_path, tokenizer=cfg.model_name
 
 
18
  )
19
  return model
20
 
 
3
  from config.inference_config import InferenceConfig
4
  import streamlit as st
5
 
 
 
6
 
7
  class LabelScore(tp.TypedDict):
8
  label: str
 
12
  @st.cache_resource
13
  def setup_pipeline(cfg: InferenceConfig) -> Pipeline:
14
  model = pipeline(
15
+ "text-classification",
16
+ model=cfg.checkpoint_path,
17
+ tokenizer=cfg.model_name,
18
  )
19
  return model
20
 
src/app/visualization.py CHANGED
@@ -7,6 +7,7 @@ def visualize_predicted_categories(
7
  top_labels: list[LabelScore],
8
  scores: list[LabelScore],
9
  label_to_name_mapping: Dict[str, str],
 
10
  ):
11
  """
12
  Visualize the predicted categories in a streamlit app
@@ -19,7 +20,9 @@ def visualize_predicted_categories(
19
  st.subheader("Predicted Categories")
20
 
21
  for i, label in enumerate(top_labels):
22
- score = next((s["score"] for s in scores if s["label"] == label["label"]), 0)
 
 
23
 
24
  # Color gradient based on confidence
25
  color_intensity = min(int(score * 255), 255)
 
7
  top_labels: list[LabelScore],
8
  scores: list[LabelScore],
9
  label_to_name_mapping: Dict[str, str],
10
+ minimal_score: float = 0.01,
11
  ):
12
  """
13
  Visualize the predicted categories in a streamlit app
 
20
  st.subheader("Predicted Categories")
21
 
22
  for i, label in enumerate(top_labels):
23
+ score = label["score"]
24
+ if score < minimal_score:
25
+ continue
26
 
27
  # Color gradient based on confidence
28
  color_intensity = min(int(score * 255), 255)