rudradcruze commited on
Commit
1c25c67
·
1 Parent(s): ff10d38

upload toxicity api application

Browse files
.env_example ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN="your_huggingface_token_here"
.gitignore ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # OVARIAN CANCER DETECTION PROJECT - .GITIGNORE
3
+ # ==============================================================================
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ *__pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ Pipfile.lock
92
+
93
+ # poetry
94
+ poetry.lock
95
+
96
+ # pdm
97
+ .pdm.toml
98
+
99
+ # PEP 582
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
135
+
136
+ # pytype static type analyzer
137
+ .pytype/
138
+
139
+ # Cython debug symbols
140
+ cython_debug/
141
+
142
+ # PyCharm
143
+ .idea/
144
+
145
+ # ==============================================================================
146
+ # MACHINE LEARNING & DATA SCIENCE SPECIFIC
147
+ # ==============================================================================
148
+
149
+ # Model files (commented out since we need to deploy them)
150
+ # *.pt
151
+ # *.pth
152
+ # *.pkl
153
+ # *.joblib
154
+ # *.h5
155
+ # *.hdf5
156
+
157
+ # Datasets (keep models but ignore large datasets)
158
+ data/
159
+ dataset/
160
+ datasets/
161
+ *.csv
162
+ *.tsv
163
+ *.json
164
+ *.jsonl
165
+ *.parquet
166
+
167
+ # Large files
168
+ *.zip
169
+ *.tar.gz
170
+ *.rar
171
+ *.7z
172
+
173
+ # Training outputs
174
+ logs/
175
+ runs/
176
+ experiments/
177
+ outputs/
178
+ checkpoints/
179
+ wandb/
180
+ mlruns/
181
+
182
+ # Tensorboard logs
183
+ events.out.tfevents.*
184
+
185
+ # ==============================================================================
186
+ # HUGGINGFACE & API SPECIFIC
187
+ # ==============================================================================
188
+
189
+ # HuggingFace cache
190
+ .cache/
191
+ transformers_cache/
192
+ huggingface_hub/
193
+
194
+ # API keys and tokens (CRITICAL SECURITY)
195
+ .env
196
+ .env.local
197
+ .env.development
198
+ .env.test
199
+ .env.production
200
+ *.token
201
+ *_token
202
+ api_keys.txt
203
+ secrets.txt
204
+ credentials.json
205
+ config.json
206
+
207
+ # HuggingFace specific
208
+ hf_token.txt
209
+ huggingface_token
210
+ .huggingface_token
211
+
212
+ # ==============================================================================
213
+ # GRADIO SPECIFIC
214
+ # ==============================================================================
215
+
216
+ # Gradio temporary files
217
+ gradio_cached_examples/
218
+ flagged/
219
+ gradio_queue.db
220
+
221
+ # ==============================================================================
222
+ # OPERATING SYSTEM FILES
223
+ # ==============================================================================
224
+
225
+ # macOS
226
+ .DS_Store
227
+ .AppleDouble
228
+ .LSOverride
229
+ Icon?
230
+ ._*
231
+ .DocumentRevisions-V100
232
+ .fseventsd
233
+ .Spotlight-V100
234
+ .TemporaryItems
235
+ .Trashes
236
+ .VolumeIcon.icns
237
+ .com.apple.timemachine.donotpresent
238
+ .AppleDB
239
+ .AppleDesktop
240
+ Network Trash Folder
241
+ Temporary Items
242
+ .apdisk
243
+
244
+ # Windows
245
+ Thumbs.db
246
+ Thumbs.db:encryptable
247
+ ehthumbs.db
248
+ ehthumbs_vista.db
249
+ *.tmp
250
+ *.temp
251
+ Desktop.ini
252
+ $RECYCLE.BIN/
253
+ *.cab
254
+ *.msi
255
+ *.msix
256
+ *.msm
257
+ *.msp
258
+ *.lnk
259
+
260
+ # Linux
261
+ *~
262
+ .fuse_hidden*
263
+ .directory
264
+ .Trash-*
265
+ .nfs*
266
+
267
+ # ==============================================================================
268
+ # IDE AND EDITOR FILES
269
+ # ==============================================================================
270
+
271
+ # Visual Studio Code
272
+ .vscode/
273
+ *.code-workspace
274
+
275
+ # JetBrains IDEs
276
+ .idea/
277
+ *.iws
278
+ *.iml
279
+ *.ipr
280
+
281
+ # Sublime Text
282
+ *.sublime-project
283
+ *.sublime-workspace
284
+
285
+ # Vim
286
+ *.swp
287
+ *.swo
288
+ *~
289
+ .viminfo
290
+
291
+ # Emacs
292
+ *~
293
+ \#*\#
294
+ /.emacs.desktop
295
+ /.emacs.desktop.lock
296
+ *.elc
297
+ auto-save-list
298
+ tramp
299
+ .\#*
300
+
301
+ # Atom
302
+ .atom/
303
+
304
+ # ==============================================================================
305
+ # DEVELOPMENT AND TESTING
306
+ # ==============================================================================
307
+
308
+ # Testing
309
+ .tox/
310
+ .coverage
311
+ htmlcov/
312
+ .pytest_cache/
313
+ test_results/
314
+ test_outputs/
315
+
316
+ # Local development
317
+ local/
318
+ tmp/
319
+ temp/
320
+ .tmp/
321
+ .temp/
322
+
323
+ # Backup files
324
+ *.bak
325
+ *.backup
326
+ *.old
327
+ *_backup
328
+ *_old
329
+
330
+ # ==============================================================================
331
+ # PROJECT SPECIFIC
332
+ # ==============================================================================
333
+
334
+ # Original dataset folder (if you have it locally)
335
+ Original/
336
+ original_dataset/
337
+
338
+ # Feature extraction outputs (if regenerating)
339
+ feature_extraction_outputs/
340
+ extracted_features/
341
+
342
+ # Training artifacts (if retraining)
343
+ training_logs/
344
+ model_checkpoints/
345
+ training_outputs/
346
+
347
+ # Test images and results
348
+ test_images/
349
+ test_results/
350
+ prediction_outputs/
351
+
352
+ # Documentation builds
353
+ docs/build/
354
+ documentation/build/
355
+
356
+ # Deployment artifacts (optional)
357
+ deployment_logs/
358
+ build_logs/
359
+
360
+ # Personal notes and scratch files
361
+ notes.txt
362
+ todo.txt
363
+ scratch.py
364
+ test.py
365
+ debug.py
366
+ playground.py
367
+
368
+ # ==============================================================================
369
+ # SECURITY SENSITIVE FILES (CRITICAL)
370
+ # ==============================================================================
371
+
372
+ # Never commit these files containing sensitive information
373
+ **/secrets/**
374
+ **/credentials/**
375
+ **/*_secret*
376
+ **/*_key*
377
+ **/*_password*
378
+ **/*_token*
379
+ **/*credentials*
380
+ private_key*
381
+ public_key*
382
+ *.pem
383
+ *.key
384
+ *.crt
385
+ *.cert
386
+
387
+ # ==============================================================================
388
+ # LARGE FILES AND BINARIES
389
+ # ==============================================================================
390
+
391
+ # Large model files (uncomment if models are too large for Git)
392
+ models/*.pt
393
+ models/*.pth
394
+ model_cache/*.pt
395
+ model_cache/*.pth
396
+ models/
397
+ model_cache/
398
+ *.bin
399
+ *.pt
400
+ *.pkl
401
+ *.h5
402
+ *.onnx
403
+
404
+ # Videos and large media
405
+ *.mp4
406
+ *.avi
407
+ *.mov
408
+ *.mkv
409
+ *.webm
410
+ *.gif
411
+
412
+ # Large images (keep examples small)
413
+ # *.png
414
+ # *.jpg
415
+ # *.jpeg
416
+ # *.tiff
417
+ # *.bmp
418
+
419
+ models/feature_extractor.pt
420
+ models/feature_scaler.pt
421
+ models/multi_head_self_attention_classifier.pt
422
+ *model_cache
423
+ venv
424
+
425
+ # ==============================================================================
426
+ # END OF .GITIGNORE
427
+ # ==============================================================================
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+
7
+ # Create a non-root user for security
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+
11
+ # Set PATH for user local binaries
12
+ ENV PATH="/home/user/.local/bin:$PATH"
13
+
14
+ # Set working directory
15
+ WORKDIR /app
16
+
17
+ # Copy requirements first for better Docker layer caching
18
+ COPY --chown=user requirements.txt requirements.txt
19
+
20
+ # Install Python dependencies
21
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
22
+
23
+ # Create models directory with proper permissions
24
+ RUN mkdir -p /app/models
25
+
26
+ # Copy utils directory (model classes)
27
+ COPY --chown=user ./utils /app/utils
28
+
29
+ # Copy main application
30
+ COPY --chown=user ./app.py /app/
31
+
32
+ # Copy any additional files you might have
33
+ COPY --chown=user ./*.py /app/
34
+
35
+ # Expose port 7860 (required for HuggingFace Spaces)
36
+ EXPOSE 7860
37
+
38
+ # Command to run the application
39
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 CAMLAs (Computer Vision and Machine Learning Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ MEDICAL DISCLAIMER:
26
+
27
+ This software is intended for research and educational purposes only.
28
+ It is NOT intended for clinical diagnosis, medical decision-making, or
29
+ patient care. The software should NOT be used as a substitute for
30
+ professional medical advice, diagnosis, or treatment.
31
+
32
+ Users of this software acknowledge that:
33
+
34
+ 1. The software is experimental and may contain errors or inaccuracies
35
+ 2. Medical decisions should always be made by qualified healthcare professionals
36
+ 3. The developers and CAMLAs organization are not responsible for any
37
+ medical decisions or outcomes resulting from the use of this software
38
+ 4. Users assume all risks associated with the use of this software
39
+
40
+ By using this software, you agree to these terms and acknowledge that you
41
+ understand the limitations and appropriate use cases for this technology.
42
+
43
+ ---
44
+
45
+ ATTRIBUTION:
46
+
47
+ If you use this software in academic research, please cite:
48
+
49
+ CAMLAs Research Team. (2025). Ovarian Cancer Detection API using Hybrid
50
+ ConvNeXt-NASNet Architecture. HuggingFace Spaces.
51
+ https://huggingface.co/spaces/CAMLAs/ovarian-cancer
52
+
53
+ ---
54
+
55
+ THIRD-PARTY LICENSES:
56
+
57
+ This software uses the following third-party libraries and frameworks:
58
+
59
+ - PyTorch: BSD-style license (https://github.com/pytorch/pytorch/blob/master/LICENSE)
60
+ - timm: Apache License 2.0 (https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE)
61
+ - scikit-learn: BSD License (https://github.com/scikit-learn/scikit-learn/blob/main/COPYING)
62
+ - Gradio: Apache License 2.0 (https://github.com/gradio-app/gradio/blob/main/LICENSE)
63
+ - NumPy: BSD License (https://github.com/numpy/numpy/blob/main/LICENSE.txt)
64
+ - Pillow: HPND License (https://github.com/python-pillow/Pillow/blob/main/LICENSE)
65
+
66
+ All third-party libraries retain their original licenses and copyrights.
README.md CHANGED
@@ -1,11 +1,388 @@
1
  ---
2
- title: Toxicity
3
- emoji: 👀
 
 
 
4
  colorFrom: green
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
- short_description: Toxicity Prediction API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: Toxicity Prediction API
3
+ description: A FastAPI-based REST API for predicting protein sequence toxicity using ProtBERT embeddings and MHSA-GRU classifier.
4
+ short_description: Toxicity Prediction API
5
+ version: 1.0.0
6
+ emoji: 🧬
7
  colorFrom: green
8
+ colorTo: blue
9
  sdk: docker
10
+ app_file: app.py
11
  pinned: false
12
+ license: mit
13
+ tags:
14
+ - protein-toxicity
15
+ - protbert
16
+ - mhsa-gru
17
+ - pytorch
18
+ - fastapi
19
+ ---
20
+
21
+ # Toxicity Prediction API
22
+
23
+ A FastAPI-based REST API for predicting protein sequence toxicity using ProtBERT embeddings and MHSA-GRU classifier.
24
+
25
+ Developed by the CAMLAs research team - [Francis Rudra D Cruze](https://linkedin.com/in/rudradcruze).
26
+
27
+ ## 🚀 Features
28
+
29
+ - **ProtBERT Feature Extraction**: Uses state-of-the-art protein language model
30
+ - **MHSA-GRU Classification**: Multi-Head Self-Attention with GRU for accurate predictions
31
+ - **Single & Batch Predictions**: Process one or multiple sequences
32
+ - **HuggingFace Integration**: Automatic model loading from private repository
33
+ - **Production Ready**: Health checks, error handling, and comprehensive logging
34
+
35
+ ## 📋 Requirements
36
+
37
+ - Python 3.8+
38
+ - CUDA-capable GPU (optional, but recommended)
39
+ - HuggingFace account with access to private repository
40
+
41
+ ## 🔧 Installation
42
+
43
+ 1. **Clone the repository**
44
+
45
+ ```bash
46
+ git clone https://huggingface.co/spaces/camlas/toxicity
47
+ cd toxicity
48
+ ```
49
+
50
+ 2. **Create virtual environment**
51
+
52
+ ```bash
53
+ python -m venv venv
54
+ source venv/bin/activate # On Windows: venv\Scripts\activate
55
+ ```
56
+
57
+ 3. **Install dependencies**
58
+
59
+ ```bash
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ 4. **Create `.env` file**
64
+
65
+ ```bash
66
+ echo "HF_TOKEN=your_huggingface_token_here" > .env
67
+ ```
68
+
69
+ Get your HuggingFace token from: https://huggingface.co/settings/tokens
70
+
71
+ ## 🎯 Usage
72
+
73
+ ### Start the API Server
74
+
75
+ ```bash
76
+ python app.py
77
+ ```
78
+
79
+ Or with uvicorn directly:
80
+
81
+ ```bash
82
+ uvicorn app:app --host 0.0.0.0 --port 8000 --reload
83
+ ```
84
+
85
+ The API will be available at: `http://localhost:8000`
86
+
87
+ ### Run Tests
88
+
89
+ ```bash
90
+ python test_api.py
91
+ ```
92
+
93
+ ## 📡 API Endpoints
94
+
95
+ ### 1. Root Endpoint
96
+
97
+ **GET** `/`
98
+
99
+ Returns API information and available endpoints.
100
+
101
+ ```bash
102
+ curl http://localhost:8000/
103
+ ```
104
+
105
+ ### 2. Health Check
106
+
107
+ **GET** `/health`
108
+
109
+ Check API status and model loading status.
110
+
111
+ ```bash
112
+ curl http://localhost:8000/health
113
+ ```
114
+
115
+ **Response:**
116
+
117
+ ```json
118
+ {
119
+ "status_code": 200,
120
+ "status": "healthy",
121
+ "service": "Toxicity Prediction API",
122
+ "api_version": "1.0.0",
123
+ "model_version": "MHSA-GRU-Transformer-v1.0",
124
+ "models_loaded": true,
125
+ "device": "cuda",
126
+ "timestamp": "2025-01-21T10:30:00Z"
127
+ }
128
+ ```
129
+
130
+ ### 3. Single Prediction
131
+
132
+ **POST** `/predict`
133
+
134
+ Predict toxicity for a single protein sequence.
135
+
136
+ **Request:**
137
+
138
+ ```bash
139
+ curl -X POST http://localhost:8000/predict \
140
+ -H "Content-Type: application/json" \
141
+ -d '{"sequence": "MKTAYIAKQRQISFVKSHFSRQLE"}'
142
+ ```
143
+
144
+ **Response:**
145
+
146
+ ```json
147
+ {
148
+ "status_code": 200,
149
+ "status": "success",
150
+ "success": true,
151
+ "data": {
152
+ "sequence": "MKTAYIAKQRQISFVKSHFSRQLE",
153
+ "sequence_length": 24,
154
+ "prediction": {
155
+ "predicted_class": "Toxic",
156
+ "confidence": 0.85,
157
+ "confidence_level": "high",
158
+ "toxicity_score": 0.925,
159
+ "non_toxicity_score": 0.075
160
+ },
161
+ "metadata": {
162
+ "embedding_model": "ProtBERT",
163
+ "embedding_type": "Bert",
164
+ "model_version": "MHSA-GRU-Transformer-v1.0",
165
+ "device": "cuda"
166
+ }
167
+ },
168
+ "timestamp": "2025-01-21T10:30:00Z",
169
+ "api_version": "1.0.0",
170
+ "processing_time_ms": 45.2
171
+ }
172
+ ```
173
+
174
+ ### 4. Batch Prediction
175
+
176
+ **POST** `/predict/batch`
177
+
178
+ Predict toxicity for multiple sequences at once.
179
+
180
+ **Request in Postman/cURL:**
181
+
182
+ ```bash
183
+ curl -X POST http://localhost:8000/predict/batch \
184
+ -H "Content-Type: application/json" \
185
+ -d '{
186
+ "sequences": [
187
+ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES",
188
+ "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK",
189
+ "MKTAYIAKQRQISFVKSHFSRQLE"
190
+ ]
191
+ }'
192
+ ```
193
+
194
+ **Request Body (JSON):**
195
+
196
+ ```json
197
+ {
198
+ "sequences": [
199
+ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES",
200
+ "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK"
201
+ ]
202
+ }
203
+ ```
204
+
205
+ **Response:**
206
+
207
+ ```json
208
+ {
209
+ "status_code": 200,
210
+ "status": "success",
211
+ "success": true,
212
+ "data": {
213
+ "total_sequences": 2,
214
+ "results": [
215
+ {
216
+ "sequence": "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES",
217
+ "sequence_length": 51,
218
+ "predicted_class": "Toxic",
219
+ "toxicity_score": 0.925,
220
+ "confidence": 0.85
221
+ },
222
+ {
223
+ "sequence": "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK",
224
+ "sequence_length": 51,
225
+ "predicted_class": "Non-Toxic",
226
+ "toxicity_score": 0.125,
227
+ "confidence": 0.75
228
+ }
229
+ ],
230
+ "metadata": {
231
+ "embedding_model": "ProtBERT",
232
+ "embedding_type": "Bert",
233
+ "model_version": "MHSA-GRU-Transformer-v1.0",
234
+ "device": "cuda"
235
+ }
236
+ },
237
+ "timestamp": "2025-01-21T10:30:00Z",
238
+ "api_version": "1.0.0",
239
+ "processing_time_ms": 125.8
240
+ }
241
+ ```
242
+
243
+ ## 🐍 Python Usage Examples
244
+
245
+ ### Single Prediction
246
+
247
+ ```python
248
+ import requests
249
+
250
+ response = requests.post(
251
+ "http://localhost:8000/predict",
252
+ json={"sequence": "MKTAYIAKQRQISFVKSHFSRQLE"}
253
+ )
254
+
255
+ result = response.json()
256
+ print(f"Predicted Class: {result['data']['prediction']['predicted_class']}")
257
+ print(f"Toxicity Score: {result['data']['prediction']['toxicity_score']:.4f}")
258
+ print(f"Confidence: {result['data']['prediction']['confidence']:.4f}")
259
+ ```
260
+
261
+ ### Batch Prediction
262
+
263
+ ```python
264
+ sequences = [
265
+ "MKTAYIAKQRQISFVKSHFSRQLE",
266
+ "ARNDCEQGHILKMFPSTWYV",
267
+ "MVHLTPEEKS"
268
+ ]
269
+
270
+ response = requests.post(
271
+ "http://localhost:8000/predict/batch",
272
+ json={"sequences": sequences}
273
+ )
274
+
275
+ results = response.json()
276
+ for i, pred in enumerate(results['data']['results'], 1):
277
+ print(f"Sequence {i}: {pred['predicted_class']} ({pred['toxicity_score']:.4f})")
278
+ ```
279
+
280
+ ## 📁 Project Structure
281
+
282
+ ```
283
+ toxicity-api/
284
+ ├── app.py # Main FastAPI application
285
+ ├── requirements.txt # Python dependencies
286
+ ├── test_api.py # Test suite
287
+ ├── .env # Environment variables (create this)
288
+ ├── models/ # Downloaded models (auto-created)
289
+ └── README.md # This file
290
+ ```
291
+
292
+ ## 🔒 HuggingFace Repository Structure
293
+
294
+ Your private repository `camlas/toxicity` should contain:
295
+
296
+ ```
297
+ camlas/toxicity/
298
+ ├── mhsa_gru_classifier.pth # Trained MHSA-GRU model
299
+ ├── scaler.pkl # Feature scaler
300
+ ├── config.json # ProtBERT config
301
+ ├── model.safetensors # ProtBERT weights
302
+ ├── vocab.txt # ProtBERT vocabulary
303
+ ├── tokenizer_config.json # Tokenizer configuration
304
+ └── special_tokens_map.json # Special tokens mapping
305
+ ```
306
+
307
+ ## 🎨 Model Architecture
308
+
309
+ 1. **Feature Extraction**: ProtBERT (1024-dimensional embeddings)
310
+ 2. **Feature Scaling**: StandardScaler
311
+ 3. **Classification**: MHSA-GRU
312
+ - Multi-Head Self-Attention (3 layers)
313
+ - Bidirectional GRU (2 layers)
314
+ - Fully connected layers with dropout
315
+
316
+ ## ⚠️ Error Codes
317
+
318
+ - `MISSING_SEQUENCE`: No sequence provided in request
319
+ - `SEQUENCE_TOO_SHORT`: Sequence length < 10 amino acids
320
+ - `MODEL_NOT_LOADED`: Models failed to load from HuggingFace
321
+ - `INTERNAL_ERROR`: Unexpected server error
322
+
323
+ ## 📊 Performance
324
+
325
+ - Single prediction: ~40-50ms (GPU)
326
+ - Batch prediction (10 sequences): ~100-150ms (GPU)
327
+ - Model loading time: ~10-15 seconds (first time)
328
+
329
+ ## 🐛 Troubleshooting
330
+
331
+ ### Models not loading
332
+
333
+ 1. Check your HuggingFace token in `.env`
334
+ 2. Verify you have access to the private repository
335
+ 3. Check internet connection
336
+ 4. Look at console logs for specific errors
337
+
338
+ ### CUDA out of memory
339
+
340
+ - Reduce batch size
341
+ - Use CPU instead: Set `device = "cpu"` in code
342
+ - Process sequences one at a time
343
+
344
+ ### Slow predictions
345
+
346
+ - Ensure GPU is being used (check `/health` endpoint)
347
+ - First prediction is always slower (model initialization)
348
+
349
+ ## 🌐 Public Usage Guidelines
350
+
351
+ - **Free to Use**: No authentication or API keys required.
352
+ - **Rate Limiting**: Fair usage is expected. Please do not abuse the service.
353
+ - **Educational Purpose**: Designed for research and educational use.
354
+ - **Medical Disclaimer**: Not for clinical diagnosis. See disclaimer below.
355
+ - **Availability**: Best effort uptime, not guaranteed 24/7.
356
+
357
+ ## ⚠️ Medical Disclaimer
358
+
359
+ **IMPORTANT**: This API is designed for **research and educational purposes only**. It should **NOT** be used for clinical diagnosis or medical decision-making. Always consult qualified medical professionals for diagnostic decisions.
360
+
361
+ ## 🏢 About CAMLAs
362
+
363
+ **CAMLAs** (Centre for Advanced Machine Learning & Applications) is a research organization focused on advancing AI applications in medical imaging and healthcare.
364
+
365
+ **Team Members:**
366
+
367
+ - **S M Hasan Mahmud** – Principal Investigator & Supervisor
368
+ _Roles:_ Writing – Original Draft, Writing – Review & Editing, Conceptualization, Supervision, Project Administration
369
+
370
+ - **Francis Rudra D Cruze** – Lead Developer & Researcher
371
+ _Roles:_ Methodology, Software, Formal Analysis, Investigation, Resources, Visualization
372
+
373
+ ## 📞 Support & Contact
374
+
375
+ - **Issues**: [GitHub Repository Issues](https://github.com/camlas/ovarian-cancer)
376
+ - **Email**: [email protected]
377
+ - **Documentation**: This README
378
+ - **API Status**: Check `/health` endpoint
379
+ - **Website Integration**: Perfect for ovarian.francisrudra.com
380
+
381
+ ## 📄 License
382
+
383
+ This project is licensed under the MIT License - see the LICENSE file for details.
384
+
385
  ---
386
 
387
+ **CAMLAs** - Center for Advanced Machine Learning and Applications
388
+ _Advancing Medical AI Research with Public FastAPI_ 🌐🚀
app-worked-backup-1.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import time
7
+ import joblib
8
+ from pathlib import Path
9
+ from datetime import datetime, timezone
10
+ from typing import Optional
11
+ from contextlib import asynccontextmanager
12
+ from dotenv import load_dotenv
13
+ import shutil
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ # Transformers imports specifically for ProtBERT
17
+ from transformers import BertTokenizer, BertModel
18
+
19
+ # Import your custom model structure
20
+ from utils.model_classes import MHSA_GRU
21
+
22
+ load_dotenv()
23
+
24
+ # ========================= CONFIGURATION ==========================
25
+
26
+ # Repository details (Where your trained classifier/scaler live)
27
+ MODEL_REPO = {
28
+ "repo_id": "camlas/toxicity",
29
+ "files": {
30
+ "classifier": "mhsa_gru_classifier.pth",
31
+ "scaler": "scaler.pkl"
32
+ }
33
+ }
34
+
35
+ # Feature Extraction Config - UPDATED FOR PROTBERT
36
+ TRANSFORMER_CONFIG = {
37
+ "model_name": "Rostlab/prot_bert",
38
+ "model_type": "ProtBERT",
39
+ "tokenizer_class": BertTokenizer,
40
+ "model_class": BertModel
41
+ }
42
+
43
+ CLASSES = ["Non-Toxic", "Toxic"]
44
+ API_VERSION = "2.0.0-protbert"
45
+ MODEL_VERSION = "ProtBERT-MHSA-GRU-v1"
46
+
47
+ # Global variables to hold loaded models
48
+ models = {
49
+ "transformer": None,
50
+ "tokenizer": None,
51
+ "classifier": None,
52
+ "scaler": None
53
+ }
54
+
55
+ # Device selection
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ # ========================= HELPER FUNCTIONS =========================
59
+
60
+ def ensure_models_directory():
61
+ models_dir = "models"
62
+ Path(models_dir).mkdir(exist_ok=True)
63
+ return models_dir
64
+
65
+ def download_model_from_hub(model_key: str) -> Optional[str]:
66
+ """Download custom trained models (Classifier/Scaler) from Private HF Repo"""
67
+ try:
68
+ filename = MODEL_REPO["files"][model_key]
69
+ repo_id = MODEL_REPO["repo_id"]
70
+ models_dir = ensure_models_directory()
71
+ local_path = os.path.join(models_dir, filename)
72
+
73
+ # If file exists locally, use it
74
+ if os.path.exists(local_path):
75
+ print(f"✅ Found {model_key} locally: {local_path}")
76
+ return local_path
77
+
78
+ print(f"📥 Downloading {model_key} from {repo_id}...")
79
+ token = os.getenv("HF_TOKEN")
80
+
81
+ if not token:
82
+ print("⚠️ Warning: HF_TOKEN not found in .env. Private repos will fail.")
83
+
84
+ temp_path = hf_hub_download(
85
+ repo_id=repo_id,
86
+ filename=filename,
87
+ repo_type="model",
88
+ token=token
89
+ )
90
+ shutil.copy2(temp_path, local_path)
91
+ return local_path
92
+ except Exception as e:
93
+ print(f"❌ Error downloading {model_key}: {e}")
94
+ return None
95
+
96
+ def load_feature_extractor():
97
+ """Load the ProtBERT Model from HuggingFace"""
98
+ print(f"🔄 Loading Transformer: {TRANSFORMER_CONFIG['model_name']}...")
99
+ try:
100
+ # Load specifically with do_lower_case=False for ProtBERT
101
+ tokenizer = TRANSFORMER_CONFIG['tokenizer_class'].from_pretrained(
102
+ TRANSFORMER_CONFIG['model_name'],
103
+ do_lower_case=False
104
+ )
105
+ model = TRANSFORMER_CONFIG['model_class'].from_pretrained(
106
+ TRANSFORMER_CONFIG['model_name']
107
+ )
108
+ model.to(device)
109
+ model.eval()
110
+
111
+ models["tokenizer"] = tokenizer
112
+ models["transformer"] = model
113
+ print("✅ ProtBERT Transformer loaded successfully")
114
+ return True
115
+ except Exception as e:
116
+ print(f"❌ Error loading Transformer: {e}")
117
+ return False
118
+
119
+ def load_classifier_and_scaler():
120
+ """Load the custom MHSA-GRU classifier and Scaler"""
121
+ try:
122
+ # 1. Load Scaler
123
+ scaler_path = download_model_from_hub("scaler")
124
+ if scaler_path:
125
+ models["scaler"] = joblib.load(scaler_path)
126
+ print("✅ Scaler loaded")
127
+
128
+ # 2. Load Classifier
129
+ clf_path = download_model_from_hub("classifier")
130
+ if clf_path:
131
+ # ProtBERT output dimension is 1024
132
+ input_dim = 1024
133
+
134
+ print(f"ℹ️ Initializing MHSA_GRU with input_dim={input_dim} (ProtBERT)")
135
+
136
+ classifier = MHSA_GRU(
137
+ input_dim=input_dim,
138
+ hidden_dim=256, # Matching your training code
139
+ num_heads=8,
140
+ num_gru_layers=2,
141
+ dropout=0.3
142
+ )
143
+
144
+ state_dict = torch.load(clf_path, map_location=device)
145
+ classifier.load_state_dict(state_dict)
146
+ classifier.to(device)
147
+ classifier.eval()
148
+ models["classifier"] = classifier
149
+ print("✅ Classifier loaded")
150
+
151
+ return models["scaler"] is not None and models["classifier"] is not None
152
+ except Exception as e:
153
+ print(f"❌ Error loading custom models: {e}")
154
+ return False
155
+
156
+ def preprocess_sequence(sequence: str):
157
+ """
158
+ Preprocess sequence for ProtBERT.
159
+ ProtBERT expects spaces between amino acids: 'M K T A Y...'
160
+ """
161
+ # Clean and uppercase
162
+ sequence = sequence.upper().strip().replace('\n', '').replace('\r', '')
163
+
164
+ # Add spaces between residues
165
+ spaced_sequence = " ".join(list(sequence))
166
+ return spaced_sequence
167
+
168
+ def extract_features(sequence: str):
169
+ """Run sequence through ProtBERT to get [CLS] embeddings"""
170
+ tokenizer = models["tokenizer"]
171
+ model = models["transformer"]
172
+
173
+ processed_seq = preprocess_sequence(sequence)
174
+
175
+ inputs = tokenizer(
176
+ [processed_seq],
177
+ return_tensors="pt",
178
+ padding=True,
179
+ truncation=True,
180
+ max_length=512 # ProtBERT max length
181
+ )
182
+ inputs = {k: v.to(device) for k, v in inputs.items()}
183
+
184
+ with torch.no_grad():
185
+ outputs = model(**inputs)
186
+
187
+ # Extract [CLS] token embedding (Index 0)
188
+ # shape: (batch_size, hidden_dim) -> (1, 1024)
189
+ features = outputs.last_hidden_state[:, 0, :]
190
+
191
+ return features.cpu().numpy()
192
+
193
+ # ========================= FASTAPI LIFESPAN =========================
194
+
195
+ @asynccontextmanager
196
+ async def lifespan(app: FastAPI):
197
+ print("🚀 Starting Toxicity Detection API (ProtBERT Edition)...")
198
+
199
+ # Check if utils/model_classes.py exists
200
+ if not os.path.exists("utils/model_classes.py"):
201
+ print("❌ Error: utils/model_classes.py not found. Please create it.")
202
+
203
+ success_tf = load_feature_extractor()
204
+ success_custom = load_classifier_and_scaler()
205
+
206
+ if not (success_tf and success_custom):
207
+ print("⚠️ Warning: Not all models loaded successfully")
208
+ yield
209
+ print("🔄 Shutting down API...")
210
+
211
+ app = FastAPI(
212
+ title="Peptide Toxicity Detection API",
213
+ description="API using ProtBERT features + MHSA-GRU classifier",
214
+ version=API_VERSION,
215
+ lifespan=lifespan
216
+ )
217
+
218
+ # ========================= PYDANTIC MODELS =========================
219
+
220
+ class SequenceRequest(BaseModel):
221
+ sequence: str
222
+
223
+ class PredictionResponse(BaseModel):
224
+ sequence_preview: str
225
+ is_toxic: bool
226
+ label: str
227
+ score: float
228
+ confidence_level: str
229
+ model_used: str
230
+ processing_time_ms: float
231
+ timestamp: str
232
+
233
+ # ========================= ENDPOINTS =========================
234
+
235
+ @app.get("/")
236
+ async def root():
237
+ return {"message": "Toxicity Detection API is running. Use /predict to analyze sequences."}
238
+
239
+ @app.get("/health")
240
+ async def health_check():
241
+ loaded = all(v is not None for v in models.values())
242
+ return {
243
+ "status": "healthy" if loaded else "degraded",
244
+ "models_loaded": {k: v is not None for k, v in models.items()},
245
+ "device": str(device),
246
+ "model_version": MODEL_VERSION,
247
+ "feature_extractor": TRANSFORMER_CONFIG["model_name"]
248
+ }
249
+
250
+ @app.post("/predict", response_model=PredictionResponse)
251
+ async def predict(request: SequenceRequest):
252
+ start_time = time.time()
253
+
254
+ if not all(models.values()):
255
+ raise HTTPException(status_code=503, detail="Models are not fully initialized.")
256
+
257
+ if not request.sequence:
258
+ raise HTTPException(status_code=400, detail="Empty sequence provided.")
259
+
260
+ try:
261
+ # 1. Extract Features (ProtBERT [CLS] Token)
262
+ # This handles the 'M K T' spacing internally
263
+ raw_features = extract_features(request.sequence)
264
+
265
+ # 2. Scale Features
266
+ # Use the scaler loaded from your repo
267
+ scaled_features = models["scaler"].transform(raw_features)
268
+
269
+ # 3. Predict (MHSA-GRU)
270
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
271
+
272
+ with torch.no_grad():
273
+ # Get probability (sigmoid output)
274
+ probability = models["classifier"](features_tensor).item()
275
+
276
+ # 4. Interpret Results
277
+ # Threshold 0.5
278
+ prediction_class = 1 if probability > 0.5 else 0
279
+ predicted_label = CLASSES[prediction_class]
280
+
281
+ # Confidence calculation
282
+ confidence_score = abs(probability - 0.5) * 2
283
+ confidence_level = "High" if confidence_score > 0.8 else "Medium" if confidence_score > 0.5 else "Low"
284
+
285
+ processing_time = round((time.time() - start_time) * 1000, 2)
286
+
287
+ return PredictionResponse(
288
+ sequence_preview=request.sequence[:20] + "..." if len(request.sequence) > 20 else request.sequence,
289
+ is_toxic=(prediction_class == 1),
290
+ label=predicted_label,
291
+ score=probability,
292
+ confidence_level=confidence_level,
293
+ model_used="ProtBERT + MHSA-GRU",
294
+ processing_time_ms=processing_time,
295
+ timestamp=datetime.now(timezone.utc).isoformat()
296
+ )
297
+
298
+ except Exception as e:
299
+ print(f"Error during prediction: {e}")
300
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
301
+
302
+ if __name__ == "__main__":
303
+ import uvicorn
304
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app-worked-backup-2.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from typing import Optional, List
7
+ import time
8
+ from datetime import datetime, timezone
9
+ import os
10
+ import warnings
11
+ from huggingface_hub import hf_hub_download
12
+ from contextlib import asynccontextmanager
13
+ import uvicorn
14
+ from dotenv import load_dotenv
15
+ import shutil
16
+ import joblib
17
+ from pathlib import Path
18
+ from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel, DistilBertTokenizer, DistilBertModel
19
+
20
+ load_dotenv()
21
+ warnings.filterwarnings('ignore')
22
+
23
+ # ========================= MODEL CLASSES =========================
24
+ class MultiHeadSelfAttention(nn.Module):
25
+ """Multi-Head Self-Attention mechanism"""
26
+ def __init__(self, embed_dim, num_heads, dropout=0.3):
27
+ super(MultiHeadSelfAttention, self).__init__()
28
+ self.attention = nn.MultiheadAttention(
29
+ embed_dim=embed_dim,
30
+ num_heads=num_heads,
31
+ dropout=dropout,
32
+ batch_first=True
33
+ )
34
+ self.layer_norm = nn.LayerNorm(embed_dim)
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, x):
38
+ attn_output, _ = self.attention(x, x, x)
39
+ x = self.layer_norm(x + self.dropout(attn_output))
40
+ return x
41
+
42
+
43
+ class MHSA_GRU(nn.Module):
44
+ """Multi-Head Self-Attention with GRU model"""
45
+ def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3):
46
+ super(MHSA_GRU, self).__init__()
47
+
48
+ self.input_dim = input_dim
49
+ self.hidden_dim = hidden_dim
50
+
51
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
52
+ self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
53
+ self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
54
+
55
+ self.gru = nn.GRU(
56
+ input_size=hidden_dim,
57
+ hidden_size=hidden_dim,
58
+ num_layers=num_gru_layers,
59
+ batch_first=True,
60
+ dropout=dropout if num_gru_layers > 1 else 0,
61
+ bidirectional=False
62
+ )
63
+
64
+ self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
68
+ self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
69
+ self.fc3 = nn.Linear(hidden_dim // 4, 1)
70
+
71
+ self.bn1 = nn.BatchNorm1d(hidden_dim // 2)
72
+ self.bn2 = nn.BatchNorm1d(hidden_dim // 4)
73
+
74
+ def forward(self, x):
75
+ batch_size = x.size(0)
76
+ x = self.input_projection(x)
77
+ x = x.unsqueeze(1)
78
+
79
+ x = self.mhsa1(x)
80
+ x = self.mhsa2(x)
81
+ gru_out, hidden = self.gru(x)
82
+ x = self.mhsa3(gru_out)
83
+ x = x[:, -1, :]
84
+
85
+ x = self.dropout(x)
86
+ x = torch.relu(self.bn1(self.fc1(x)))
87
+ x = self.dropout(x)
88
+ x = torch.relu(self.bn2(self.fc2(x)))
89
+ x = self.dropout(x)
90
+ x = self.fc3(x)
91
+
92
+ return torch.sigmoid(x)
93
+
94
+
95
+ # ========================= CONFIGURATION =========================
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+
98
+ API_VERSION = "1.0.0"
99
+ MODEL_VERSION = "MHSA-GRU-Transformer-v1.0"
100
+
101
+ # Model repository configuration
102
+ MODEL_REPO = {
103
+ "repo_id": "camlas/toxicity",
104
+ "files": {
105
+ "classifier": "mhsa_gru_classifier.pth",
106
+ "scaler": "scaler.pkl",
107
+ "config": "config.json",
108
+ "model_weights": "model.safetensors",
109
+ "vocab": "vocab.txt",
110
+ "tokenizer_config": "tokenizer_config.json",
111
+ "special_tokens_map": "special_tokens_map.json"
112
+ }
113
+ }
114
+
115
+ # Global model variables
116
+ classifier = None
117
+ scaler = None
118
+ transformer_model = None
119
+ transformer_tokenizer = None
120
+ EMBEDDING_TYPE = "Bert"
121
+ MODEL_NAME = "ProtBERT"
122
+
123
+
124
+ # ========================= PYDANTIC MODELS =========================
125
+ class SequenceRequest(BaseModel):
126
+ sequence: str
127
+
128
+
129
+ class BatchSequenceRequest(BaseModel):
130
+ sequences: List[str]
131
+
132
+
133
+ class PredictionResponse(BaseModel):
134
+ status_code: int
135
+ status: str
136
+ success: bool
137
+ data: Optional[dict] = None
138
+ error: Optional[str] = None
139
+ error_code: Optional[str] = None
140
+ timestamp: str
141
+ api_version: str
142
+ processing_time_ms: float
143
+
144
+
145
+ class HealthResponse(BaseModel):
146
+ status_code: int
147
+ status: str
148
+ service: str
149
+ api_version: str
150
+ model_version: str
151
+ models_loaded: bool
152
+ models_loaded_count: int
153
+ total_models_required: int
154
+ model_sources: dict
155
+ repository_info: dict
156
+ device: str
157
+ timestamp: str
158
+
159
+
160
+ # ========================= HELPER FUNCTIONS =========================
161
+ def create_kmers(sequence, k=6):
162
+ """Convert DNA sequence to k-mer tokens (for DNABERT)"""
163
+ kmers = []
164
+ for i in range(len(sequence) - k + 1):
165
+ kmer = sequence[i:i+k]
166
+ kmers.append(kmer)
167
+ return ' '.join(kmers)
168
+
169
+
170
+ def ensure_models_directory():
171
+ models_dir = "models"
172
+ if not os.path.exists(models_dir):
173
+ os.makedirs(models_dir)
174
+ print(f"✅ Created {models_dir} directory")
175
+ return models_dir
176
+
177
+
178
+ def download_model_from_hub(model_name: str) -> Optional[str]:
179
+ """Download individual model files from HuggingFace Hub"""
180
+ try:
181
+ if model_name not in MODEL_REPO["files"]:
182
+ raise ValueError(f"Unknown model: {model_name}")
183
+
184
+ filename = MODEL_REPO["files"][model_name]
185
+ repo_id = MODEL_REPO["repo_id"]
186
+ models_dir = ensure_models_directory()
187
+ local_path = os.path.join(models_dir, filename)
188
+
189
+ if os.path.exists(local_path):
190
+ print(f"✅ Found {model_name} in local models directory: {local_path}")
191
+ return local_path
192
+
193
+ print(f"📥 Downloading {model_name} ({filename}) from {repo_id}...")
194
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
195
+
196
+ if not token:
197
+ print("⚠️ Warning: No HF token found. This may fail for private repositories.")
198
+
199
+ temp_model_path = hf_hub_download(
200
+ repo_id=repo_id,
201
+ filename=filename,
202
+ repo_type="model",
203
+ token=token
204
+ )
205
+
206
+ shutil.copy2(temp_model_path, local_path)
207
+ print(f"✅ {model_name} downloaded and stored!")
208
+ return local_path
209
+
210
+ except Exception as e:
211
+ print(f"❌ Error downloading {model_name}: {e}")
212
+ return None
213
+
214
+
215
+ def extract_features_from_sequence(sequence: str):
216
+ """Extract features from sequence using ProtBERT"""
217
+ global transformer_model, transformer_tokenizer
218
+
219
+ if transformer_model is None or transformer_tokenizer is None:
220
+ raise ValueError("ProtBERT model not loaded")
221
+
222
+ # ProtBERT expects sequences with spaces between amino acids
223
+ # Convert "MKTAYIAKQR" to "M K T A Y I A K Q R"
224
+ processed_seq = ' '.join(list(sequence.upper()))
225
+
226
+ # Tokenize
227
+ inputs = transformer_tokenizer(
228
+ processed_seq,
229
+ return_tensors="pt",
230
+ padding=True,
231
+ truncation=True,
232
+ max_length=512
233
+ )
234
+ inputs = {k: v.to(device) for k, v in inputs.items()}
235
+
236
+ # Extract features
237
+ with torch.no_grad():
238
+ outputs = transformer_model(**inputs)
239
+ # Use [CLS] token embedding
240
+ cls_embeddings = outputs.last_hidden_state[:, 0, :]
241
+
242
+ return cls_embeddings.cpu().numpy()
243
+
244
+
245
+ def load_all_models():
246
+ """Load all models from HuggingFace Hub"""
247
+ global classifier, scaler, transformer_model, transformer_tokenizer
248
+
249
+ models_dir = ensure_models_directory()
250
+ models_loaded = {
251
+ "classifier": False,
252
+ "scaler": False,
253
+ "transformer_model": False,
254
+ "transformer_tokenizer": False
255
+ }
256
+
257
+ print(f"🚀 Loading models from {MODEL_REPO['repo_id']}...")
258
+ print("=" * 60)
259
+
260
+ try:
261
+ # Download all necessary files
262
+ print("📥 Downloading ProtBERT model files...")
263
+
264
+ files_to_download = ["config", "model_weights", "vocab",
265
+ "tokenizer_config", "special_tokens_map"]
266
+
267
+ for file_key in files_to_download:
268
+ download_model_from_hub(file_key)
269
+
270
+ # Load ProtBERT Tokenizer
271
+ print("🔄 Loading ProtBERT tokenizer...")
272
+ try:
273
+ transformer_tokenizer = BertTokenizer.from_pretrained(
274
+ models_dir,
275
+ do_lower_case=False,
276
+ local_files_only=True
277
+ )
278
+ models_loaded["transformer_tokenizer"] = True
279
+ print("✅ ProtBERT tokenizer loaded!")
280
+ except Exception as e:
281
+ print(f"❌ Error loading tokenizer: {e}")
282
+ # Try loading from HuggingFace directly
283
+ print("🔄 Trying to load tokenizer directly from HuggingFace...")
284
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
285
+ transformer_tokenizer = BertTokenizer.from_pretrained(
286
+ MODEL_REPO["repo_id"],
287
+ do_lower_case=False,
288
+ token=token
289
+ )
290
+ models_loaded["transformer_tokenizer"] = True
291
+ print("✅ ProtBERT tokenizer loaded from HuggingFace!")
292
+
293
+ # Load ProtBERT Model
294
+ print("🔄 Loading ProtBERT model...")
295
+ try:
296
+ transformer_model = BertModel.from_pretrained(
297
+ models_dir,
298
+ local_files_only=True
299
+ )
300
+ models_loaded["transformer_model"] = True
301
+ print("✅ ProtBERT model loaded!")
302
+ except Exception as e:
303
+ print(f"❌ Error loading model: {e}")
304
+ # Try loading from HuggingFace directly
305
+ print("🔄 Trying to load model directly from HuggingFace...")
306
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
307
+ transformer_model = BertModel.from_pretrained(
308
+ MODEL_REPO["repo_id"],
309
+ token=token
310
+ )
311
+ models_loaded["transformer_model"] = True
312
+ print("✅ ProtBERT model loaded from HuggingFace!")
313
+
314
+ transformer_model.to(device)
315
+ transformer_model.eval()
316
+
317
+ # Load Classifier
318
+ print("🔄 Loading classifier (MHSA-GRU)...")
319
+ clf_path = os.path.join(models_dir, MODEL_REPO["files"]["classifier"])
320
+
321
+ if not os.path.exists(clf_path):
322
+ print("📥 Classifier not found locally, downloading...")
323
+ clf_path = download_model_from_hub("classifier")
324
+
325
+ if clf_path and os.path.exists(clf_path):
326
+ checkpoint = torch.load(clf_path, map_location=device, weights_only=False)
327
+
328
+ # Handle different checkpoint formats
329
+ if 'input_dim' in checkpoint:
330
+ input_dim = checkpoint['input_dim']
331
+ else:
332
+ # ProtBERT embedding size is 1024
333
+ input_dim = 1024
334
+
335
+ classifier = MHSA_GRU(input_dim, hidden_dim=256)
336
+
337
+ # Load state dict
338
+ if 'model_state_dict' in checkpoint:
339
+ classifier.load_state_dict(checkpoint['model_state_dict'])
340
+ else:
341
+ classifier.load_state_dict(checkpoint)
342
+
343
+ classifier.to(device)
344
+ classifier.eval()
345
+ models_loaded["classifier"] = True
346
+ print(f"✅ Classifier loaded! (input_dim: {input_dim})")
347
+
348
+ # Load Scaler
349
+ print("🔄 Loading feature scaler...")
350
+ scaler_path = os.path.join(models_dir, MODEL_REPO["files"]["scaler"])
351
+
352
+ if not os.path.exists(scaler_path):
353
+ print("📥 Scaler not found locally, downloading...")
354
+ scaler_path = download_model_from_hub("scaler")
355
+
356
+ if scaler_path and os.path.exists(scaler_path):
357
+ scaler = joblib.load(scaler_path)
358
+ models_loaded["scaler"] = True
359
+ print("✅ Scaler loaded!")
360
+
361
+ loaded_count = sum(models_loaded.values())
362
+ total_count = len(models_loaded)
363
+
364
+ print(f"\n📊 Model Loading Summary:")
365
+ print(f" • Successfully loaded: {loaded_count}/{total_count}")
366
+ print(f" • Repository: {MODEL_REPO['repo_id']}")
367
+ print(f" • Embedding Model: {MODEL_NAME}")
368
+ print(f" • Device: {device}")
369
+
370
+ critical_models = ["classifier", "scaler", "transformer_model", "transformer_tokenizer"]
371
+ critical_loaded = all(models_loaded[m] for m in critical_models)
372
+
373
+ if critical_loaded:
374
+ print("🎉 All critical models loaded successfully!")
375
+ return True
376
+ else:
377
+ print("⚠️ Some critical models failed to load")
378
+ print(f" Models status: {models_loaded}")
379
+ return False
380
+
381
+ except Exception as e:
382
+ print(f"❌ Error loading models: {e}")
383
+ import traceback
384
+ traceback.print_exc()
385
+ return False
386
+
387
+
388
+ # ========================= FASTAPI APPLICATION =========================
389
+ @asynccontextmanager
390
+ async def lifespan(app: FastAPI):
391
+ # Startup
392
+ print("🚀 Starting Toxicity Prediction API...")
393
+ success = load_all_models()
394
+ if not success:
395
+ print("⚠️ Warning: Not all models loaded successfully")
396
+ yield
397
+ # Shutdown
398
+ print("🔄 Shutting down API...")
399
+
400
+
401
+ app = FastAPI(
402
+ title="Toxicity Prediction API",
403
+ description="API for toxicity prediction using MHSA-GRU with Transformer embeddings",
404
+ version="1.0.0",
405
+ lifespan=lifespan
406
+ )
407
+
408
+
409
+ @app.get("/")
410
+ async def root():
411
+ return {
412
+ "message": "Toxicity Prediction API",
413
+ "version": API_VERSION,
414
+ "endpoints": {
415
+ "/predict": "POST - Predict toxicity for a single sequence",
416
+ "/predict/batch": "POST - Predict toxicity for multiple sequences",
417
+ "/health": "GET - Check API health and model status"
418
+ }
419
+ }
420
+
421
+
422
+ @app.post("/predict", response_model=PredictionResponse)
423
+ async def predict(request: SequenceRequest):
424
+ start_time = time.time()
425
+ timestamp = datetime.now(timezone.utc).isoformat()
426
+
427
+ try:
428
+ if not request.sequence or len(request.sequence) == 0:
429
+ raise HTTPException(
430
+ status_code=400,
431
+ detail={
432
+ "status_code": 400,
433
+ "status": "error",
434
+ "success": False,
435
+ "error": "No sequence provided",
436
+ "error_code": "MISSING_SEQUENCE",
437
+ "timestamp": timestamp,
438
+ "api_version": API_VERSION,
439
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
440
+ }
441
+ )
442
+
443
+ # Check if models are loaded
444
+ if classifier is None or scaler is None or transformer_model is None:
445
+ raise HTTPException(
446
+ status_code=503,
447
+ detail={
448
+ "status_code": 503,
449
+ "status": "error",
450
+ "success": False,
451
+ "error": "Models not loaded properly",
452
+ "error_code": "MODEL_NOT_LOADED",
453
+ "timestamp": timestamp,
454
+ "api_version": API_VERSION,
455
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
456
+ }
457
+ )
458
+
459
+ # Validate sequence
460
+ sequence = request.sequence.upper().strip()
461
+ if len(sequence) < 10:
462
+ raise HTTPException(
463
+ status_code=400,
464
+ detail={
465
+ "status_code": 400,
466
+ "status": "error",
467
+ "success": False,
468
+ "error": "Sequence too short (minimum 10 characters)",
469
+ "error_code": "SEQUENCE_TOO_SHORT",
470
+ "timestamp": timestamp,
471
+ "api_version": API_VERSION,
472
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
473
+ }
474
+ )
475
+
476
+ # Step 1: Extract features using ProtBERT
477
+ features = extract_features_from_sequence(sequence)
478
+
479
+ # Step 2: Scale features
480
+ scaled_features = scaler.transform(features)
481
+
482
+ # Step 3: Predict using MHSA-GRU
483
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
484
+
485
+ with torch.no_grad():
486
+ probability = classifier(features_tensor).cpu().numpy()[0, 0]
487
+
488
+ # Determine prediction
489
+ prediction_class = 1 if probability > 0.5 else 0
490
+ predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic"
491
+ confidence = float(abs(probability - 0.5) * 2)
492
+
493
+ # Determine confidence level
494
+ if confidence > 0.8:
495
+ confidence_level = "high"
496
+ elif confidence > 0.6:
497
+ confidence_level = "medium"
498
+ else:
499
+ confidence_level = "low"
500
+
501
+ processing_time = round((time.time() - start_time) * 1000, 2)
502
+
503
+ return PredictionResponse(
504
+ status_code=200,
505
+ status="success",
506
+ success=True,
507
+ data={
508
+ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence,
509
+ "sequence_length": len(sequence),
510
+ "prediction": {
511
+ "predicted_class": predicted_label,
512
+ "confidence": confidence,
513
+ "confidence_level": confidence_level,
514
+ "toxicity_score": float(probability),
515
+ "non_toxicity_score": float(1 - probability)
516
+ },
517
+ "metadata": {
518
+ "embedding_model": MODEL_NAME,
519
+ "embedding_type": EMBEDDING_TYPE,
520
+ "model_version": MODEL_VERSION,
521
+ "device": str(device)
522
+ }
523
+ },
524
+ timestamp=timestamp,
525
+ api_version=API_VERSION,
526
+ processing_time_ms=processing_time
527
+ )
528
+
529
+ except HTTPException:
530
+ raise
531
+ except Exception as e:
532
+ processing_time = round((time.time() - start_time) * 1000, 2)
533
+ raise HTTPException(
534
+ status_code=500,
535
+ detail={
536
+ "status_code": 500,
537
+ "status": "error",
538
+ "success": False,
539
+ "error": f"Internal server error: {str(e)}",
540
+ "error_code": "INTERNAL_ERROR",
541
+ "timestamp": timestamp,
542
+ "api_version": API_VERSION,
543
+ "processing_time_ms": processing_time
544
+ }
545
+ )
546
+
547
+
548
+ @app.post("/predict/batch", response_model=PredictionResponse)
549
+ async def predict_batch(request: BatchSequenceRequest):
550
+ start_time = time.time()
551
+ timestamp = datetime.now(timezone.utc).isoformat()
552
+
553
+ try:
554
+ if not request.sequences or len(request.sequences) == 0:
555
+ raise HTTPException(
556
+ status_code=400,
557
+ detail={
558
+ "status_code": 400,
559
+ "status": "error",
560
+ "success": False,
561
+ "error": "No sequences provided",
562
+ "error_code": "MISSING_SEQUENCES",
563
+ "timestamp": timestamp,
564
+ "api_version": API_VERSION,
565
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
566
+ }
567
+ )
568
+
569
+ # Check if models are loaded
570
+ if classifier is None or scaler is None or transformer_model is None:
571
+ raise HTTPException(
572
+ status_code=503,
573
+ detail={
574
+ "status_code": 503,
575
+ "status": "error",
576
+ "success": False,
577
+ "error": "Models not loaded properly",
578
+ "error_code": "MODEL_NOT_LOADED",
579
+ "timestamp": timestamp,
580
+ "api_version": API_VERSION,
581
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
582
+ }
583
+ )
584
+
585
+ results = []
586
+
587
+ for seq in request.sequences:
588
+ sequence = seq.upper().strip()
589
+
590
+ # Extract features using ProtBERT
591
+ features = extract_features_from_sequence(sequence)
592
+ scaled_features = scaler.transform(features)
593
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
594
+
595
+ with torch.no_grad():
596
+ probability = classifier(features_tensor).cpu().numpy()[0, 0]
597
+
598
+ prediction_class = 1 if probability > 0.5 else 0
599
+ predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic"
600
+ confidence = float(abs(probability - 0.5) * 2)
601
+
602
+ results.append({
603
+ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence,
604
+ "sequence_length": len(sequence),
605
+ "predicted_class": predicted_label,
606
+ "toxicity_score": float(probability),
607
+ "confidence": confidence
608
+ })
609
+
610
+ processing_time = round((time.time() - start_time) * 1000, 2)
611
+
612
+ return PredictionResponse(
613
+ status_code=200,
614
+ status="success",
615
+ success=True,
616
+ data={
617
+ "total_sequences": len(request.sequences),
618
+ "results": results,
619
+ "metadata": {
620
+ "embedding_model": MODEL_NAME,
621
+ "embedding_type": EMBEDDING_TYPE,
622
+ "model_version": MODEL_VERSION,
623
+ "device": str(device)
624
+ }
625
+ },
626
+ timestamp=timestamp,
627
+ api_version=API_VERSION,
628
+ processing_time_ms=processing_time
629
+ )
630
+
631
+ except HTTPException:
632
+ raise
633
+ except Exception as e:
634
+ processing_time = round((time.time() - start_time) * 1000, 2)
635
+ raise HTTPException(
636
+ status_code=500,
637
+ detail={
638
+ "status_code": 500,
639
+ "status": "error",
640
+ "success": False,
641
+ "error": f"Internal server error: {str(e)}",
642
+ "error_code": "INTERNAL_ERROR",
643
+ "timestamp": timestamp,
644
+ "api_version": API_VERSION,
645
+ "processing_time_ms": processing_time
646
+ }
647
+ )
648
+
649
+
650
+ @app.get("/health", response_model=HealthResponse)
651
+ async def health_check():
652
+ models_loaded = all([
653
+ classifier is not None,
654
+ scaler is not None,
655
+ transformer_model is not None,
656
+ transformer_tokenizer is not None
657
+ ])
658
+
659
+ model_sources = {
660
+ "classifier": {
661
+ "loaded": classifier is not None,
662
+ "source": "huggingface_hub",
663
+ "repository": MODEL_REPO["repo_id"]
664
+ },
665
+ "scaler": {
666
+ "loaded": scaler is not None,
667
+ "source": "huggingface_hub",
668
+ "repository": MODEL_REPO["repo_id"]
669
+ },
670
+ "transformer_model": {
671
+ "loaded": transformer_model is not None,
672
+ "model_name": MODEL_NAME,
673
+ "source": "huggingface_hub",
674
+ "repository": MODEL_REPO["repo_id"]
675
+ }
676
+ }
677
+
678
+ repository_info = {
679
+ "repository_id": MODEL_REPO["repo_id"],
680
+ "embedding_type": EMBEDDING_TYPE,
681
+ "model_name": MODEL_NAME,
682
+ "total_models": len(MODEL_REPO["files"])
683
+ }
684
+
685
+ return HealthResponse(
686
+ status_code=200 if models_loaded else 503,
687
+ status="healthy" if models_loaded else "unhealthy",
688
+ service="Toxicity Prediction API",
689
+ api_version=API_VERSION,
690
+ model_version=MODEL_VERSION,
691
+ models_loaded=models_loaded,
692
+ models_loaded_count=sum(1 for source in model_sources.values() if source["loaded"]),
693
+ total_models_required=4,
694
+ model_sources=model_sources,
695
+ repository_info=repository_info,
696
+ device=str(device),
697
+ timestamp=datetime.now(timezone.utc).isoformat()
698
+ )
699
+
700
+
701
+ if __name__ == "__main__":
702
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from typing import Optional, List
7
+ import time
8
+ from datetime import datetime, timezone
9
+ import os
10
+ import warnings
11
+ from huggingface_hub import hf_hub_download
12
+ from contextlib import asynccontextmanager
13
+ import uvicorn
14
+ from dotenv import load_dotenv
15
+ import shutil
16
+ import joblib
17
+ from pathlib import Path
18
+ from transformers import BertTokenizer, BertModel
19
+ from utils.model_classes import MHSA_GRU, MultiHeadSelfAttention
20
+
21
+ load_dotenv()
22
+ warnings.filterwarnings('ignore')
23
+
24
+ # ========================= CONFIGURATION =========================
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ API_VERSION = "1.0.0"
28
+ MODEL_VERSION = "MHSA-GRU-Transformer-v1.0"
29
+
30
+ # Model repository configuration
31
+ MODEL_REPO = {
32
+ "repo_id": "camlas/toxicity",
33
+ "files": {
34
+ "classifier": "mhsa_gru_classifier.pth",
35
+ "scaler": "scaler.pkl",
36
+ "config": "config.json",
37
+ "model_weights": "model.safetensors",
38
+ "vocab": "vocab.txt",
39
+ "tokenizer_config": "tokenizer_config.json",
40
+ "special_tokens_map": "special_tokens_map.json"
41
+ }
42
+ }
43
+
44
+ # Global model variables
45
+ classifier = None
46
+ scaler = None
47
+ transformer_model = None
48
+ transformer_tokenizer = None
49
+ EMBEDDING_TYPE = "Bert"
50
+ MODEL_NAME = "ProtBERT"
51
+
52
+
53
+ # ========================= PYDANTIC MODELS =========================
54
+ class SequenceRequest(BaseModel):
55
+ sequence: str
56
+
57
+
58
+ class BatchSequenceRequest(BaseModel):
59
+ sequences: List[str]
60
+
61
+
62
+ class PredictionResponse(BaseModel):
63
+ status_code: int
64
+ status: str
65
+ success: bool
66
+ data: Optional[dict] = None
67
+ error: Optional[str] = None
68
+ error_code: Optional[str] = None
69
+ timestamp: str
70
+ api_version: str
71
+ processing_time_ms: float
72
+
73
+
74
+ class HealthResponse(BaseModel):
75
+ status_code: int
76
+ status: str
77
+ service: str
78
+ api_version: str
79
+ model_version: str
80
+ models_loaded: bool
81
+ models_loaded_count: int
82
+ total_models_required: int
83
+ model_sources: dict
84
+ repository_info: dict
85
+ device: str
86
+ timestamp: str
87
+
88
+
89
+ # ========================= HELPER FUNCTIONS =========================
90
+ def create_kmers(sequence, k=6):
91
+ """Convert DNA sequence to k-mer tokens (for DNABERT)"""
92
+ kmers = []
93
+ for i in range(len(sequence) - k + 1):
94
+ kmer = sequence[i:i+k]
95
+ kmers.append(kmer)
96
+ return ' '.join(kmers)
97
+
98
+
99
+ def ensure_models_directory():
100
+ models_dir = "models"
101
+ if not os.path.exists(models_dir):
102
+ os.makedirs(models_dir)
103
+ print(f"✅ Created {models_dir} directory")
104
+ return models_dir
105
+
106
+
107
+ def download_model_from_hub(model_name: str) -> Optional[str]:
108
+ """Download individual model files from HuggingFace Hub"""
109
+ try:
110
+ if model_name not in MODEL_REPO["files"]:
111
+ raise ValueError(f"Unknown model: {model_name}")
112
+
113
+ filename = MODEL_REPO["files"][model_name]
114
+ repo_id = MODEL_REPO["repo_id"]
115
+ models_dir = ensure_models_directory()
116
+ local_path = os.path.join(models_dir, filename)
117
+
118
+ if os.path.exists(local_path):
119
+ print(f"✅ Found {model_name} in local models directory: {local_path}")
120
+ return local_path
121
+
122
+ print(f"📥 Downloading {model_name} ({filename}) from {repo_id}...")
123
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
124
+
125
+ if not token:
126
+ print("⚠️ Warning: No HF token found. This may fail for private repositories.")
127
+
128
+ temp_model_path = hf_hub_download(
129
+ repo_id=repo_id,
130
+ filename=filename,
131
+ repo_type="model",
132
+ token=token
133
+ )
134
+
135
+ shutil.copy2(temp_model_path, local_path)
136
+ print(f"✅ {model_name} downloaded and stored!")
137
+ return local_path
138
+
139
+ except Exception as e:
140
+ print(f"❌ Error downloading {model_name}: {e}")
141
+ return None
142
+
143
+
144
+ def extract_features_from_sequence(sequence: str):
145
+ """Extract features from sequence using ProtBERT"""
146
+ global transformer_model, transformer_tokenizer
147
+
148
+ if transformer_model is None or transformer_tokenizer is None:
149
+ raise ValueError("ProtBERT model not loaded")
150
+
151
+ # ProtBERT expects sequences with spaces between amino acids
152
+ # Convert "MKTAYIAKQR" to "M K T A Y I A K Q R"
153
+ processed_seq = ' '.join(list(sequence.upper()))
154
+
155
+ # Tokenize
156
+ inputs = transformer_tokenizer(
157
+ processed_seq,
158
+ return_tensors="pt",
159
+ padding=True,
160
+ truncation=True,
161
+ max_length=512
162
+ )
163
+ inputs = {k: v.to(device) for k, v in inputs.items()}
164
+
165
+ # Extract features
166
+ with torch.no_grad():
167
+ outputs = transformer_model(**inputs)
168
+ # Use [CLS] token embedding
169
+ cls_embeddings = outputs.last_hidden_state[:, 0, :]
170
+
171
+ return cls_embeddings.cpu().numpy()
172
+
173
+
174
+ def load_all_models():
175
+ """Load all models from HuggingFace Hub"""
176
+ global classifier, scaler, transformer_model, transformer_tokenizer
177
+
178
+ models_dir = ensure_models_directory()
179
+ models_loaded = {
180
+ "classifier": False,
181
+ "scaler": False,
182
+ "transformer_model": False,
183
+ "transformer_tokenizer": False
184
+ }
185
+
186
+ print(f"🚀 Loading models from {MODEL_REPO['repo_id']}...")
187
+ print("=" * 60)
188
+
189
+ try:
190
+ # Download all necessary files
191
+ print("📥 Downloading ProtBERT model files...")
192
+
193
+ files_to_download = ["config", "model_weights", "vocab",
194
+ "tokenizer_config", "special_tokens_map"]
195
+
196
+ for file_key in files_to_download:
197
+ download_model_from_hub(file_key)
198
+
199
+ # Load ProtBERT Tokenizer
200
+ print("🔄 Loading ProtBERT tokenizer...")
201
+ try:
202
+ transformer_tokenizer = BertTokenizer.from_pretrained(
203
+ models_dir,
204
+ do_lower_case=False,
205
+ local_files_only=True
206
+ )
207
+ models_loaded["transformer_tokenizer"] = True
208
+ print("✅ ProtBERT tokenizer loaded!")
209
+ except Exception as e:
210
+ print(f"❌ Error loading tokenizer: {e}")
211
+ # Try loading from HuggingFace directly
212
+ print("🔄 Trying to load tokenizer directly from HuggingFace...")
213
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
214
+ transformer_tokenizer = BertTokenizer.from_pretrained(
215
+ MODEL_REPO["repo_id"],
216
+ do_lower_case=False,
217
+ token=token
218
+ )
219
+ models_loaded["transformer_tokenizer"] = True
220
+ print("✅ ProtBERT tokenizer loaded from HuggingFace!")
221
+
222
+ # Load ProtBERT Model
223
+ print("🔄 Loading ProtBERT model...")
224
+ try:
225
+ transformer_model = BertModel.from_pretrained(
226
+ models_dir,
227
+ local_files_only=True
228
+ )
229
+ models_loaded["transformer_model"] = True
230
+ print("✅ ProtBERT model loaded!")
231
+ except Exception as e:
232
+ print(f"❌ Error loading model: {e}")
233
+ # Try loading from HuggingFace directly
234
+ print("🔄 Trying to load model directly from HuggingFace...")
235
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
236
+ transformer_model = BertModel.from_pretrained(
237
+ MODEL_REPO["repo_id"],
238
+ token=token
239
+ )
240
+ models_loaded["transformer_model"] = True
241
+ print("✅ ProtBERT model loaded from HuggingFace!")
242
+
243
+ transformer_model.to(device)
244
+ transformer_model.eval()
245
+
246
+ # Load Classifier
247
+ print("🔄 Loading classifier (MHSA-GRU)...")
248
+ clf_path = os.path.join(models_dir, MODEL_REPO["files"]["classifier"])
249
+
250
+ if not os.path.exists(clf_path):
251
+ print("📥 Classifier not found locally, downloading...")
252
+ clf_path = download_model_from_hub("classifier")
253
+
254
+ if clf_path and os.path.exists(clf_path):
255
+ checkpoint = torch.load(clf_path, map_location=device, weights_only=False)
256
+
257
+ # Handle different checkpoint formats
258
+ if 'input_dim' in checkpoint:
259
+ input_dim = checkpoint['input_dim']
260
+ else:
261
+ # ProtBERT embedding size is 1024
262
+ input_dim = 1024
263
+
264
+ classifier = MHSA_GRU(input_dim, hidden_dim=256)
265
+
266
+ # Load state dict
267
+ if 'model_state_dict' in checkpoint:
268
+ classifier.load_state_dict(checkpoint['model_state_dict'])
269
+ else:
270
+ classifier.load_state_dict(checkpoint)
271
+
272
+ classifier.to(device)
273
+ classifier.eval()
274
+ models_loaded["classifier"] = True
275
+ print(f"✅ Classifier loaded! (input_dim: {input_dim})")
276
+
277
+ # Load Scaler
278
+ print("🔄 Loading feature scaler...")
279
+ scaler_path = os.path.join(models_dir, MODEL_REPO["files"]["scaler"])
280
+
281
+ if not os.path.exists(scaler_path):
282
+ print("📥 Scaler not found locally, downloading...")
283
+ scaler_path = download_model_from_hub("scaler")
284
+
285
+ if scaler_path and os.path.exists(scaler_path):
286
+ scaler = joblib.load(scaler_path)
287
+ models_loaded["scaler"] = True
288
+ print("✅ Scaler loaded!")
289
+
290
+ loaded_count = sum(models_loaded.values())
291
+ total_count = len(models_loaded)
292
+
293
+ print(f"\n📊 Model Loading Summary:")
294
+ print(f" • Successfully loaded: {loaded_count}/{total_count}")
295
+ print(f" • Repository: {MODEL_REPO['repo_id']}")
296
+ print(f" • Embedding Model: {MODEL_NAME}")
297
+ print(f" • Device: {device}")
298
+
299
+ critical_models = ["classifier", "scaler", "transformer_model", "transformer_tokenizer"]
300
+ critical_loaded = all(models_loaded[m] for m in critical_models)
301
+
302
+ if critical_loaded:
303
+ print("🎉 All critical models loaded successfully!")
304
+ return True
305
+ else:
306
+ print("⚠️ Some critical models failed to load")
307
+ print(f" Models status: {models_loaded}")
308
+ return False
309
+
310
+ except Exception as e:
311
+ print(f"❌ Error loading models: {e}")
312
+ import traceback
313
+ traceback.print_exc()
314
+ return False
315
+
316
+
317
+ # ========================= FASTAPI APPLICATION =========================
318
+ @asynccontextmanager
319
+ async def lifespan(app: FastAPI):
320
+ # Startup
321
+ print("🚀 Starting Toxicity Prediction API...")
322
+ success = load_all_models()
323
+ if not success:
324
+ print("⚠️ Warning: Not all models loaded successfully")
325
+ yield
326
+ # Shutdown
327
+ print("🔄 Shutting down API...")
328
+
329
+
330
+ app = FastAPI(
331
+ title="Toxicity Prediction API",
332
+ description="API for toxicity prediction using MHSA-GRU with Transformer embeddings",
333
+ version="1.0.0",
334
+ lifespan=lifespan
335
+ )
336
+
337
+
338
+ @app.get("/")
339
+ async def root():
340
+ return {
341
+ "message": "Toxicity Prediction API",
342
+ "version": API_VERSION,
343
+ "endpoints": {
344
+ "/predict": "POST - Predict toxicity for a single sequence",
345
+ "/predict/batch": "POST - Predict toxicity for multiple sequences",
346
+ "/example": "GET - Try the API with a hardcoded example sequence",
347
+ "/health": "GET - Check API health and model status"
348
+ },
349
+ "example_usage": {
350
+ "single": {
351
+ "method": "POST",
352
+ "url": "/predict",
353
+ "body": {"sequence": "MKTAYIAKQRQISFVKSHFSRQLE"}
354
+ },
355
+ "batch": {
356
+ "method": "POST",
357
+ "url": "/predict/batch",
358
+ "body": {
359
+ "sequences": [
360
+ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES",
361
+ "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK"
362
+ ]
363
+ }
364
+ },
365
+ "example": {
366
+ "method": "GET",
367
+ "url": "/example",
368
+ "description": "No input needed - just call this endpoint"
369
+ }
370
+ }
371
+ }
372
+
373
+
374
+ @app.post("/predict", response_model=PredictionResponse)
375
+ async def predict(request: SequenceRequest):
376
+ start_time = time.time()
377
+ timestamp = datetime.now(timezone.utc).isoformat()
378
+
379
+ try:
380
+ if not request.sequence or len(request.sequence) == 0:
381
+ raise HTTPException(
382
+ status_code=400,
383
+ detail={
384
+ "status_code": 400,
385
+ "status": "error",
386
+ "success": False,
387
+ "error": "No sequence provided",
388
+ "error_code": "MISSING_SEQUENCE",
389
+ "timestamp": timestamp,
390
+ "api_version": API_VERSION,
391
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
392
+ }
393
+ )
394
+
395
+ # Check if models are loaded
396
+ if classifier is None or scaler is None or transformer_model is None:
397
+ raise HTTPException(
398
+ status_code=503,
399
+ detail={
400
+ "status_code": 503,
401
+ "status": "error",
402
+ "success": False,
403
+ "error": "Models not loaded properly",
404
+ "error_code": "MODEL_NOT_LOADED",
405
+ "timestamp": timestamp,
406
+ "api_version": API_VERSION,
407
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
408
+ }
409
+ )
410
+
411
+ # Validate sequence
412
+ sequence = request.sequence.upper().strip()
413
+ if len(sequence) < 10:
414
+ raise HTTPException(
415
+ status_code=400,
416
+ detail={
417
+ "status_code": 400,
418
+ "status": "error",
419
+ "success": False,
420
+ "error": "Sequence too short (minimum 10 characters)",
421
+ "error_code": "SEQUENCE_TOO_SHORT",
422
+ "timestamp": timestamp,
423
+ "api_version": API_VERSION,
424
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
425
+ }
426
+ )
427
+
428
+ # Step 1: Extract features using ProtBERT
429
+ features = extract_features_from_sequence(sequence)
430
+
431
+ # Step 2: Scale features
432
+ scaled_features = scaler.transform(features)
433
+
434
+ # Step 3: Predict using MHSA-GRU
435
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
436
+
437
+ with torch.no_grad():
438
+ probability = classifier(features_tensor).cpu().numpy()[0, 0]
439
+
440
+ # Determine prediction
441
+ prediction_class = 1 if probability > 0.5 else 0
442
+ predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic"
443
+ confidence = float(abs(probability - 0.5) * 2)
444
+
445
+ # Determine confidence level
446
+ if confidence > 0.8:
447
+ confidence_level = "high"
448
+ elif confidence > 0.6:
449
+ confidence_level = "medium"
450
+ else:
451
+ confidence_level = "low"
452
+
453
+ processing_time = round((time.time() - start_time) * 1000, 2)
454
+
455
+ return PredictionResponse(
456
+ status_code=200,
457
+ status="success",
458
+ success=True,
459
+ data={
460
+ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence,
461
+ "sequence_length": len(sequence),
462
+ "prediction": {
463
+ "predicted_class": predicted_label,
464
+ "confidence": confidence,
465
+ "confidence_level": confidence_level,
466
+ "toxicity_score": float(probability),
467
+ "non_toxicity_score": float(1 - probability)
468
+ },
469
+ "metadata": {
470
+ "embedding_model": MODEL_NAME,
471
+ "embedding_type": EMBEDDING_TYPE,
472
+ "model_version": MODEL_VERSION,
473
+ "device": str(device)
474
+ }
475
+ },
476
+ timestamp=timestamp,
477
+ api_version=API_VERSION,
478
+ processing_time_ms=processing_time
479
+ )
480
+
481
+ except HTTPException:
482
+ raise
483
+ except Exception as e:
484
+ processing_time = round((time.time() - start_time) * 1000, 2)
485
+ raise HTTPException(
486
+ status_code=500,
487
+ detail={
488
+ "status_code": 500,
489
+ "status": "error",
490
+ "success": False,
491
+ "error": f"Internal server error: {str(e)}",
492
+ "error_code": "INTERNAL_ERROR",
493
+ "timestamp": timestamp,
494
+ "api_version": API_VERSION,
495
+ "processing_time_ms": processing_time
496
+ }
497
+ )
498
+
499
+
500
+ @app.post("/predict/batch", response_model=PredictionResponse)
501
+ async def predict_batch(request: BatchSequenceRequest):
502
+ """
503
+ Predict toxicity for multiple sequences at once.
504
+
505
+ Example request body:
506
+ {
507
+ "sequences": [
508
+ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES",
509
+ "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK"
510
+ ]
511
+ }
512
+ """
513
+ start_time = time.time()
514
+ timestamp = datetime.now(timezone.utc).isoformat()
515
+
516
+ try:
517
+ if not request.sequences or len(request.sequences) == 0:
518
+ raise HTTPException(
519
+ status_code=400,
520
+ detail={
521
+ "status_code": 400,
522
+ "status": "error",
523
+ "success": False,
524
+ "error": "No sequences provided",
525
+ "error_code": "MISSING_SEQUENCES",
526
+ "timestamp": timestamp,
527
+ "api_version": API_VERSION,
528
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
529
+ }
530
+ )
531
+
532
+ # Check if models are loaded
533
+ if classifier is None or scaler is None or transformer_model is None:
534
+ raise HTTPException(
535
+ status_code=503,
536
+ detail={
537
+ "status_code": 503,
538
+ "status": "error",
539
+ "success": False,
540
+ "error": "Models not loaded properly",
541
+ "error_code": "MODEL_NOT_LOADED",
542
+ "timestamp": timestamp,
543
+ "api_version": API_VERSION,
544
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
545
+ }
546
+ )
547
+
548
+ results = []
549
+
550
+ for idx, seq in enumerate(request.sequences, 1):
551
+ try:
552
+ sequence = seq.upper().strip()
553
+
554
+ # Validate sequence length
555
+ if len(sequence) < 10:
556
+ results.append({
557
+ "sequence_index": idx,
558
+ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence,
559
+ "sequence_length": len(sequence),
560
+ "error": "Sequence too short (minimum 10 characters)",
561
+ "predicted_class": None,
562
+ "toxicity_score": None,
563
+ "confidence": None
564
+ })
565
+ continue
566
+
567
+ # Extract features using ProtBERT
568
+ features = extract_features_from_sequence(sequence)
569
+ scaled_features = scaler.transform(features)
570
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
571
+
572
+ with torch.no_grad():
573
+ probability = classifier(features_tensor).cpu().numpy()[0, 0]
574
+
575
+ prediction_class = 1 if probability > 0.5 else 0
576
+ predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic"
577
+ confidence = float(abs(probability - 0.5) * 2)
578
+
579
+ # Determine confidence level
580
+ if confidence > 0.8:
581
+ confidence_level = "high"
582
+ elif confidence > 0.6:
583
+ confidence_level = "medium"
584
+ else:
585
+ confidence_level = "low"
586
+
587
+ results.append({
588
+ "sequence_index": idx,
589
+ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence,
590
+ "sequence_length": len(sequence),
591
+ "predicted_class": predicted_label,
592
+ "toxicity_score": float(probability),
593
+ "non_toxicity_score": float(1 - probability),
594
+ "confidence": confidence,
595
+ "confidence_level": confidence_level,
596
+ "error": None
597
+ })
598
+
599
+ except Exception as e:
600
+ # Handle individual sequence errors without stopping the batch
601
+ results.append({
602
+ "sequence_index": idx,
603
+ "sequence": seq[:100] + "..." if len(seq) > 100 else seq,
604
+ "sequence_length": len(seq),
605
+ "error": f"Error processing sequence: {str(e)}",
606
+ "predicted_class": None,
607
+ "toxicity_score": None,
608
+ "confidence": None
609
+ })
610
+
611
+ processing_time = round((time.time() - start_time) * 1000, 2)
612
+
613
+ # Count successful predictions
614
+ successful_predictions = sum(1 for r in results if r.get("predicted_class") is not None)
615
+
616
+ return PredictionResponse(
617
+ status_code=200,
618
+ status="success",
619
+ success=True,
620
+ data={
621
+ "total_sequences": len(request.sequences),
622
+ "successful_predictions": successful_predictions,
623
+ "failed_predictions": len(request.sequences) - successful_predictions,
624
+ "results": results,
625
+ "metadata": {
626
+ "embedding_model": MODEL_NAME,
627
+ "embedding_type": EMBEDDING_TYPE,
628
+ "model_version": MODEL_VERSION,
629
+ "device": str(device)
630
+ }
631
+ },
632
+ timestamp=timestamp,
633
+ api_version=API_VERSION,
634
+ processing_time_ms=processing_time
635
+ )
636
+
637
+ except HTTPException:
638
+ raise
639
+ except Exception as e:
640
+ processing_time = round((time.time() - start_time) * 1000, 2)
641
+ raise HTTPException(
642
+ status_code=500,
643
+ detail={
644
+ "status_code": 500,
645
+ "status": "error",
646
+ "success": False,
647
+ "error": f"Internal server error: {str(e)}",
648
+ "error_code": "INTERNAL_ERROR",
649
+ "timestamp": timestamp,
650
+ "api_version": API_VERSION,
651
+ "processing_time_ms": processing_time
652
+ }
653
+ )
654
+
655
+ @app.get("/example", response_model=PredictionResponse)
656
+ async def predict_example():
657
+ """
658
+ Predict using a hardcoded example protein sequence.
659
+ No input required - just call this endpoint to see how the API works.
660
+
661
+ Example sequence: MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES
662
+ """
663
+ start_time = time.time()
664
+ timestamp = datetime.now(timezone.utc).isoformat()
665
+
666
+ # Hardcoded example sequence
667
+ EXAMPLE_SEQUENCE = "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES"
668
+
669
+ try:
670
+ # Check if models are loaded
671
+ if classifier is None or scaler is None or transformer_model is None:
672
+ raise HTTPException(
673
+ status_code=503,
674
+ detail={
675
+ "status_code": 503,
676
+ "status": "error",
677
+ "success": False,
678
+ "error": "Models not loaded properly",
679
+ "error_code": "MODEL_NOT_LOADED",
680
+ "timestamp": timestamp,
681
+ "api_version": API_VERSION,
682
+ "processing_time_ms": round((time.time() - start_time) * 1000, 2)
683
+ }
684
+ )
685
+
686
+ sequence = EXAMPLE_SEQUENCE.upper().strip()
687
+
688
+ # Step 1: Extract features using ProtBERT
689
+ features = extract_features_from_sequence(sequence)
690
+
691
+ # Step 2: Scale features
692
+ scaled_features = scaler.transform(features)
693
+
694
+ # Step 3: Predict using MHSA-GRU
695
+ features_tensor = torch.FloatTensor(scaled_features).to(device)
696
+
697
+ with torch.no_grad():
698
+ probability = classifier(features_tensor).cpu().numpy()[0, 0]
699
+
700
+ # Determine prediction
701
+ prediction_class = 1 if probability > 0.5 else 0
702
+ predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic"
703
+ confidence = float(abs(probability - 0.5) * 2)
704
+
705
+ # Determine confidence level
706
+ if confidence > 0.8:
707
+ confidence_level = "high"
708
+ elif confidence > 0.6:
709
+ confidence_level = "medium"
710
+ else:
711
+ confidence_level = "low"
712
+
713
+ processing_time = round((time.time() - start_time) * 1000, 2)
714
+
715
+ return PredictionResponse(
716
+ status_code=200,
717
+ status="success",
718
+ success=True,
719
+ data={
720
+ "note": "This is an example prediction using a hardcoded sequence",
721
+ "sequence": sequence,
722
+ "sequence_length": len(sequence),
723
+ "prediction": {
724
+ "predicted_class": predicted_label,
725
+ "confidence": confidence,
726
+ "confidence_level": confidence_level,
727
+ "toxicity_score": float(probability),
728
+ "non_toxicity_score": float(1 - probability)
729
+ },
730
+ "metadata": {
731
+ "embedding_model": MODEL_NAME,
732
+ "embedding_type": EMBEDDING_TYPE,
733
+ "model_version": MODEL_VERSION,
734
+ "device": str(device),
735
+ "source": "hardcoded_example"
736
+ }
737
+ },
738
+ timestamp=timestamp,
739
+ api_version=API_VERSION,
740
+ processing_time_ms=processing_time
741
+ )
742
+
743
+ except HTTPException:
744
+ raise
745
+ except Exception as e:
746
+ processing_time = round((time.time() - start_time) * 1000, 2)
747
+ raise HTTPException(
748
+ status_code=500,
749
+ detail={
750
+ "status_code": 500,
751
+ "status": "error",
752
+ "success": False,
753
+ "error": f"Internal server error: {str(e)}",
754
+ "error_code": "INTERNAL_ERROR",
755
+ "timestamp": timestamp,
756
+ "api_version": API_VERSION,
757
+ "processing_time_ms": processing_time
758
+ }
759
+ )
760
+
761
+ @app.get("/health", response_model=HealthResponse)
762
+ async def health_check():
763
+ models_loaded = all([
764
+ classifier is not None,
765
+ scaler is not None,
766
+ transformer_model is not None,
767
+ transformer_tokenizer is not None
768
+ ])
769
+
770
+ model_sources = {
771
+ "classifier": {
772
+ "loaded": classifier is not None,
773
+ "source": "huggingface_hub",
774
+ "repository": MODEL_REPO["repo_id"]
775
+ },
776
+ "scaler": {
777
+ "loaded": scaler is not None,
778
+ "source": "huggingface_hub",
779
+ "repository": MODEL_REPO["repo_id"]
780
+ },
781
+ "transformer_model": {
782
+ "loaded": transformer_model is not None,
783
+ "model_name": MODEL_NAME,
784
+ "source": "huggingface_hub",
785
+ "repository": MODEL_REPO["repo_id"]
786
+ }
787
+ }
788
+
789
+ repository_info = {
790
+ "repository_id": MODEL_REPO["repo_id"],
791
+ "embedding_type": EMBEDDING_TYPE,
792
+ "model_name": MODEL_NAME,
793
+ "total_models": len(MODEL_REPO["files"])
794
+ }
795
+
796
+ return HealthResponse(
797
+ status_code=200 if models_loaded else 503,
798
+ status="healthy" if models_loaded else "unhealthy",
799
+ service="Toxicity Prediction API",
800
+ api_version=API_VERSION,
801
+ model_version=MODEL_VERSION,
802
+ models_loaded=models_loaded,
803
+ models_loaded_count=sum(1 for source in model_sources.values() if source["loaded"]),
804
+ total_models_required=4,
805
+ model_sources=model_sources,
806
+ repository_info=repository_info,
807
+ device=str(device),
808
+ timestamp=datetime.now(timezone.utc).isoformat()
809
+ )
810
+
811
+
812
+ if __name__ == "__main__":
813
+ uvicorn.run(app, host="0.0.0.0", port=8000)
convert_base64.ipynb ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "81ff91ce53ae83fe",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2025-07-10T07:07:50.829656Z",
10
+ "start_time": "2025-07-10T07:07:50.824248Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import base64"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "id": "initial_id",
22
+ "metadata": {
23
+ "ExecuteTime": {
24
+ "end_time": "2025-07-10T07:08:19.010102Z",
25
+ "start_time": "2025-07-10T07:08:19.004314Z"
26
+ },
27
+ "collapsed": true
28
+ },
29
+ "outputs": [],
30
+ "source": [
31
+ "with open(\"examples/cancer_example.jpg\", \"rb\") as f:\n",
32
+ " encoded = base64.b64encode(f.read()).decode()"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "35cac43020ae6db3",
39
+ "metadata": {
40
+ "ExecuteTime": {
41
+ "end_time": "2025-07-10T07:08:35.977343Z",
42
+ "start_time": "2025-07-10T07:08:35.973715Z"
43
+ }
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "print(encoded)"
48
+ ]
49
+ }
50
+ ],
51
+ "metadata": {
52
+ "kernelspec": {
53
+ "display_name": "3.12.2",
54
+ "language": "python",
55
+ "name": "python3"
56
+ },
57
+ "language_info": {
58
+ "codemirror_mode": {
59
+ "name": "ipython",
60
+ "version": 3
61
+ },
62
+ "file_extension": ".py",
63
+ "mimetype": "text/x-python",
64
+ "name": "python",
65
+ "nbconvert_exporter": "python",
66
+ "pygments_lexer": "ipython3",
67
+ "version": "3.12.2"
68
+ }
69
+ },
70
+ "nbformat": 4,
71
+ "nbformat_minor": 5
72
+ }
images/camlas-background.png ADDED
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ huggingface_hub
4
+ numpy<2.3.0
5
+ pandas
6
+ scikit-learn
7
+ Pillow
8
+ matplotlib
9
+ seaborn
10
+ plotly
11
+ requests
12
+ dotenv
13
+ fastapi
14
+ uvicorn[standard]
15
+ pydantic
16
+ timm
17
+ python-multipart
18
+ transformers
19
+ # opencv-python
utils/model_classes.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MultiHeadSelfAttention(nn.Module):
5
+ """Multi-Head Self-Attention mechanism"""
6
+ def __init__(self, embed_dim, num_heads, dropout=0.3):
7
+ super(MultiHeadSelfAttention, self).__init__()
8
+ self.attention = nn.MultiheadAttention(
9
+ embed_dim=embed_dim,
10
+ num_heads=num_heads,
11
+ dropout=dropout,
12
+ batch_first=True
13
+ )
14
+ self.layer_norm = nn.LayerNorm(embed_dim)
15
+ self.dropout = nn.Dropout(dropout)
16
+
17
+ def forward(self, x):
18
+ attn_output, _ = self.attention(x, x, x)
19
+ x = self.layer_norm(x + self.dropout(attn_output))
20
+ return x
21
+
22
+
23
+ class MHSA_GRU(nn.Module):
24
+ """Multi-Head Self-Attention with GRU model"""
25
+ def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3):
26
+ super(MHSA_GRU, self).__init__()
27
+
28
+ self.input_dim = input_dim
29
+ self.hidden_dim = hidden_dim
30
+
31
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
32
+ self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
33
+ self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
34
+
35
+ self.gru = nn.GRU(
36
+ input_size=hidden_dim,
37
+ hidden_size=hidden_dim,
38
+ num_layers=num_gru_layers,
39
+ batch_first=True,
40
+ dropout=dropout if num_gru_layers > 1 else 0,
41
+ bidirectional=False
42
+ )
43
+
44
+ self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
45
+ self.dropout = nn.Dropout(dropout)
46
+
47
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
48
+ self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
49
+ self.fc3 = nn.Linear(hidden_dim // 4, 1)
50
+
51
+ self.bn1 = nn.BatchNorm1d(hidden_dim // 2)
52
+ self.bn2 = nn.BatchNorm1d(hidden_dim // 4)
53
+
54
+ def forward(self, x):
55
+ batch_size = x.size(0)
56
+ x = self.input_projection(x)
57
+ x = x.unsqueeze(1)
58
+
59
+ x = self.mhsa1(x)
60
+ x = self.mhsa2(x)
61
+ gru_out, hidden = self.gru(x)
62
+ x = self.mhsa3(gru_out)
63
+ x = x[:, -1, :]
64
+
65
+ x = self.dropout(x)
66
+ x = torch.relu(self.bn1(self.fc1(x)))
67
+ x = self.dropout(x)
68
+ x = torch.relu(self.bn2(self.fc2(x)))
69
+ x = self.dropout(x)
70
+ x = self.fc3(x)
71
+
72
+ return torch.sigmoid(x)