Spaces:
Sleeping
Sleeping
Commit
·
e5d40e3
1
Parent(s):
8d272fe
fix: Update Gradio to 4.44.1 and improve interface
Browse files- .github/workflows/ci.yml +142 -0
- .pre-commit-config.yaml +69 -0
- README.md +159 -52
- app.py +0 -148
- docs/api/README.md +121 -0
- docs/guides/developer_guide.md +362 -0
- docs/guides/user_guide.md +164 -0
- examples/api_client.py +127 -0
- examples/llava_demo.ipynb +1 -0
- examples/process_image.py +103 -0
- pyproject.toml +181 -0
- requirements-dev.txt +38 -0
- requirements.txt +18 -19
- src/__init__.py +0 -0
- src/api/__init__.py +0 -0
- src/api/app.py +159 -0
- src/configs/__init__.py +0 -0
- src/configs/settings.py +46 -0
- src/models/__init__.py +0 -0
- src/models/llava_model.py +88 -0
- main.py → src/models/main.py +0 -0
- src/requirements.txt +26 -0
- src/utils/__init__.py +0 -0
- src/utils/logging.py +51 -0
- tests/test_model.py +67 -0
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [ main ]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
test:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
strategy:
|
| 13 |
+
matrix:
|
| 14 |
+
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
| 15 |
+
|
| 16 |
+
steps:
|
| 17 |
+
- uses: actions/checkout@v4
|
| 18 |
+
|
| 19 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 20 |
+
uses: actions/setup-python@v5
|
| 21 |
+
with:
|
| 22 |
+
python-version: ${{ matrix.python-version }}
|
| 23 |
+
cache: 'pip'
|
| 24 |
+
|
| 25 |
+
- name: Install dependencies
|
| 26 |
+
run: |
|
| 27 |
+
python -m pip install --upgrade pip
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
pip install -r requirements-dev.txt
|
| 30 |
+
|
| 31 |
+
- name: Run pre-commit hooks
|
| 32 |
+
run: |
|
| 33 |
+
pre-commit install
|
| 34 |
+
pre-commit run --all-files
|
| 35 |
+
|
| 36 |
+
- name: Run tests
|
| 37 |
+
run: |
|
| 38 |
+
pytest --cov=src --cov-report=xml
|
| 39 |
+
|
| 40 |
+
- name: Upload coverage to Codecov
|
| 41 |
+
uses: codecov/codecov-action@v4
|
| 42 |
+
with:
|
| 43 |
+
file: ./coverage.xml
|
| 44 |
+
fail_ci_if_error: true
|
| 45 |
+
|
| 46 |
+
lint:
|
| 47 |
+
runs-on: ubuntu-latest
|
| 48 |
+
steps:
|
| 49 |
+
- uses: actions/checkout@v4
|
| 50 |
+
|
| 51 |
+
- name: Set up Python
|
| 52 |
+
uses: actions/setup-python@v5
|
| 53 |
+
with:
|
| 54 |
+
python-version: "3.11"
|
| 55 |
+
cache: 'pip'
|
| 56 |
+
|
| 57 |
+
- name: Install dependencies
|
| 58 |
+
run: |
|
| 59 |
+
python -m pip install --upgrade pip
|
| 60 |
+
pip install -r requirements-dev.txt
|
| 61 |
+
|
| 62 |
+
- name: Run black
|
| 63 |
+
run: black --check src tests
|
| 64 |
+
|
| 65 |
+
- name: Run isort
|
| 66 |
+
run: isort --check-only src tests
|
| 67 |
+
|
| 68 |
+
- name: Run flake8
|
| 69 |
+
run: flake8 src tests
|
| 70 |
+
|
| 71 |
+
- name: Run mypy
|
| 72 |
+
run: mypy src
|
| 73 |
+
|
| 74 |
+
build:
|
| 75 |
+
needs: [test, lint]
|
| 76 |
+
runs-on: ubuntu-latest
|
| 77 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
| 78 |
+
|
| 79 |
+
steps:
|
| 80 |
+
- uses: actions/checkout@v4
|
| 81 |
+
|
| 82 |
+
- name: Set up Python
|
| 83 |
+
uses: actions/setup-python@v5
|
| 84 |
+
with:
|
| 85 |
+
python-version: "3.11"
|
| 86 |
+
cache: 'pip'
|
| 87 |
+
|
| 88 |
+
- name: Install dependencies
|
| 89 |
+
run: |
|
| 90 |
+
python -m pip install --upgrade pip
|
| 91 |
+
pip install build twine
|
| 92 |
+
|
| 93 |
+
- name: Build package
|
| 94 |
+
run: python -m build
|
| 95 |
+
|
| 96 |
+
- name: Check package
|
| 97 |
+
run: twine check dist/*
|
| 98 |
+
|
| 99 |
+
- name: Upload artifacts
|
| 100 |
+
uses: actions/upload-artifact@v4
|
| 101 |
+
with:
|
| 102 |
+
name: dist
|
| 103 |
+
path: dist/
|
| 104 |
+
|
| 105 |
+
deploy:
|
| 106 |
+
needs: build
|
| 107 |
+
runs-on: ubuntu-latest
|
| 108 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
| 109 |
+
|
| 110 |
+
steps:
|
| 111 |
+
- uses: actions/checkout@v4
|
| 112 |
+
|
| 113 |
+
- name: Download artifacts
|
| 114 |
+
uses: actions/download-artifact@v4
|
| 115 |
+
with:
|
| 116 |
+
name: dist
|
| 117 |
+
path: dist
|
| 118 |
+
|
| 119 |
+
- name: Set up Python
|
| 120 |
+
uses: actions/setup-python@v5
|
| 121 |
+
with:
|
| 122 |
+
python-version: "3.11"
|
| 123 |
+
cache: 'pip'
|
| 124 |
+
|
| 125 |
+
- name: Install dependencies
|
| 126 |
+
run: |
|
| 127 |
+
python -m pip install --upgrade pip
|
| 128 |
+
pip install twine
|
| 129 |
+
|
| 130 |
+
- name: Deploy to PyPI
|
| 131 |
+
env:
|
| 132 |
+
TWINE_USERNAME: __token__
|
| 133 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
| 134 |
+
run: twine upload dist/*
|
| 135 |
+
|
| 136 |
+
- name: Deploy to Hugging Face
|
| 137 |
+
env:
|
| 138 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 139 |
+
run: |
|
| 140 |
+
pip install huggingface_hub
|
| 141 |
+
huggingface-cli login --token $HF_TOKEN
|
| 142 |
+
huggingface-cli upload Prashant26am/llava-chat dist/* --repo-type space
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.5.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: check-yaml
|
| 8 |
+
- id: check-added-large-files
|
| 9 |
+
- id: check-ast
|
| 10 |
+
- id: check-json
|
| 11 |
+
- id: check-merge-conflict
|
| 12 |
+
- id: detect-private-key
|
| 13 |
+
- id: debug-statements
|
| 14 |
+
|
| 15 |
+
- repo: https://github.com/psf/black
|
| 16 |
+
rev: 24.1.1
|
| 17 |
+
hooks:
|
| 18 |
+
- id: black
|
| 19 |
+
language_version: python3.8
|
| 20 |
+
|
| 21 |
+
- repo: https://github.com/pycqa/isort
|
| 22 |
+
rev: 5.13.2
|
| 23 |
+
hooks:
|
| 24 |
+
- id: isort
|
| 25 |
+
args: ["--profile", "black"]
|
| 26 |
+
|
| 27 |
+
- repo: https://github.com/pycqa/flake8
|
| 28 |
+
rev: 7.0.0
|
| 29 |
+
hooks:
|
| 30 |
+
- id: flake8
|
| 31 |
+
additional_dependencies:
|
| 32 |
+
- flake8-docstrings
|
| 33 |
+
- flake8-bugbear
|
| 34 |
+
- flake8-comprehensions
|
| 35 |
+
- flake8-simplify
|
| 36 |
+
- flake8-unused-arguments
|
| 37 |
+
- flake8-variables-names
|
| 38 |
+
- pep8-naming
|
| 39 |
+
|
| 40 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 41 |
+
rev: v1.8.0
|
| 42 |
+
hooks:
|
| 43 |
+
- id: mypy
|
| 44 |
+
additional_dependencies:
|
| 45 |
+
- types-Pillow
|
| 46 |
+
- types-requests
|
| 47 |
+
- types-setuptools
|
| 48 |
+
- types-urllib3
|
| 49 |
+
|
| 50 |
+
- repo: https://github.com/asottile/pyupgrade
|
| 51 |
+
rev: v3.15.0
|
| 52 |
+
hooks:
|
| 53 |
+
- id: pyupgrade
|
| 54 |
+
args: [--py38-plus]
|
| 55 |
+
|
| 56 |
+
- repo: https://github.com/PyCQA/bandit
|
| 57 |
+
rev: 1.7.7
|
| 58 |
+
hooks:
|
| 59 |
+
- id: bandit
|
| 60 |
+
args: ["-c", "pyproject.toml"]
|
| 61 |
+
additional_dependencies: ["bandit[toml]"]
|
| 62 |
+
|
| 63 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
| 64 |
+
rev: v4.0.0-alpha.8
|
| 65 |
+
hooks:
|
| 66 |
+
- id: prettier
|
| 67 |
+
types_or: [javascript, jsx, ts, tsx, json, css, scss, md, yaml, yml]
|
| 68 |
+
additional_dependencies:
|
| 69 |
README.md
CHANGED
|
@@ -1,52 +1,159 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaVA Implementation
|
| 2 |
+
|
| 3 |
+
[](https://opensource.org/licenses/MIT)
|
| 4 |
+
[](https://www.python.org/downloads/)
|
| 5 |
+
[](https://gradio.app/)
|
| 6 |
+
[](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
| 7 |
+
|
| 8 |
+
A modern implementation of LLaVA (Large Language and Vision Assistant) with a beautiful web interface. This project combines state-of-the-art vision and language models to create an interactive AI assistant that can understand and discuss images.
|
| 9 |
+
|
| 10 |
+
## 🌟 Features
|
| 11 |
+
|
| 12 |
+
- **Modern Web Interface**
|
| 13 |
+
- Beautiful Gradio-based UI
|
| 14 |
+
- Real-time image analysis
|
| 15 |
+
- Interactive chat experience
|
| 16 |
+
- Responsive design
|
| 17 |
+
|
| 18 |
+
- **Advanced AI Capabilities**
|
| 19 |
+
- CLIP ViT-L/14 vision encoder
|
| 20 |
+
- Vicuna-7B language model
|
| 21 |
+
- Multimodal understanding
|
| 22 |
+
- Natural conversation flow
|
| 23 |
+
|
| 24 |
+
- **Developer Friendly**
|
| 25 |
+
- Clean, modular codebase
|
| 26 |
+
- Comprehensive documentation
|
| 27 |
+
- Easy deployment options
|
| 28 |
+
- Extensible architecture
|
| 29 |
+
|
| 30 |
+
## 📋 Project Structure
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
llava_implementation/
|
| 34 |
+
├── src/ # Source code
|
| 35 |
+
│ ├── api/ # API endpoints and FastAPI app
|
| 36 |
+
│ ├── models/ # Model implementations
|
| 37 |
+
│ ├── utils/ # Utility functions
|
| 38 |
+
│ └── configs/ # Configuration files
|
| 39 |
+
├── tests/ # Test suite
|
| 40 |
+
├── docs/ # Documentation
|
| 41 |
+
│ ├── api/ # API documentation
|
| 42 |
+
│ ├── examples/ # Usage examples
|
| 43 |
+
│ └── guides/ # User and developer guides
|
| 44 |
+
├── assets/ # Static assets
|
| 45 |
+
│ ├── images/ # Example images
|
| 46 |
+
│ └── icons/ # UI icons
|
| 47 |
+
├── scripts/ # Utility scripts
|
| 48 |
+
└── examples/ # Example images for the web interface
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## 🚀 Quick Start
|
| 52 |
+
|
| 53 |
+
### Prerequisites
|
| 54 |
+
|
| 55 |
+
- Python 3.8+
|
| 56 |
+
- CUDA-capable GPU (recommended)
|
| 57 |
+
- Git
|
| 58 |
+
|
| 59 |
+
### Installation
|
| 60 |
+
|
| 61 |
+
1. Clone the repository:
|
| 62 |
+
```bash
|
| 63 |
+
git clone https://github.com/Prashant-ambati/llava-implementation.git
|
| 64 |
+
cd llava-implementation
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
2. Create and activate a virtual environment:
|
| 68 |
+
```bash
|
| 69 |
+
python -m venv venv
|
| 70 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
3. Install dependencies:
|
| 74 |
+
```bash
|
| 75 |
+
pip install -r requirements.txt
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Running Locally
|
| 79 |
+
|
| 80 |
+
1. Start the development server:
|
| 81 |
+
```bash
|
| 82 |
+
python src/api/app.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
2. Open your browser and navigate to:
|
| 86 |
+
```
|
| 87 |
+
http://localhost:7860
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## 🌐 Web Deployment
|
| 91 |
+
|
| 92 |
+
### Hugging Face Spaces
|
| 93 |
+
|
| 94 |
+
The application is deployed on Hugging Face Spaces:
|
| 95 |
+
- [Live Demo](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
| 96 |
+
- Automatic deployment from main branch
|
| 97 |
+
- Free GPU resources
|
| 98 |
+
- Public API access
|
| 99 |
+
|
| 100 |
+
### Local Deployment
|
| 101 |
+
|
| 102 |
+
For local deployment:
|
| 103 |
+
```bash
|
| 104 |
+
# Build the application
|
| 105 |
+
python -m build
|
| 106 |
+
|
| 107 |
+
# Run with production settings
|
| 108 |
+
python src/api/app.py --production
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## 📚 Documentation
|
| 112 |
+
|
| 113 |
+
- [API Documentation](docs/api/README.md)
|
| 114 |
+
- [User Guide](docs/guides/user_guide.md)
|
| 115 |
+
- [Developer Guide](docs/guides/developer_guide.md)
|
| 116 |
+
- [Examples](docs/examples/README.md)
|
| 117 |
+
|
| 118 |
+
## 🛠️ Development
|
| 119 |
+
|
| 120 |
+
### Running Tests
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
pytest tests/
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Code Style
|
| 127 |
+
|
| 128 |
+
This project follows PEP 8 guidelines. To check your code:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
flake8 src/
|
| 132 |
+
black src/
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### Contributing
|
| 136 |
+
|
| 137 |
+
1. Fork the repository
|
| 138 |
+
2. Create a feature branch
|
| 139 |
+
3. Commit your changes
|
| 140 |
+
4. Push to the branch
|
| 141 |
+
5. Create a Pull Request
|
| 142 |
+
|
| 143 |
+
## 📝 License
|
| 144 |
+
|
| 145 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 146 |
+
|
| 147 |
+
## 🙏 Acknowledgments
|
| 148 |
+
|
| 149 |
+
- [LLaVA Paper](https://arxiv.org/abs/2304.08485) by Microsoft Research
|
| 150 |
+
- [Gradio](https://gradio.app/) for the web interface
|
| 151 |
+
- [Hugging Face](https://huggingface.co/) for model hosting
|
| 152 |
+
- [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) for the language model
|
| 153 |
+
- [CLIP](https://openai.com/research/clip) for the vision model
|
| 154 |
+
|
| 155 |
+
## 📞 Contact
|
| 156 |
+
|
| 157 |
+
- GitHub Issues: [Report a bug](https://github.com/Prashant-ambati/llava-implementation/issues)
|
| 158 |
+
- Email: [Your Email]
|
| 159 |
+
- Twitter: [@YourTwitter]
|
app.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
from fastapi.responses import JSONResponse
|
| 4 |
-
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
from typing import Optional
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
import torch
|
| 9 |
-
import gradio as gr
|
| 10 |
-
from models.llava import LLaVA
|
| 11 |
-
|
| 12 |
-
# Initialize model globally
|
| 13 |
-
model = None
|
| 14 |
-
|
| 15 |
-
def initialize_model():
|
| 16 |
-
global model
|
| 17 |
-
try:
|
| 18 |
-
model = LLaVA(
|
| 19 |
-
vision_model_path="openai/clip-vit-large-patch14-336",
|
| 20 |
-
language_model_path="lmsys/vicuna-7b-v1.5",
|
| 21 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 22 |
-
load_in_8bit=True
|
| 23 |
-
)
|
| 24 |
-
print(f"Model initialized on {model.device}")
|
| 25 |
-
return True
|
| 26 |
-
except Exception as e:
|
| 27 |
-
print(f"Error initializing model: {e}")
|
| 28 |
-
return False
|
| 29 |
-
|
| 30 |
-
def process_image(image, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
|
| 31 |
-
if not model:
|
| 32 |
-
return "Error: Model not initialized"
|
| 33 |
-
|
| 34 |
-
try:
|
| 35 |
-
# Save the uploaded image temporarily
|
| 36 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
| 37 |
-
image.save(temp_file.name)
|
| 38 |
-
temp_path = temp_file.name
|
| 39 |
-
|
| 40 |
-
# Generate response
|
| 41 |
-
response = model.generate_from_image(
|
| 42 |
-
image_path=temp_path,
|
| 43 |
-
prompt=prompt,
|
| 44 |
-
max_new_tokens=max_new_tokens,
|
| 45 |
-
temperature=temperature,
|
| 46 |
-
top_p=top_p
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
# Clean up temporary file
|
| 50 |
-
os.unlink(temp_path)
|
| 51 |
-
return response
|
| 52 |
-
|
| 53 |
-
except Exception as e:
|
| 54 |
-
return f"Error processing image: {str(e)}"
|
| 55 |
-
|
| 56 |
-
# Create Gradio interface
|
| 57 |
-
def create_interface():
|
| 58 |
-
with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo:
|
| 59 |
-
gr.Markdown("""
|
| 60 |
-
# LLaVA Chat
|
| 61 |
-
Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations.
|
| 62 |
-
""")
|
| 63 |
-
|
| 64 |
-
with gr.Row():
|
| 65 |
-
with gr.Column(scale=1):
|
| 66 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
| 67 |
-
prompt_input = gr.Textbox(
|
| 68 |
-
label="Ask about the image",
|
| 69 |
-
placeholder="What can you see in this image?",
|
| 70 |
-
lines=3
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
with gr.Accordion("Advanced Settings", open=False):
|
| 74 |
-
max_tokens = gr.Slider(
|
| 75 |
-
minimum=32,
|
| 76 |
-
maximum=512,
|
| 77 |
-
value=256,
|
| 78 |
-
step=32,
|
| 79 |
-
label="Max New Tokens"
|
| 80 |
-
)
|
| 81 |
-
temperature = gr.Slider(
|
| 82 |
-
minimum=0.1,
|
| 83 |
-
maximum=1.0,
|
| 84 |
-
value=0.7,
|
| 85 |
-
step=0.1,
|
| 86 |
-
label="Temperature"
|
| 87 |
-
)
|
| 88 |
-
top_p = gr.Slider(
|
| 89 |
-
minimum=0.1,
|
| 90 |
-
maximum=1.0,
|
| 91 |
-
value=0.9,
|
| 92 |
-
step=0.1,
|
| 93 |
-
label="Top P"
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
submit_btn = gr.Button("Generate Response", variant="primary")
|
| 97 |
-
|
| 98 |
-
with gr.Column(scale=1):
|
| 99 |
-
output = gr.Textbox(
|
| 100 |
-
label="Model Response",
|
| 101 |
-
lines=10,
|
| 102 |
-
show_copy_button=True
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
# Set up the submit action
|
| 106 |
-
submit_btn.click(
|
| 107 |
-
fn=process_image,
|
| 108 |
-
inputs=[image_input, prompt_input, max_tokens, temperature, top_p],
|
| 109 |
-
outputs=output
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
# Add examples
|
| 113 |
-
gr.Examples(
|
| 114 |
-
examples=[
|
| 115 |
-
["examples/cat.jpg", "What can you see in this image?"],
|
| 116 |
-
["examples/landscape.jpg", "Describe this scene in detail."],
|
| 117 |
-
["examples/food.jpg", "What kind of food is this and how would you describe it?"]
|
| 118 |
-
],
|
| 119 |
-
inputs=[image_input, prompt_input]
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
return demo
|
| 123 |
-
|
| 124 |
-
# Create FastAPI app
|
| 125 |
-
app = FastAPI(title="LLaVA Web Interface")
|
| 126 |
-
|
| 127 |
-
# Configure CORS
|
| 128 |
-
app.add_middleware(
|
| 129 |
-
CORSMiddleware,
|
| 130 |
-
allow_origins=["*"],
|
| 131 |
-
allow_credentials=True,
|
| 132 |
-
allow_methods=["*"],
|
| 133 |
-
allow_headers=["*"],
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
# Create Gradio app
|
| 137 |
-
demo = create_interface()
|
| 138 |
-
|
| 139 |
-
# Mount Gradio app
|
| 140 |
-
app = gr.mount_gradio_app(app, demo, path="/")
|
| 141 |
-
|
| 142 |
-
if __name__ == "__main__":
|
| 143 |
-
# Initialize model
|
| 144 |
-
if initialize_model():
|
| 145 |
-
import uvicorn
|
| 146 |
-
uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face Spaces uses port 7860
|
| 147 |
-
else:
|
| 148 |
-
print("Failed to initialize model. Exiting...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/api/README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaVA API Documentation
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The LLaVA API provides a simple interface for interacting with the LLaVA model through a Gradio web interface. The API allows users to upload images and receive AI-generated responses about the image content.
|
| 6 |
+
|
| 7 |
+
## API Endpoints
|
| 8 |
+
|
| 9 |
+
### Web Interface
|
| 10 |
+
|
| 11 |
+
The main interface is served at the root URL (`/`) and provides the following components:
|
| 12 |
+
|
| 13 |
+
#### Input Components
|
| 14 |
+
|
| 15 |
+
1. **Image Upload**
|
| 16 |
+
- Type: Image uploader
|
| 17 |
+
- Format: PIL Image
|
| 18 |
+
- Purpose: Upload an image for analysis
|
| 19 |
+
|
| 20 |
+
2. **Prompt Input**
|
| 21 |
+
- Type: Text input
|
| 22 |
+
- Purpose: Enter questions or prompts about the image
|
| 23 |
+
- Default placeholder: "What can you see in this image?"
|
| 24 |
+
|
| 25 |
+
3. **Generation Parameters**
|
| 26 |
+
- Max New Tokens (64-2048, default: 512)
|
| 27 |
+
- Temperature (0.1-1.0, default: 0.7)
|
| 28 |
+
- Top P (0.1-1.0, default: 0.9)
|
| 29 |
+
|
| 30 |
+
#### Output Components
|
| 31 |
+
|
| 32 |
+
1. **Response**
|
| 33 |
+
- Type: Text output
|
| 34 |
+
- Purpose: Displays the model's response
|
| 35 |
+
- Features: Copy button, scrollable
|
| 36 |
+
|
| 37 |
+
## Usage Examples
|
| 38 |
+
|
| 39 |
+
### Basic Usage
|
| 40 |
+
|
| 41 |
+
1. Upload an image using the image uploader
|
| 42 |
+
2. Enter a prompt in the text input
|
| 43 |
+
3. Click "Generate Response"
|
| 44 |
+
4. View the response in the output box
|
| 45 |
+
|
| 46 |
+
### Example Prompts
|
| 47 |
+
|
| 48 |
+
- "What can you see in this image?"
|
| 49 |
+
- "Describe this scene in detail"
|
| 50 |
+
- "What emotions does this image convey?"
|
| 51 |
+
- "What's happening in this picture?"
|
| 52 |
+
- "Can you identify any objects or people in this image?"
|
| 53 |
+
|
| 54 |
+
## Error Handling
|
| 55 |
+
|
| 56 |
+
The API handles various error cases:
|
| 57 |
+
|
| 58 |
+
1. **Invalid Images**
|
| 59 |
+
- Returns an error message if the image is invalid or corrupted
|
| 60 |
+
- Supports common image formats (JPEG, PNG, etc.)
|
| 61 |
+
|
| 62 |
+
2. **Empty Prompts**
|
| 63 |
+
- Returns an error message if no prompt is provided
|
| 64 |
+
- Prompts should be non-empty strings
|
| 65 |
+
|
| 66 |
+
3. **Model Errors**
|
| 67 |
+
- Returns descriptive error messages for model-related issues
|
| 68 |
+
- Includes logging for debugging
|
| 69 |
+
|
| 70 |
+
## Configuration
|
| 71 |
+
|
| 72 |
+
The API can be configured through environment variables or the settings file:
|
| 73 |
+
|
| 74 |
+
- `API_HOST`: Server host (default: "0.0.0.0")
|
| 75 |
+
- `API_PORT`: Server port (default: 7860)
|
| 76 |
+
- `GRADIO_THEME`: Interface theme (default: "soft")
|
| 77 |
+
- `DEFAULT_MAX_NEW_TOKENS`: Default token limit (default: 512)
|
| 78 |
+
- `DEFAULT_TEMPERATURE`: Default temperature (default: 0.7)
|
| 79 |
+
- `DEFAULT_TOP_P`: Default top-p value (default: 0.9)
|
| 80 |
+
|
| 81 |
+
## Development
|
| 82 |
+
|
| 83 |
+
### Running Locally
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
python src/api/app.py
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Running Tests
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pytest tests/
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Code Style
|
| 96 |
+
|
| 97 |
+
The project follows PEP 8 guidelines. To check your code:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
flake8 src/
|
| 101 |
+
black src/
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Security Considerations
|
| 105 |
+
|
| 106 |
+
1. The API is designed for public use but should be deployed behind appropriate security measures
|
| 107 |
+
2. Input validation is performed on all user inputs
|
| 108 |
+
3. Large file uploads are handled safely
|
| 109 |
+
4. Error messages are sanitized to prevent information leakage
|
| 110 |
+
|
| 111 |
+
## Rate Limiting
|
| 112 |
+
|
| 113 |
+
Currently, no rate limiting is implemented. Consider implementing rate limiting for production deployments.
|
| 114 |
+
|
| 115 |
+
## Future Improvements
|
| 116 |
+
|
| 117 |
+
1. Add authentication
|
| 118 |
+
2. Implement rate limiting
|
| 119 |
+
3. Add batch processing capabilities
|
| 120 |
+
4. Support for video input
|
| 121 |
+
5. Real-time streaming responses
|
docs/guides/developer_guide.md
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaVA Implementation Developer Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This guide is intended for developers who want to contribute to or extend the LLaVA implementation. The project is structured as a Python package with a Gradio web interface, using modern best practices and tools.
|
| 6 |
+
|
| 7 |
+
## Project Structure
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
llava_implementation/
|
| 11 |
+
├── src/ # Source code
|
| 12 |
+
│ ├── api/ # API endpoints and FastAPI app
|
| 13 |
+
│ │ ├── __init__.py
|
| 14 |
+
│ │ └── app.py # Gradio interface
|
| 15 |
+
│ ├── models/ # Model implementations
|
| 16 |
+
│ │ ├── __init__.py
|
| 17 |
+
│ │ └── llava_model.py # LLaVA model wrapper
|
| 18 |
+
│ ├── utils/ # Utility functions
|
| 19 |
+
│ │ ├── __init__.py
|
| 20 |
+
│ │ └── logging.py # Logging utilities
|
| 21 |
+
│ └── configs/ # Configuration files
|
| 22 |
+
│ ├── __init__.py
|
| 23 |
+
│ └── settings.py # Application settings
|
| 24 |
+
├── tests/ # Test suite
|
| 25 |
+
│ ├── __init__.py
|
| 26 |
+
│ └── test_model.py # Model tests
|
| 27 |
+
├── docs/ # Documentation
|
| 28 |
+
│ ├── api/ # API documentation
|
| 29 |
+
│ ├── examples/ # Usage examples
|
| 30 |
+
│ └── guides/ # User and developer guides
|
| 31 |
+
├── assets/ # Static assets
|
| 32 |
+
│ ├── images/ # Example images
|
| 33 |
+
│ └── icons/ # UI icons
|
| 34 |
+
├── scripts/ # Utility scripts
|
| 35 |
+
└── examples/ # Example images for the web interface
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Development Setup
|
| 39 |
+
|
| 40 |
+
### Prerequisites
|
| 41 |
+
|
| 42 |
+
- Python 3.8+
|
| 43 |
+
- Git
|
| 44 |
+
- CUDA-capable GPU (recommended)
|
| 45 |
+
- Virtual environment tool (venv, conda, etc.)
|
| 46 |
+
|
| 47 |
+
### Installation
|
| 48 |
+
|
| 49 |
+
1. Clone the repository:
|
| 50 |
+
```bash
|
| 51 |
+
git clone https://github.com/Prashant-ambati/llava-implementation.git
|
| 52 |
+
cd llava-implementation
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
2. Create and activate a virtual environment:
|
| 56 |
+
```bash
|
| 57 |
+
python -m venv venv
|
| 58 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
3. Install development dependencies:
|
| 62 |
+
```bash
|
| 63 |
+
pip install -r requirements.txt
|
| 64 |
+
pip install -r requirements-dev.txt # Development dependencies
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Development Tools
|
| 68 |
+
|
| 69 |
+
1. **Code Formatting**
|
| 70 |
+
- Black for code formatting
|
| 71 |
+
- isort for import sorting
|
| 72 |
+
- flake8 for linting
|
| 73 |
+
|
| 74 |
+
2. **Testing**
|
| 75 |
+
- pytest for testing
|
| 76 |
+
- pytest-cov for coverage
|
| 77 |
+
- pytest-mock for mocking
|
| 78 |
+
|
| 79 |
+
3. **Type Checking**
|
| 80 |
+
- mypy for static type checking
|
| 81 |
+
- types-* packages for type hints
|
| 82 |
+
|
| 83 |
+
## Code Style
|
| 84 |
+
|
| 85 |
+
### Python Style Guide
|
| 86 |
+
|
| 87 |
+
1. Follow PEP 8 guidelines
|
| 88 |
+
2. Use type hints
|
| 89 |
+
3. Write docstrings (Google style)
|
| 90 |
+
4. Keep functions focused and small
|
| 91 |
+
5. Use meaningful variable names
|
| 92 |
+
|
| 93 |
+
### Example
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
from typing import Optional, List
|
| 97 |
+
from PIL import Image
|
| 98 |
+
|
| 99 |
+
def process_image(
|
| 100 |
+
image: Image.Image,
|
| 101 |
+
prompt: str,
|
| 102 |
+
max_tokens: Optional[int] = None
|
| 103 |
+
) -> List[str]:
|
| 104 |
+
"""
|
| 105 |
+
Process an image with the given prompt.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
image: Input image as PIL Image
|
| 109 |
+
prompt: Text prompt for the model
|
| 110 |
+
max_tokens: Optional maximum tokens to generate
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
List of generated responses
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
ValueError: If image is invalid
|
| 117 |
+
RuntimeError: If model fails to process
|
| 118 |
+
"""
|
| 119 |
+
# Implementation
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Testing
|
| 123 |
+
|
| 124 |
+
### Running Tests
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
# Run all tests
|
| 128 |
+
pytest
|
| 129 |
+
|
| 130 |
+
# Run with coverage
|
| 131 |
+
pytest --cov=src
|
| 132 |
+
|
| 133 |
+
# Run specific test file
|
| 134 |
+
pytest tests/test_model.py
|
| 135 |
+
|
| 136 |
+
# Run with verbose output
|
| 137 |
+
pytest -v
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### Writing Tests
|
| 141 |
+
|
| 142 |
+
1. Use pytest fixtures
|
| 143 |
+
2. Mock external dependencies
|
| 144 |
+
3. Test edge cases
|
| 145 |
+
4. Include both unit and integration tests
|
| 146 |
+
|
| 147 |
+
Example test:
|
| 148 |
+
```python
|
| 149 |
+
import pytest
|
| 150 |
+
from PIL import Image
|
| 151 |
+
|
| 152 |
+
def test_process_image(model, sample_image):
|
| 153 |
+
"""Test image processing functionality."""
|
| 154 |
+
prompt = "What color is this image?"
|
| 155 |
+
response = model.process_image(
|
| 156 |
+
image=sample_image,
|
| 157 |
+
prompt=prompt
|
| 158 |
+
)
|
| 159 |
+
assert isinstance(response, str)
|
| 160 |
+
assert len(response) > 0
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
## Model Development
|
| 164 |
+
|
| 165 |
+
### Adding New Models
|
| 166 |
+
|
| 167 |
+
1. Create a new model class in `src/models/`
|
| 168 |
+
2. Implement required methods
|
| 169 |
+
3. Add tests
|
| 170 |
+
4. Update documentation
|
| 171 |
+
|
| 172 |
+
Example:
|
| 173 |
+
```python
|
| 174 |
+
class NewModel:
|
| 175 |
+
"""New model implementation."""
|
| 176 |
+
|
| 177 |
+
def __init__(self, config: dict):
|
| 178 |
+
"""Initialize the model."""
|
| 179 |
+
self.config = config
|
| 180 |
+
self.model = self._load_model()
|
| 181 |
+
|
| 182 |
+
def process(self, *args, **kwargs):
|
| 183 |
+
"""Process inputs and generate output."""
|
| 184 |
+
pass
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### Model Configuration
|
| 188 |
+
|
| 189 |
+
1. Add configuration in `src/configs/settings.py`
|
| 190 |
+
2. Use environment variables for secrets
|
| 191 |
+
3. Document all parameters
|
| 192 |
+
|
| 193 |
+
## API Development
|
| 194 |
+
|
| 195 |
+
### Adding New Endpoints
|
| 196 |
+
|
| 197 |
+
1. Create new endpoint in `src/api/app.py`
|
| 198 |
+
2. Add input validation
|
| 199 |
+
3. Implement error handling
|
| 200 |
+
4. Add tests
|
| 201 |
+
5. Update documentation
|
| 202 |
+
|
| 203 |
+
### Error Handling
|
| 204 |
+
|
| 205 |
+
1. Use custom exceptions
|
| 206 |
+
2. Implement proper logging
|
| 207 |
+
3. Return appropriate status codes
|
| 208 |
+
4. Include error messages
|
| 209 |
+
|
| 210 |
+
Example:
|
| 211 |
+
```python
|
| 212 |
+
class ModelError(Exception):
|
| 213 |
+
"""Base exception for model errors."""
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
def process_request(request):
|
| 217 |
+
try:
|
| 218 |
+
result = model.process(request)
|
| 219 |
+
return result
|
| 220 |
+
except ModelError as e:
|
| 221 |
+
logger.error(f"Model error: {e}")
|
| 222 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
## Deployment
|
| 226 |
+
|
| 227 |
+
### Local Deployment
|
| 228 |
+
|
| 229 |
+
1. Build the package:
|
| 230 |
+
```bash
|
| 231 |
+
python -m build
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
2. Run the server:
|
| 235 |
+
```bash
|
| 236 |
+
python src/api/app.py
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
### Hugging Face Spaces
|
| 240 |
+
|
| 241 |
+
1. Update `README.md` with Space metadata
|
| 242 |
+
2. Ensure all dependencies are in `requirements.txt`
|
| 243 |
+
3. Test the Space locally
|
| 244 |
+
4. Push changes to the Space
|
| 245 |
+
|
| 246 |
+
### Production Deployment
|
| 247 |
+
|
| 248 |
+
1. Set up proper logging
|
| 249 |
+
2. Configure security measures
|
| 250 |
+
3. Implement rate limiting
|
| 251 |
+
4. Set up monitoring
|
| 252 |
+
5. Use environment variables
|
| 253 |
+
|
| 254 |
+
## Contributing
|
| 255 |
+
|
| 256 |
+
### Workflow
|
| 257 |
+
|
| 258 |
+
1. Fork the repository
|
| 259 |
+
2. Create a feature branch
|
| 260 |
+
3. Make changes
|
| 261 |
+
4. Run tests
|
| 262 |
+
5. Update documentation
|
| 263 |
+
6. Create a pull request
|
| 264 |
+
|
| 265 |
+
### Pull Request Process
|
| 266 |
+
|
| 267 |
+
1. Update documentation
|
| 268 |
+
2. Add tests
|
| 269 |
+
3. Ensure CI passes
|
| 270 |
+
4. Get code review
|
| 271 |
+
5. Address feedback
|
| 272 |
+
6. Merge when approved
|
| 273 |
+
|
| 274 |
+
## Performance Optimization
|
| 275 |
+
|
| 276 |
+
### Model Optimization
|
| 277 |
+
|
| 278 |
+
1. Use model quantization
|
| 279 |
+
2. Implement caching
|
| 280 |
+
3. Batch processing
|
| 281 |
+
4. GPU optimization
|
| 282 |
+
|
| 283 |
+
### API Optimization
|
| 284 |
+
|
| 285 |
+
1. Response compression
|
| 286 |
+
2. Request validation
|
| 287 |
+
3. Connection pooling
|
| 288 |
+
4. Caching strategies
|
| 289 |
+
|
| 290 |
+
## Security
|
| 291 |
+
|
| 292 |
+
### Best Practices
|
| 293 |
+
|
| 294 |
+
1. Input validation
|
| 295 |
+
2. Error handling
|
| 296 |
+
3. Rate limiting
|
| 297 |
+
4. Secure configuration
|
| 298 |
+
5. Regular updates
|
| 299 |
+
|
| 300 |
+
### Security Checklist
|
| 301 |
+
|
| 302 |
+
- [ ] Validate all inputs
|
| 303 |
+
- [ ] Sanitize outputs
|
| 304 |
+
- [ ] Use secure dependencies
|
| 305 |
+
- [ ] Implement rate limiting
|
| 306 |
+
- [ ] Set up monitoring
|
| 307 |
+
- [ ] Regular security audits
|
| 308 |
+
|
| 309 |
+
## Monitoring and Logging
|
| 310 |
+
|
| 311 |
+
### Logging
|
| 312 |
+
|
| 313 |
+
1. Use structured logging
|
| 314 |
+
2. Include context
|
| 315 |
+
3. Set appropriate levels
|
| 316 |
+
4. Rotate logs
|
| 317 |
+
|
| 318 |
+
### Monitoring
|
| 319 |
+
|
| 320 |
+
1. Track key metrics
|
| 321 |
+
2. Set up alerts
|
| 322 |
+
3. Monitor resources
|
| 323 |
+
4. Track errors
|
| 324 |
+
|
| 325 |
+
## Future Development
|
| 326 |
+
|
| 327 |
+
### Planned Features
|
| 328 |
+
|
| 329 |
+
1. Video support
|
| 330 |
+
2. Batch processing
|
| 331 |
+
3. Model fine-tuning
|
| 332 |
+
4. API authentication
|
| 333 |
+
5. Advanced caching
|
| 334 |
+
|
| 335 |
+
### Contributing Ideas
|
| 336 |
+
|
| 337 |
+
1. Open issues
|
| 338 |
+
2. Discuss in PRs
|
| 339 |
+
3. Join discussions
|
| 340 |
+
4. Share use cases
|
| 341 |
+
|
| 342 |
+
## Resources
|
| 343 |
+
|
| 344 |
+
### Documentation
|
| 345 |
+
|
| 346 |
+
- [Python Documentation](https://docs.python.org/)
|
| 347 |
+
- [Gradio Documentation](https://gradio.app/docs/)
|
| 348 |
+
- [Hugging Face Docs](https://huggingface.co/docs)
|
| 349 |
+
- [Pytest Documentation](https://docs.pytest.org/)
|
| 350 |
+
|
| 351 |
+
### Tools
|
| 352 |
+
|
| 353 |
+
- [Black](https://black.readthedocs.io/)
|
| 354 |
+
- [isort](https://pycqa.github.io/isort/)
|
| 355 |
+
- [flake8](https://flake8.pycqa.org/)
|
| 356 |
+
- [mypy](https://mypy.readthedocs.io/)
|
| 357 |
+
|
| 358 |
+
### Community
|
| 359 |
+
|
| 360 |
+
- [GitHub Issues](https://github.com/Prashant-ambati/llava-implementation/issues)
|
| 361 |
+
- [Hugging Face Forums](https://discuss.huggingface.co/)
|
| 362 |
+
- [Stack Overflow](https://stackoverflow.com/)
|
docs/guides/user_guide.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaVA Chat User Guide
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Welcome to LLaVA Chat! This guide will help you get started with using our AI-powered image understanding and chat interface. LLaVA (Large Language and Vision Assistant) combines advanced vision and language models to provide detailed analysis and natural conversations about images.
|
| 6 |
+
|
| 7 |
+
## Getting Started
|
| 8 |
+
|
| 9 |
+
### Accessing the Interface
|
| 10 |
+
|
| 11 |
+
1. Visit our [Hugging Face Space](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
| 12 |
+
2. Wait for the interface to load (this may take a few moments as the model initializes)
|
| 13 |
+
3. You're ready to start chatting with images!
|
| 14 |
+
|
| 15 |
+
### Basic Usage
|
| 16 |
+
|
| 17 |
+
1. **Upload an Image**
|
| 18 |
+
- Click the image upload area or drag and drop an image
|
| 19 |
+
- Supported formats: JPEG, PNG, GIF
|
| 20 |
+
- Maximum file size: 10MB
|
| 21 |
+
|
| 22 |
+
2. **Enter Your Prompt**
|
| 23 |
+
- Type your question or prompt in the text box
|
| 24 |
+
- Be specific about what you want to know
|
| 25 |
+
- You can ask multiple questions about the same image
|
| 26 |
+
|
| 27 |
+
3. **Adjust Parameters** (Optional)
|
| 28 |
+
- Click "Generation Parameters" to expand
|
| 29 |
+
- Modify settings to control the response:
|
| 30 |
+
- Max New Tokens: Longer responses (64-2048)
|
| 31 |
+
- Temperature: More creative responses (0.1-1.0)
|
| 32 |
+
- Top P: More diverse responses (0.1-1.0)
|
| 33 |
+
|
| 34 |
+
4. **Generate Response**
|
| 35 |
+
- Click the "Generate Response" button
|
| 36 |
+
- Wait for the model to process (usually a few seconds)
|
| 37 |
+
- Read the response in the output box
|
| 38 |
+
- Use the copy button to save the response
|
| 39 |
+
|
| 40 |
+
## Best Practices
|
| 41 |
+
|
| 42 |
+
### Writing Effective Prompts
|
| 43 |
+
|
| 44 |
+
1. **Be Specific**
|
| 45 |
+
- Instead of "What's in this image?", try "What objects can you identify in this image?"
|
| 46 |
+
- Instead of "Describe this", try "Describe the scene, focusing on the main subject"
|
| 47 |
+
|
| 48 |
+
2. **Ask Follow-up Questions**
|
| 49 |
+
- "What emotions does this image convey?"
|
| 50 |
+
- "Can you identify any specific details about [object]?"
|
| 51 |
+
- "How would you describe the composition of this image?"
|
| 52 |
+
|
| 53 |
+
3. **Use Natural Language**
|
| 54 |
+
- Write as if you're talking to a person
|
| 55 |
+
- Feel free to ask for clarification or more details
|
| 56 |
+
- You can have a conversation about the image
|
| 57 |
+
|
| 58 |
+
### Example Prompts
|
| 59 |
+
|
| 60 |
+
1. **General Analysis**
|
| 61 |
+
- "What can you see in this image?"
|
| 62 |
+
- "Describe this scene in detail"
|
| 63 |
+
- "What's the main subject of this image?"
|
| 64 |
+
|
| 65 |
+
2. **Specific Details**
|
| 66 |
+
- "What colors are prominent in this image?"
|
| 67 |
+
- "Can you identify any text or signs in the image?"
|
| 68 |
+
- "What time of day does this image appear to be taken?"
|
| 69 |
+
|
| 70 |
+
3. **Emotional Response**
|
| 71 |
+
- "What mood or atmosphere does this image convey?"
|
| 72 |
+
- "How does this image make you feel?"
|
| 73 |
+
- "What emotions might this image evoke in viewers?"
|
| 74 |
+
|
| 75 |
+
4. **Technical Analysis**
|
| 76 |
+
- "What's the composition of this image?"
|
| 77 |
+
- "How would you describe the lighting in this image?"
|
| 78 |
+
- "What camera angle or perspective is used?"
|
| 79 |
+
|
| 80 |
+
## Troubleshooting
|
| 81 |
+
|
| 82 |
+
### Common Issues
|
| 83 |
+
|
| 84 |
+
1. **Image Not Loading**
|
| 85 |
+
- Check file format (JPEG, PNG, GIF)
|
| 86 |
+
- Ensure file size is under 10MB
|
| 87 |
+
- Try refreshing the page
|
| 88 |
+
|
| 89 |
+
2. **Slow Response**
|
| 90 |
+
- Reduce image size
|
| 91 |
+
- Simplify your prompt
|
| 92 |
+
- Check your internet connection
|
| 93 |
+
|
| 94 |
+
3. **Unexpected Responses**
|
| 95 |
+
- Try rephrasing your prompt
|
| 96 |
+
- Adjust generation parameters
|
| 97 |
+
- Be more specific in your question
|
| 98 |
+
|
| 99 |
+
### Getting Help
|
| 100 |
+
|
| 101 |
+
If you encounter any issues:
|
| 102 |
+
1. Check this guide for solutions
|
| 103 |
+
2. Visit our [GitHub repository](https://github.com/Prashant-ambati/llava-implementation)
|
| 104 |
+
3. Open an issue on GitHub
|
| 105 |
+
4. Contact us through Hugging Face
|
| 106 |
+
|
| 107 |
+
## Advanced Usage
|
| 108 |
+
|
| 109 |
+
### Parameter Tuning
|
| 110 |
+
|
| 111 |
+
1. **Max New Tokens**
|
| 112 |
+
- Lower values (64-256): Short, concise responses
|
| 113 |
+
- Medium values (256-512): Balanced responses
|
| 114 |
+
- Higher values (512+): Detailed, comprehensive responses
|
| 115 |
+
|
| 116 |
+
2. **Temperature**
|
| 117 |
+
- Lower values (0.1-0.3): More focused, deterministic responses
|
| 118 |
+
- Medium values (0.4-0.7): Balanced creativity
|
| 119 |
+
- Higher values (0.8-1.0): More creative, diverse responses
|
| 120 |
+
|
| 121 |
+
3. **Top P**
|
| 122 |
+
- Lower values (0.1-0.3): More focused word choice
|
| 123 |
+
- Medium values (0.4-0.7): Balanced diversity
|
| 124 |
+
- Higher values (0.8-1.0): More diverse word choice
|
| 125 |
+
|
| 126 |
+
### Tips for Better Results
|
| 127 |
+
|
| 128 |
+
1. **Image Quality**
|
| 129 |
+
- Use clear, well-lit images
|
| 130 |
+
- Ensure the subject is clearly visible
|
| 131 |
+
- Avoid heavily edited or filtered images
|
| 132 |
+
|
| 133 |
+
2. **Prompt Engineering**
|
| 134 |
+
- Start with simple questions
|
| 135 |
+
- Build up to more complex queries
|
| 136 |
+
- Use follow-up questions for details
|
| 137 |
+
|
| 138 |
+
3. **Response Management**
|
| 139 |
+
- Copy important responses
|
| 140 |
+
- Save interesting conversations
|
| 141 |
+
- Compare responses with different parameters
|
| 142 |
+
|
| 143 |
+
## Privacy and Ethics
|
| 144 |
+
|
| 145 |
+
1. **Image Privacy**
|
| 146 |
+
- Don't upload sensitive or private images
|
| 147 |
+
- Be mindful of copyright
|
| 148 |
+
- Respect others' privacy
|
| 149 |
+
|
| 150 |
+
2. **Responsible Use**
|
| 151 |
+
- Use the tool ethically
|
| 152 |
+
- Don't use for harmful purposes
|
| 153 |
+
- Respect content guidelines
|
| 154 |
+
|
| 155 |
+
## Future Updates
|
| 156 |
+
|
| 157 |
+
We're constantly improving LLaVA Chat. Planned features include:
|
| 158 |
+
1. Support for video input
|
| 159 |
+
2. Batch image processing
|
| 160 |
+
3. More advanced parameter controls
|
| 161 |
+
4. Additional model options
|
| 162 |
+
5. Enhanced response formatting
|
| 163 |
+
|
| 164 |
+
Stay tuned for updates!
|
examples/api_client.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example API client for the LLaVA model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import base64
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
|
| 15 |
+
def encode_image(image_path: str) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Encode an image to base64 string.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
image_path: Path to the image file
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
str: Base64 encoded image
|
| 24 |
+
"""
|
| 25 |
+
with open(image_path, "rb") as image_file:
|
| 26 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 27 |
+
|
| 28 |
+
def process_image(
|
| 29 |
+
api_url: str,
|
| 30 |
+
image_path: str,
|
| 31 |
+
prompt: str,
|
| 32 |
+
max_new_tokens: Optional[int] = None,
|
| 33 |
+
temperature: Optional[float] = None,
|
| 34 |
+
top_p: Optional[float] = None
|
| 35 |
+
) -> Dict[str, Any]:
|
| 36 |
+
"""
|
| 37 |
+
Process an image using the LLaVA API.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
api_url: URL of the API endpoint
|
| 41 |
+
image_path: Path to the input image
|
| 42 |
+
prompt: Text prompt for the model
|
| 43 |
+
max_new_tokens: Optional maximum tokens to generate
|
| 44 |
+
temperature: Optional sampling temperature
|
| 45 |
+
top_p: Optional top-p sampling parameter
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Dict containing the API response
|
| 49 |
+
"""
|
| 50 |
+
# Prepare the request payload
|
| 51 |
+
payload = {
|
| 52 |
+
"image": encode_image(image_path),
|
| 53 |
+
"prompt": prompt
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# Add optional parameters if provided
|
| 57 |
+
if max_new_tokens is not None:
|
| 58 |
+
payload["max_new_tokens"] = max_new_tokens
|
| 59 |
+
if temperature is not None:
|
| 60 |
+
payload["temperature"] = temperature
|
| 61 |
+
if top_p is not None:
|
| 62 |
+
payload["top_p"] = top_p
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Send the request
|
| 66 |
+
response = requests.post(api_url, json=payload)
|
| 67 |
+
response.raise_for_status()
|
| 68 |
+
return response.json()
|
| 69 |
+
|
| 70 |
+
except requests.exceptions.RequestException as e:
|
| 71 |
+
print(f"Error making request: {e}")
|
| 72 |
+
if hasattr(e.response, 'text'):
|
| 73 |
+
print(f"Response: {e.response.text}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def save_response(response: Dict[str, Any], output_path: Optional[str] = None):
|
| 77 |
+
"""
|
| 78 |
+
Save or print the API response.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
response: API response dictionary
|
| 82 |
+
output_path: Optional path to save the response
|
| 83 |
+
"""
|
| 84 |
+
if output_path:
|
| 85 |
+
with open(output_path, 'w') as f:
|
| 86 |
+
json.dump(response, f, indent=2)
|
| 87 |
+
print(f"Saved response to {output_path}")
|
| 88 |
+
else:
|
| 89 |
+
print("\nAPI Response:")
|
| 90 |
+
print("-" * 50)
|
| 91 |
+
print(json.dumps(response, indent=2))
|
| 92 |
+
print("-" * 50)
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
"""Main function to process images using the API."""
|
| 96 |
+
parser = argparse.ArgumentParser(description="Process images using LLaVA API")
|
| 97 |
+
parser.add_argument("image_path", type=str, help="Path to the input image")
|
| 98 |
+
parser.add_argument("prompt", type=str, help="Text prompt for the model")
|
| 99 |
+
parser.add_argument("--api-url", type=str, default="http://localhost:7860/api/process",
|
| 100 |
+
help="URL of the API endpoint")
|
| 101 |
+
parser.add_argument("--max-tokens", type=int, help="Maximum tokens to generate")
|
| 102 |
+
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
| 103 |
+
parser.add_argument("--top-p", type=float, help="Top-p sampling parameter")
|
| 104 |
+
parser.add_argument("--output", type=str, help="Path to save the response")
|
| 105 |
+
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# Process image
|
| 110 |
+
response = process_image(
|
| 111 |
+
api_url=args.api_url,
|
| 112 |
+
image_path=args.image_path,
|
| 113 |
+
prompt=args.prompt,
|
| 114 |
+
max_new_tokens=args.max_tokens,
|
| 115 |
+
temperature=args.temperature,
|
| 116 |
+
top_p=args.top_p
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Save or print response
|
| 120 |
+
save_response(response, args.output)
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error: {str(e)}")
|
| 124 |
+
raise
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
examples/llava_demo.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
examples/process_image.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example script for processing images with the LLaVA model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from src.models.llava_model import LLaVAModel
|
| 10 |
+
from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
|
| 11 |
+
from src.utils.logging import setup_logging, get_logger
|
| 12 |
+
|
| 13 |
+
# Set up logging
|
| 14 |
+
setup_logging()
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
def process_image(
|
| 18 |
+
image_path: str,
|
| 19 |
+
prompt: str,
|
| 20 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
| 21 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 22 |
+
top_p: float = DEFAULT_TOP_P
|
| 23 |
+
) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Process an image with the LLaVA model.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
image_path: Path to the input image
|
| 29 |
+
prompt: Text prompt for the model
|
| 30 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 31 |
+
temperature: Sampling temperature
|
| 32 |
+
top_p: Top-p sampling parameter
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
str: Model response
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
# Load image
|
| 39 |
+
image = Image.open(image_path)
|
| 40 |
+
logger.info(f"Loaded image from {image_path}")
|
| 41 |
+
|
| 42 |
+
# Initialize model
|
| 43 |
+
model = LLaVAModel()
|
| 44 |
+
logger.info("Model initialized")
|
| 45 |
+
|
| 46 |
+
# Generate response
|
| 47 |
+
response = model(
|
| 48 |
+
image=image,
|
| 49 |
+
prompt=prompt,
|
| 50 |
+
max_new_tokens=max_new_tokens,
|
| 51 |
+
temperature=temperature,
|
| 52 |
+
top_p=top_p
|
| 53 |
+
)
|
| 54 |
+
logger.info("Generated response")
|
| 55 |
+
|
| 56 |
+
return response
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error processing image: {str(e)}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
"""Main function to process images from command line."""
|
| 64 |
+
parser = argparse.ArgumentParser(description="Process images with LLaVA model")
|
| 65 |
+
parser.add_argument("image_path", type=str, help="Path to the input image")
|
| 66 |
+
parser.add_argument("prompt", type=str, help="Text prompt for the model")
|
| 67 |
+
parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
| 68 |
+
help="Maximum number of tokens to generate")
|
| 69 |
+
parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE,
|
| 70 |
+
help="Sampling temperature")
|
| 71 |
+
parser.add_argument("--top-p", type=float, default=DEFAULT_TOP_P,
|
| 72 |
+
help="Top-p sampling parameter")
|
| 73 |
+
parser.add_argument("--output", type=str, help="Path to save the response")
|
| 74 |
+
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
# Process image
|
| 79 |
+
response = process_image(
|
| 80 |
+
image_path=args.image_path,
|
| 81 |
+
prompt=args.prompt,
|
| 82 |
+
max_new_tokens=args.max_tokens,
|
| 83 |
+
temperature=args.temperature,
|
| 84 |
+
top_p=args.top_p
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Print or save response
|
| 88 |
+
if args.output:
|
| 89 |
+
output_path = Path(args.output)
|
| 90 |
+
output_path.write_text(response)
|
| 91 |
+
logger.info(f"Saved response to {output_path}")
|
| 92 |
+
else:
|
| 93 |
+
print("\nModel Response:")
|
| 94 |
+
print("-" * 50)
|
| 95 |
+
print(response)
|
| 96 |
+
print("-" * 50)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Error: {str(e)}")
|
| 100 |
+
raise
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "llava-implementation"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "A modern implementation of LLaVA with a beautiful web interface"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Prashant Ambati", email = "[email protected]"}
|
| 14 |
+
]
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Development Status :: 4 - Beta",
|
| 17 |
+
"Intended Audience :: Developers",
|
| 18 |
+
"Intended Audience :: Science/Research",
|
| 19 |
+
"License :: OSI Approved :: MIT License",
|
| 20 |
+
"Programming Language :: Python :: 3",
|
| 21 |
+
"Programming Language :: Python :: 3.8",
|
| 22 |
+
"Programming Language :: Python :: 3.9",
|
| 23 |
+
"Programming Language :: Python :: 3.10",
|
| 24 |
+
"Programming Language :: Python :: 3.11",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 26 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 27 |
+
]
|
| 28 |
+
dependencies = [
|
| 29 |
+
"torch>=2.0.0",
|
| 30 |
+
"torchvision>=0.15.0",
|
| 31 |
+
"transformers>=4.36.0",
|
| 32 |
+
"accelerate>=0.25.0",
|
| 33 |
+
"pillow>=10.0.0",
|
| 34 |
+
"numpy>=1.24.0",
|
| 35 |
+
"tqdm>=4.65.0",
|
| 36 |
+
"matplotlib>=3.7.0",
|
| 37 |
+
"opencv-python>=4.8.0",
|
| 38 |
+
"einops>=0.7.0",
|
| 39 |
+
"timm>=0.9.0",
|
| 40 |
+
"sentencepiece>=0.1.99",
|
| 41 |
+
"peft>=0.7.0",
|
| 42 |
+
"bitsandbytes>=0.41.0",
|
| 43 |
+
"safetensors>=0.4.0",
|
| 44 |
+
"gradio==4.44.1",
|
| 45 |
+
"fastapi>=0.109.0",
|
| 46 |
+
"uvicorn>=0.27.0",
|
| 47 |
+
"python-multipart>=0.0.6",
|
| 48 |
+
"pydantic>=2.5.0",
|
| 49 |
+
"python-jose>=3.3.0",
|
| 50 |
+
"passlib>=1.7.4",
|
| 51 |
+
"bcrypt>=4.0.1",
|
| 52 |
+
"aiofiles>=23.2.0",
|
| 53 |
+
"httpx>=0.26.0",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
[project.optional-dependencies]
|
| 57 |
+
dev = [
|
| 58 |
+
"pytest>=8.0.0",
|
| 59 |
+
"pytest-cov>=4.1.0",
|
| 60 |
+
"pytest-mock>=3.12.0",
|
| 61 |
+
"pytest-asyncio>=0.23.5",
|
| 62 |
+
"pytest-xdist>=3.5.0",
|
| 63 |
+
"black>=24.1.1",
|
| 64 |
+
"isort>=5.13.2",
|
| 65 |
+
"flake8>=7.0.0",
|
| 66 |
+
"mypy>=1.8.0",
|
| 67 |
+
"types-Pillow>=10.2.0.20240106",
|
| 68 |
+
"types-requests>=2.31.0.20240125",
|
| 69 |
+
"sphinx>=7.2.6",
|
| 70 |
+
"sphinx-rtd-theme>=2.0.0",
|
| 71 |
+
"sphinx-autodoc-typehints>=2.0.1",
|
| 72 |
+
"sphinx-copybutton>=0.5.2",
|
| 73 |
+
"sphinx-tabs>=3.4.4",
|
| 74 |
+
"pre-commit>=3.6.0",
|
| 75 |
+
"ipython>=8.21.0",
|
| 76 |
+
"jupyter>=1.0.0",
|
| 77 |
+
"notebook>=7.0.7",
|
| 78 |
+
"ipykernel>=6.29.0",
|
| 79 |
+
"build>=1.0.3",
|
| 80 |
+
"twine>=4.0.2",
|
| 81 |
+
"wheel>=0.42.0",
|
| 82 |
+
"memory-profiler>=0.61.0",
|
| 83 |
+
"line-profiler>=4.1.2",
|
| 84 |
+
"debugpy>=1.8.0",
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
[project.urls]
|
| 88 |
+
Homepage = "https://github.com/Prashant-ambati/llava-implementation"
|
| 89 |
+
Documentation = "https://github.com/Prashant-ambati/llava-implementation#readme"
|
| 90 |
+
Repository = "https://github.com/Prashant-ambati/llava-implementation.git"
|
| 91 |
+
Issues = "https://github.com/Prashant-ambati/llava-implementation/issues"
|
| 92 |
+
"Bug Tracker" = "https://github.com/Prashant-ambati/llava-implementation/issues"
|
| 93 |
+
|
| 94 |
+
[tool.setuptools]
|
| 95 |
+
packages = ["src"]
|
| 96 |
+
|
| 97 |
+
[tool.black]
|
| 98 |
+
line-length = 88
|
| 99 |
+
target-version = ["py38"]
|
| 100 |
+
include = '\.pyi?$'
|
| 101 |
+
|
| 102 |
+
[tool.isort]
|
| 103 |
+
profile = "black"
|
| 104 |
+
multi_line_output = 3
|
| 105 |
+
include_trailing_comma = true
|
| 106 |
+
force_grid_wrap = 0
|
| 107 |
+
use_parentheses = true
|
| 108 |
+
ensure_newline_before_comments = true
|
| 109 |
+
line_length = 88
|
| 110 |
+
|
| 111 |
+
[tool.mypy]
|
| 112 |
+
python_version = "3.8"
|
| 113 |
+
warn_return_any = true
|
| 114 |
+
warn_unused_configs = true
|
| 115 |
+
disallow_untyped_defs = true
|
| 116 |
+
disallow_incomplete_defs = true
|
| 117 |
+
check_untyped_defs = true
|
| 118 |
+
disallow_untyped_decorators = true
|
| 119 |
+
no_implicit_optional = true
|
| 120 |
+
warn_redundant_casts = true
|
| 121 |
+
warn_unused_ignores = true
|
| 122 |
+
warn_no_return = true
|
| 123 |
+
warn_unreachable = true
|
| 124 |
+
strict_optional = true
|
| 125 |
+
|
| 126 |
+
[tool.pytest.ini_options]
|
| 127 |
+
minversion = "6.0"
|
| 128 |
+
addopts = "-ra -q --cov=src"
|
| 129 |
+
testpaths = [
|
| 130 |
+
"tests",
|
| 131 |
+
]
|
| 132 |
+
python_files = ["test_*.py"]
|
| 133 |
+
python_classes = ["Test*"]
|
| 134 |
+
python_functions = ["test_*"]
|
| 135 |
+
|
| 136 |
+
[tool.coverage.run]
|
| 137 |
+
source = ["src"]
|
| 138 |
+
branch = true
|
| 139 |
+
|
| 140 |
+
[tool.coverage.report]
|
| 141 |
+
exclude_lines = [
|
| 142 |
+
"pragma: no cover",
|
| 143 |
+
"def __repr__",
|
| 144 |
+
"if self.debug:",
|
| 145 |
+
"raise NotImplementedError",
|
| 146 |
+
"if __name__ == .__main__.:",
|
| 147 |
+
"pass",
|
| 148 |
+
"raise ImportError",
|
| 149 |
+
]
|
| 150 |
+
show_missing = true
|
| 151 |
+
fail_under = 80
|
| 152 |
+
|
| 153 |
+
[tool.bandit]
|
| 154 |
+
exclude_dirs = ["tests", "docs"]
|
| 155 |
+
skips = ["B101"]
|
| 156 |
+
|
| 157 |
+
[tool.ruff]
|
| 158 |
+
line-length = 88
|
| 159 |
+
target-version = "py38"
|
| 160 |
+
select = [
|
| 161 |
+
"E", # pycodestyle errors
|
| 162 |
+
"W", # pycodestyle warnings
|
| 163 |
+
"F", # pyflakes
|
| 164 |
+
"I", # isort
|
| 165 |
+
"B", # flake8-bugbear
|
| 166 |
+
"C4", # flake8-comprehensions
|
| 167 |
+
"UP", # pyupgrade
|
| 168 |
+
"N", # pep8-naming
|
| 169 |
+
"PL", # pylint
|
| 170 |
+
"RUF", # ruff-specific rules
|
| 171 |
+
]
|
| 172 |
+
ignore = [
|
| 173 |
+
"E501", # line length violations
|
| 174 |
+
"B008", # do not perform function calls in argument defaults
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
[tool.ruff.isort]
|
| 178 |
+
known-first-party = ["src"]
|
| 179 |
+
|
| 180 |
+
[tool.ruff.mccabe]
|
| 181 |
+
max-complexity = 10
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Testing
|
| 2 |
+
pytest==8.0.0
|
| 3 |
+
pytest-cov==4.1.0
|
| 4 |
+
pytest-mock==3.12.0
|
| 5 |
+
pytest-asyncio==0.23.5
|
| 6 |
+
pytest-xdist==3.5.0
|
| 7 |
+
|
| 8 |
+
# Code Quality
|
| 9 |
+
black==24.1.1
|
| 10 |
+
isort==5.13.2
|
| 11 |
+
flake8==7.0.0
|
| 12 |
+
mypy==1.8.0
|
| 13 |
+
types-Pillow==10.2.0.20240106
|
| 14 |
+
types-requests==2.31.0.20240125
|
| 15 |
+
|
| 16 |
+
# Documentation
|
| 17 |
+
sphinx==7.2.6
|
| 18 |
+
sphinx-rtd-theme==2.0.0
|
| 19 |
+
sphinx-autodoc-typehints==2.0.1
|
| 20 |
+
sphinx-copybutton==0.5.2
|
| 21 |
+
sphinx-tabs==3.4.4
|
| 22 |
+
|
| 23 |
+
# Development Tools
|
| 24 |
+
pre-commit==3.6.0
|
| 25 |
+
ipython==8.21.0
|
| 26 |
+
jupyter==1.0.0
|
| 27 |
+
notebook==7.0.7
|
| 28 |
+
ipykernel==6.29.0
|
| 29 |
+
|
| 30 |
+
# Build Tools
|
| 31 |
+
build==1.0.3
|
| 32 |
+
twine==4.0.2
|
| 33 |
+
wheel==0.42.0
|
| 34 |
+
|
| 35 |
+
# Monitoring and Debugging
|
| 36 |
+
memory-profiler==0.61.0
|
| 37 |
+
line-profiler==4.1.2
|
| 38 |
+
debugpy==1.8.0
|
requirements.txt
CHANGED
|
@@ -1,26 +1,25 @@
|
|
| 1 |
torch>=2.0.0
|
| 2 |
torchvision>=0.15.0
|
| 3 |
-
transformers>=4.
|
| 4 |
-
accelerate>=0.
|
| 5 |
-
pillow>=
|
| 6 |
numpy>=1.24.0
|
| 7 |
tqdm>=4.65.0
|
| 8 |
matplotlib>=3.7.0
|
| 9 |
-
opencv-python>=4.
|
| 10 |
-
einops>=0.
|
| 11 |
timm>=0.9.0
|
| 12 |
sentencepiece>=0.1.99
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
fastapi
|
| 18 |
-
uvicorn
|
| 19 |
-
python-multipart
|
| 20 |
-
pydantic
|
| 21 |
-
python-jose
|
| 22 |
-
passlib
|
| 23 |
-
bcrypt
|
| 24 |
-
aiofiles
|
| 25 |
-
|
| 26 |
-
httpx==0.25.2
|
|
|
|
| 1 |
torch>=2.0.0
|
| 2 |
torchvision>=0.15.0
|
| 3 |
+
transformers>=4.36.0
|
| 4 |
+
accelerate>=0.25.0
|
| 5 |
+
pillow>=10.0.0
|
| 6 |
numpy>=1.24.0
|
| 7 |
tqdm>=4.65.0
|
| 8 |
matplotlib>=3.7.0
|
| 9 |
+
opencv-python>=4.8.0
|
| 10 |
+
einops>=0.7.0
|
| 11 |
timm>=0.9.0
|
| 12 |
sentencepiece>=0.1.99
|
| 13 |
+
peft>=0.7.0
|
| 14 |
+
bitsandbytes>=0.41.0
|
| 15 |
+
safetensors>=0.4.0
|
| 16 |
+
gradio==4.44.1
|
| 17 |
+
fastapi>=0.109.0
|
| 18 |
+
uvicorn>=0.27.0
|
| 19 |
+
python-multipart>=0.0.6
|
| 20 |
+
pydantic>=2.5.0
|
| 21 |
+
python-jose>=3.3.0
|
| 22 |
+
passlib>=1.7.4
|
| 23 |
+
bcrypt>=4.0.1
|
| 24 |
+
aiofiles>=23.2.0
|
| 25 |
+
httpx>=0.26.0
|
|
|
src/__init__.py
ADDED
|
File without changes
|
src/api/__init__.py
ADDED
|
File without changes
|
src/api/app.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio interface for the LLaVA model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from ..configs.settings import (
|
| 9 |
+
GRADIO_THEME,
|
| 10 |
+
GRADIO_TITLE,
|
| 11 |
+
GRADIO_DESCRIPTION,
|
| 12 |
+
DEFAULT_MAX_NEW_TOKENS,
|
| 13 |
+
DEFAULT_TEMPERATURE,
|
| 14 |
+
DEFAULT_TOP_P,
|
| 15 |
+
API_HOST,
|
| 16 |
+
API_PORT,
|
| 17 |
+
API_WORKERS,
|
| 18 |
+
API_RELOAD
|
| 19 |
+
)
|
| 20 |
+
from ..models.llava_model import LLaVAModel
|
| 21 |
+
from ..utils.logging import setup_logging, get_logger
|
| 22 |
+
|
| 23 |
+
# Set up logging
|
| 24 |
+
setup_logging()
|
| 25 |
+
logger = get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
# Initialize model
|
| 28 |
+
model = LLaVAModel()
|
| 29 |
+
|
| 30 |
+
def process_image(
|
| 31 |
+
image: Image.Image,
|
| 32 |
+
prompt: str,
|
| 33 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
| 34 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 35 |
+
top_p: float = DEFAULT_TOP_P
|
| 36 |
+
) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Process an image with the LLaVA model.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
image: Input image
|
| 42 |
+
prompt: Text prompt
|
| 43 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 44 |
+
temperature: Sampling temperature
|
| 45 |
+
top_p: Top-p sampling parameter
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
str: Model response
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
logger.info(f"Processing image with prompt: {prompt[:100]}...")
|
| 52 |
+
response = model(
|
| 53 |
+
image=image,
|
| 54 |
+
prompt=prompt,
|
| 55 |
+
max_new_tokens=max_new_tokens,
|
| 56 |
+
temperature=temperature,
|
| 57 |
+
top_p=top_p
|
| 58 |
+
)
|
| 59 |
+
logger.info("Successfully generated response")
|
| 60 |
+
return response
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Error processing image: {str(e)}")
|
| 63 |
+
return f"Error: {str(e)}"
|
| 64 |
+
|
| 65 |
+
def create_interface() -> gr.Interface:
|
| 66 |
+
"""Create and return the Gradio interface."""
|
| 67 |
+
with gr.Blocks(theme=GRADIO_THEME) as interface:
|
| 68 |
+
gr.Markdown(f"""# {GRADIO_TITLE}
|
| 69 |
+
|
| 70 |
+
{GRADIO_DESCRIPTION}
|
| 71 |
+
|
| 72 |
+
## Example Prompts
|
| 73 |
+
|
| 74 |
+
Try these prompts to get started:
|
| 75 |
+
- "What can you see in this image?"
|
| 76 |
+
- "Describe this scene in detail"
|
| 77 |
+
- "What emotions does this image convey?"
|
| 78 |
+
- "What's happening in this picture?"
|
| 79 |
+
- "Can you identify any objects or people in this image?"
|
| 80 |
+
|
| 81 |
+
## Usage Instructions
|
| 82 |
+
|
| 83 |
+
1. Upload an image using the image uploader
|
| 84 |
+
2. Enter your prompt in the text box
|
| 85 |
+
3. (Optional) Adjust the generation parameters
|
| 86 |
+
4. Click "Generate Response" to get LLaVA's analysis
|
| 87 |
+
""")
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
with gr.Column():
|
| 91 |
+
# Input components
|
| 92 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
| 93 |
+
prompt_input = gr.Textbox(
|
| 94 |
+
label="Prompt",
|
| 95 |
+
placeholder="What can you see in this image?",
|
| 96 |
+
lines=3
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
with gr.Accordion("Generation Parameters", open=False):
|
| 100 |
+
max_tokens = gr.Slider(
|
| 101 |
+
minimum=64,
|
| 102 |
+
maximum=2048,
|
| 103 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
| 104 |
+
step=64,
|
| 105 |
+
label="Max New Tokens"
|
| 106 |
+
)
|
| 107 |
+
temperature = gr.Slider(
|
| 108 |
+
minimum=0.1,
|
| 109 |
+
maximum=1.0,
|
| 110 |
+
value=DEFAULT_TEMPERATURE,
|
| 111 |
+
step=0.1,
|
| 112 |
+
label="Temperature"
|
| 113 |
+
)
|
| 114 |
+
top_p = gr.Slider(
|
| 115 |
+
minimum=0.1,
|
| 116 |
+
maximum=1.0,
|
| 117 |
+
value=DEFAULT_TOP_P,
|
| 118 |
+
step=0.1,
|
| 119 |
+
label="Top P"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
generate_btn = gr.Button("Generate Response", variant="primary")
|
| 123 |
+
|
| 124 |
+
with gr.Column():
|
| 125 |
+
# Output component
|
| 126 |
+
output = gr.Textbox(
|
| 127 |
+
label="Response",
|
| 128 |
+
lines=10,
|
| 129 |
+
show_copy_button=True
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Set up event handlers
|
| 133 |
+
generate_btn.click(
|
| 134 |
+
fn=process_image,
|
| 135 |
+
inputs=[
|
| 136 |
+
image_input,
|
| 137 |
+
prompt_input,
|
| 138 |
+
max_tokens,
|
| 139 |
+
temperature,
|
| 140 |
+
top_p
|
| 141 |
+
],
|
| 142 |
+
outputs=output
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return interface
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
"""Run the Gradio interface."""
|
| 149 |
+
interface = create_interface()
|
| 150 |
+
interface.launch(
|
| 151 |
+
server_name=API_HOST,
|
| 152 |
+
server_port=API_PORT,
|
| 153 |
+
share=True,
|
| 154 |
+
show_error=True,
|
| 155 |
+
show_api=False
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
src/configs/__init__.py
ADDED
|
File without changes
|
src/configs/settings.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the LLaVA implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Project paths
|
| 9 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 10 |
+
SRC_DIR = PROJECT_ROOT / "src"
|
| 11 |
+
ASSETS_DIR = PROJECT_ROOT / "assets"
|
| 12 |
+
EXAMPLES_DIR = PROJECT_ROOT / "examples"
|
| 13 |
+
|
| 14 |
+
# Model settings
|
| 15 |
+
MODEL_NAME = "liuhaotian/llava-v1.5-7b"
|
| 16 |
+
MODEL_REVISION = "main"
|
| 17 |
+
DEVICE = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
|
| 18 |
+
|
| 19 |
+
# Generation settings
|
| 20 |
+
DEFAULT_MAX_NEW_TOKENS = 512
|
| 21 |
+
DEFAULT_TEMPERATURE = 0.7
|
| 22 |
+
DEFAULT_TOP_P = 0.9
|
| 23 |
+
|
| 24 |
+
# API settings
|
| 25 |
+
API_HOST = "0.0.0.0"
|
| 26 |
+
API_PORT = 7860
|
| 27 |
+
API_WORKERS = 1
|
| 28 |
+
API_RELOAD = True
|
| 29 |
+
|
| 30 |
+
# Gradio settings
|
| 31 |
+
GRADIO_THEME = "soft"
|
| 32 |
+
GRADIO_TITLE = "LLaVA Chat"
|
| 33 |
+
GRADIO_DESCRIPTION = """
|
| 34 |
+
A powerful multimodal AI assistant that can understand and discuss images.
|
| 35 |
+
Upload any image and chat with LLaVA about it!
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Logging settings
|
| 39 |
+
LOG_LEVEL = "INFO"
|
| 40 |
+
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 41 |
+
LOG_DIR = PROJECT_ROOT / "logs"
|
| 42 |
+
LOG_FILE = LOG_DIR / "app.log"
|
| 43 |
+
|
| 44 |
+
# Create necessary directories
|
| 45 |
+
for directory in [ASSETS_DIR, EXAMPLES_DIR, LOG_DIR]:
|
| 46 |
+
directory.mkdir(parents=True, exist_ok=True)
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/llava_model.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaVA model implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE
|
| 10 |
+
from ..utils.logging import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
class LLaVAModel:
|
| 15 |
+
"""LLaVA model wrapper class."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
"""Initialize the LLaVA model and processor."""
|
| 19 |
+
logger.info(f"Initializing LLaVA model from {MODEL_NAME}")
|
| 20 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 21 |
+
MODEL_NAME,
|
| 22 |
+
revision=MODEL_REVISION,
|
| 23 |
+
trust_remote_code=True
|
| 24 |
+
)
|
| 25 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
+
MODEL_NAME,
|
| 27 |
+
revision=MODEL_REVISION,
|
| 28 |
+
torch_dtype=torch.float16,
|
| 29 |
+
device_map="auto",
|
| 30 |
+
trust_remote_code=True
|
| 31 |
+
)
|
| 32 |
+
logger.info("Model initialization complete")
|
| 33 |
+
|
| 34 |
+
def generate_response(
|
| 35 |
+
self,
|
| 36 |
+
image: Image.Image,
|
| 37 |
+
prompt: str,
|
| 38 |
+
max_new_tokens: int = 512,
|
| 39 |
+
temperature: float = 0.7,
|
| 40 |
+
top_p: float = 0.9
|
| 41 |
+
) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Generate a response for the given image and prompt.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
image: Input image as PIL Image
|
| 47 |
+
prompt: Text prompt for the model
|
| 48 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 49 |
+
temperature: Sampling temperature
|
| 50 |
+
top_p: Top-p sampling parameter
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
str: Generated response
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
# Prepare inputs
|
| 57 |
+
inputs = self.processor(
|
| 58 |
+
prompt,
|
| 59 |
+
image,
|
| 60 |
+
return_tensors="pt"
|
| 61 |
+
).to(DEVICE)
|
| 62 |
+
|
| 63 |
+
# Generate response
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
outputs = self.model.generate(
|
| 66 |
+
**inputs,
|
| 67 |
+
max_new_tokens=max_new_tokens,
|
| 68 |
+
temperature=temperature,
|
| 69 |
+
top_p=top_p,
|
| 70 |
+
do_sample=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Decode and return response
|
| 74 |
+
response = self.processor.decode(
|
| 75 |
+
outputs[0],
|
| 76 |
+
skip_special_tokens=True
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
logger.debug(f"Generated response: {response[:100]}...")
|
| 80 |
+
return response
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
def __call__(self, *args, **kwargs):
|
| 87 |
+
"""Convenience method to call generate_response."""
|
| 88 |
+
return self.generate_response(*args, **kwargs)
|
main.py → src/models/main.py
RENAMED
|
File without changes
|
src/requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
accelerate>=0.20.0
|
| 5 |
+
pillow>=9.0.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
tqdm>=4.65.0
|
| 8 |
+
matplotlib>=3.7.0
|
| 9 |
+
opencv-python>=4.7.0
|
| 10 |
+
einops>=0.6.0
|
| 11 |
+
timm>=0.9.0
|
| 12 |
+
sentencepiece>=0.1.99
|
| 13 |
+
gradio>=3.35.0
|
| 14 |
+
peft>=0.4.0
|
| 15 |
+
bitsandbytes>=0.40.0
|
| 16 |
+
safetensors>=0.3.1
|
| 17 |
+
fastapi==0.104.1
|
| 18 |
+
uvicorn==0.24.0
|
| 19 |
+
python-multipart==0.0.6
|
| 20 |
+
pydantic==2.5.2
|
| 21 |
+
python-jose==3.3.0
|
| 22 |
+
passlib==1.7.4
|
| 23 |
+
bcrypt==4.0.1
|
| 24 |
+
aiofiles==23.2.1
|
| 25 |
+
python-dotenv==1.0.0
|
| 26 |
+
httpx==0.25.2
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/logging.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging utilities for the LLaVA implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from ..configs.settings import LOG_LEVEL, LOG_FORMAT, LOG_FILE
|
| 10 |
+
|
| 11 |
+
def setup_logging(name: str = None) -> logging.Logger:
|
| 12 |
+
"""
|
| 13 |
+
Set up logging configuration for the application.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
name: Optional name for the logger. If None, returns the root logger.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
logging.Logger: Configured logger instance.
|
| 20 |
+
"""
|
| 21 |
+
# Create logger
|
| 22 |
+
logger = logging.getLogger(name)
|
| 23 |
+
logger.setLevel(LOG_LEVEL)
|
| 24 |
+
|
| 25 |
+
# Create formatters
|
| 26 |
+
formatter = logging.Formatter(LOG_FORMAT)
|
| 27 |
+
|
| 28 |
+
# Create handlers
|
| 29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 30 |
+
console_handler.setFormatter(formatter)
|
| 31 |
+
|
| 32 |
+
file_handler = logging.FileHandler(LOG_FILE)
|
| 33 |
+
file_handler.setFormatter(formatter)
|
| 34 |
+
|
| 35 |
+
# Add handlers to logger
|
| 36 |
+
logger.addHandler(console_handler)
|
| 37 |
+
logger.addHandler(file_handler)
|
| 38 |
+
|
| 39 |
+
return logger
|
| 40 |
+
|
| 41 |
+
def get_logger(name: str = None) -> logging.Logger:
|
| 42 |
+
"""
|
| 43 |
+
Get a logger instance with the specified name.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
name: Optional name for the logger. If None, returns the root logger.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
logging.Logger: Logger instance.
|
| 50 |
+
"""
|
| 51 |
+
return logging.getLogger(name)
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the LLaVA model implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.models.llava_model import LLaVAModel
|
| 10 |
+
from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def model():
|
| 14 |
+
"""Fixture to provide a model instance."""
|
| 15 |
+
return LLaVAModel()
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def sample_image():
|
| 19 |
+
"""Fixture to provide a sample image."""
|
| 20 |
+
# Create a simple test image
|
| 21 |
+
return Image.new('RGB', (224, 224), color='red')
|
| 22 |
+
|
| 23 |
+
def test_model_initialization(model):
|
| 24 |
+
"""Test that the model initializes correctly."""
|
| 25 |
+
assert model is not None
|
| 26 |
+
assert model.processor is not None
|
| 27 |
+
assert model.model is not None
|
| 28 |
+
|
| 29 |
+
def test_model_device(model):
|
| 30 |
+
"""Test that the model is on the correct device."""
|
| 31 |
+
assert next(model.model.parameters()).device.type in ['cuda', 'cpu']
|
| 32 |
+
|
| 33 |
+
def test_generate_response(model, sample_image):
|
| 34 |
+
"""Test that the model can generate responses."""
|
| 35 |
+
prompt = "What color is this image?"
|
| 36 |
+
response = model.generate_response(
|
| 37 |
+
image=sample_image,
|
| 38 |
+
prompt=prompt,
|
| 39 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
| 40 |
+
temperature=DEFAULT_TEMPERATURE,
|
| 41 |
+
top_p=DEFAULT_TOP_P
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
assert isinstance(response, str)
|
| 45 |
+
assert len(response) > 0
|
| 46 |
+
|
| 47 |
+
def test_generate_response_with_invalid_image(model):
|
| 48 |
+
"""Test that the model handles invalid images correctly."""
|
| 49 |
+
with pytest.raises(Exception):
|
| 50 |
+
model.generate_response(
|
| 51 |
+
image=None,
|
| 52 |
+
prompt="What color is this image?",
|
| 53 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
| 54 |
+
temperature=DEFAULT_TEMPERATURE,
|
| 55 |
+
top_p=DEFAULT_TOP_P
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def test_generate_response_with_empty_prompt(model, sample_image):
|
| 59 |
+
"""Test that the model handles empty prompts correctly."""
|
| 60 |
+
with pytest.raises(Exception):
|
| 61 |
+
model.generate_response(
|
| 62 |
+
image=sample_image,
|
| 63 |
+
prompt="",
|
| 64 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
| 65 |
+
temperature=DEFAULT_TEMPERATURE,
|
| 66 |
+
top_p=DEFAULT_TOP_P
|
| 67 |
+
)
|