Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- LICENSE.md +94 -0
- README.md +323 -0
- model-00001-of-00006.safetensors +3 -0
- model-00002-of-00006.safetensors +3 -0
- model-00003-of-00006.safetensors +3 -0
- model-00004-of-00006.safetensors +3 -0
- model-00005-of-00006.safetensors +3 -0
- model-00006-of-00006.safetensors +3 -0
- modeling_safellava.py +21 -0
- safellava/__init__.py +11 -0
- safellava/constants.py +20 -0
- safellava/conversation.py +96 -0
- safellava/mm_utils.py +254 -0
- safellava/model/__init__.py +11 -0
- safellava/model/language_model/__init__.py +11 -0
- safellava/model/language_model/llava_llama.py +166 -0
- safellava/model/language_model/safe_llava_llama.py +426 -0
- safellava/model/llava_arch.py +368 -0
- safellava/model/multimodal_encoder/__init__.py +15 -0
- safellava/model/multimodal_encoder/builder.py +40 -0
- safellava/model/multimodal_encoder/clip_encoder.py +127 -0
- safellava/model/multimodal_projector/__init__.py +10 -0
- safellava/model/multimodal_projector/builder.py +59 -0
- safellava/utils.py +127 -0
- test_image.png +3 -0
- tokenizer.model +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
test_image.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# License for SafeLLaVA
|
| 2 |
+
|
| 3 |
+
The SafeLLaVA project is governed by a **hybrid license model**. This license file defines the distinct licensing policies that apply to the two main components of this project.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. Definition of Work
|
| 8 |
+
|
| 9 |
+
### Model Name
|
| 10 |
+
**SafeLLaVA**
|
| 11 |
+
|
| 12 |
+
### Reference Publication
|
| 13 |
+
This model (SafeLLaVA) is the official model presented in the academic paper:
|
| 14 |
+
|
| 15 |
+
> **"HoliSafe: Holistic Safety Benchmarking and Modeling for Vision-Language Model"**
|
| 16 |
+
> [https://arxiv.org/abs/2506.04704](https://arxiv.org/abs/2506.04704)
|
| 17 |
+
|
| 18 |
+
### Base Model
|
| 19 |
+
This model is a **Derivative Work** based on the LLaVA-v1.5 model:
|
| 20 |
+
- [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA)
|
| 21 |
+
|
| 22 |
+
### Modifications by ETRI
|
| 23 |
+
This work integrates an independently developed **Visual Guard Module (VGM)** to classify harmful image inputs and generate safe text responses. All modifications and additions are the work of ETRI.
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 2. License Summary
|
| 28 |
+
|
| 29 |
+
| Component | License |
|
| 30 |
+
|-----------|---------|
|
| 31 |
+
| Independently Developed Code (e.g., VGM) | Apache License 2.0 |
|
| 32 |
+
| LLaVA-Based Components and Entire Model | LLaVA-v1.5 Usage and License Notices |
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Part 1: Apache License 2.0 (For Independently Developed Code)
|
| 37 |
+
|
| 38 |
+
All original source code and components developed independently by **Electronics and Telecommunications Research Institute (ETRI)** (hereinafter "Copyright Holder"), including the **Visual Guard Module (VGM)** contained in this project, are subject to the Apache License, Version 2.0 (the "License").
|
| 39 |
+
|
| 40 |
+
You may not use this file except in compliance with the License. You may obtain a copy of the License at:
|
| 41 |
+
|
| 42 |
+
[http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
|
| 43 |
+
|
| 44 |
+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 45 |
+
```
|
| 46 |
+
Copyright 2025 Electronics and Telecommunications Research Institute (ETRI)
|
| 47 |
+
|
| 48 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 49 |
+
you may not use this file except in compliance with the License.
|
| 50 |
+
You may obtain a copy of the License at
|
| 51 |
+
|
| 52 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 53 |
+
|
| 54 |
+
Unless required by applicable law or agreed to in writing, software
|
| 55 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 56 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 57 |
+
See the License for the specific language governing permissions and
|
| 58 |
+
limitations under the License.
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Part 2: LLaVA-v1.5 Usage and License Notices (For the LLaVA-Based Derivative Work)
|
| 64 |
+
|
| 65 |
+
This SafeLLaVA model is a **Derivative Work** based on the LLaVA-v1.5 model.
|
| 66 |
+
|
| 67 |
+
Therefore, the use, reproduction, modification, and distribution of the entire SafeLLaVA model (including LLaVA-based components and weights) are subject to the original **"Usage and License Notices"** from the LLaVA-v1.5 repository.
|
| 68 |
+
|
| 69 |
+
Any user of SafeLLaVA must agree to and comply with these upstream notices, which are reproduced below:
|
| 70 |
+
|
| 71 |
+
### Usage and License Notices (from [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file))
|
| 72 |
+
|
| 73 |
+
> This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the **OpenAI Terms of Use** for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. **Llama community license** for LLaMA-2 and **Vicuna-v1.5**). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
|
| 74 |
+
|
| 75 |
+
### Implications for Users
|
| 76 |
+
|
| 77 |
+
This means that any user of SafeLLaVA is responsible for independently verifying and complying with all applicable upstream licenses, which may include (but are not limited to):
|
| 78 |
+
|
| 79 |
+
- **The OpenAI Terms of Use**
|
| 80 |
+
[https://openai.com/policies/terms-of-use](https://openai.com/policies/terms-of-use)
|
| 81 |
+
|
| 82 |
+
- **The Llama 2 Community License**
|
| 83 |
+
[https://ai.meta.com/llama/license/](https://ai.meta.com/llama/license/)
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Part 3: Attribution and Contact
|
| 88 |
+
|
| 89 |
+
This SafeLLaVA model was developed by the **Electronics and Telecommunications Research Institute (ETRI)** in the Republic of Korea.
|
| 90 |
+
|
| 91 |
+
For any questions regarding the SafeLLaVA model or its licensing, please contact:
|
| 92 |
+
|
| 93 |
+
**Youngwan Lee**
|
| 94 |
+
Email: [[email protected]](mailto:[email protected])
|
README.md
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: llava-v1.5
|
| 3 |
+
tags:
|
| 4 |
+
- vision
|
| 5 |
+
- multimodal
|
| 6 |
+
- safety
|
| 7 |
+
- content-moderation
|
| 8 |
+
- llava
|
| 9 |
+
- image-classification
|
| 10 |
+
- vision-language
|
| 11 |
+
language:
|
| 12 |
+
- en
|
| 13 |
+
pipeline_tag: image-text-to-text
|
| 14 |
+
library_name: transformers
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# SafeLLaVA-13B: Vision-Language Model with Visual Guard Module
|
| 18 |
+
|
| 19 |
+
[**🌐 Website**](https://youngwanlee.github.io/holisafe) | [**📑 Paper**](https://www.arxiv.org/pdf/2506.04704)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
<div align="center">
|
| 24 |
+
<img src="https://dl.dropbox.com/scl/fi/soi772p6sig2tx16f092o/arch.jpg?rlkey=uj4ver4pp889oowigqld502hc&dl=1" width="1024px" />
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
SafeLLaVA-13B is a safe multimodal large language model that extends [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA) with built-in image safety classification capabilities. It can simultaneously generate text responses to visual questions while classifying potentially unsafe image content across 20 safety categories.
|
| 28 |
+
|
| 29 |
+
## Model Description
|
| 30 |
+
|
| 31 |
+
- **Base Model**: LLaVA-v1.5-13B
|
| 32 |
+
- **Architecture**: Vision-language model with Visual Guard Module (VGM)
|
| 33 |
+
- **Training Data**: HoliSafe train set
|
| 34 |
+
- **Training Method**: LoRA fine-tuning
|
| 35 |
+
- **Parameters**: 13B (base) + 8.43M (VGM)
|
| 36 |
+
- **Safety Categories**: 20 categories based on HoliSafe taxonomy
|
| 37 |
+
|
| 38 |
+
## Key Features
|
| 39 |
+
|
| 40 |
+
1. **Multimodal Understanding**: Processes images and text for comprehensive visual understanding
|
| 41 |
+
2. **Safety Classification**: Identifies unsafe content in images across 20 categories
|
| 42 |
+
3. **Non-invasive Architecture**: Maintains full LLaVA-v1.5-13B capabilities while adding safety features
|
| 43 |
+
4. **End-to-end Training**: VGM is jointly trained with the vision-language model
|
| 44 |
+
|
| 45 |
+
The model classifies images into the following 20 safety categories:
|
| 46 |
+
|
| 47 |
+
| Category ID | Category Name |
|
| 48 |
+
|------------|---------------|
|
| 49 |
+
| 0 | Safe |
|
| 50 |
+
| 1 | Gender discrimination |
|
| 51 |
+
| 2 | Race discrimination |
|
| 52 |
+
| 3 | Religion discrimination |
|
| 53 |
+
| 4 | Harassment |
|
| 54 |
+
| 5 | Disability discrimination |
|
| 55 |
+
| 6 | Drug Related Hazards |
|
| 56 |
+
| 7 | Property crime |
|
| 57 |
+
| 8 | Facial data exposure |
|
| 58 |
+
| 9 | Identity data exposure |
|
| 59 |
+
| 10 | Physical self-injury |
|
| 60 |
+
| 11 | Suicide |
|
| 61 |
+
| 12 | Animal abuse |
|
| 62 |
+
| 13 | Obscene gestures |
|
| 63 |
+
| 14 | Physical altercation |
|
| 64 |
+
| 15 | Terrorism |
|
| 65 |
+
| 16 | Weapon-related violence |
|
| 66 |
+
| 17 | Sexual content |
|
| 67 |
+
| 18 | Financial advice |
|
| 68 |
+
| 19 | Medical advice |
|
| 69 |
+
|
| 70 |
+
## 🚀 Quick Start
|
| 71 |
+
|
| 72 |
+
### Installation
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
pip install torch transformers pillow accelerate requests
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Complete Example - Copy & Paste Ready
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
import requests
|
| 82 |
+
import torch
|
| 83 |
+
import sys
|
| 84 |
+
from pathlib import Path
|
| 85 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 86 |
+
from PIL import Image
|
| 87 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 88 |
+
|
| 89 |
+
# Model path
|
| 90 |
+
model_path = "etri-vilab/SafeLLaVA-13B"
|
| 91 |
+
|
| 92 |
+
# Download model and add safellava package to path
|
| 93 |
+
model_cache_path = Path(snapshot_download(repo_id=model_path))
|
| 94 |
+
sys.path.insert(0, str(model_cache_path))
|
| 95 |
+
|
| 96 |
+
# Import safellava utilities
|
| 97 |
+
from safellava.mm_utils import tokenizer_image_token
|
| 98 |
+
from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 99 |
+
from safellava.conversation import conv_templates
|
| 100 |
+
|
| 101 |
+
# Load model and tokenizer
|
| 102 |
+
print("Loading model...")
|
| 103 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 104 |
+
model_path,
|
| 105 |
+
trust_remote_code=True,
|
| 106 |
+
torch_dtype=torch.float16,
|
| 107 |
+
low_cpu_mem_usage=True
|
| 108 |
+
)
|
| 109 |
+
model = model.to('cuda:0')
|
| 110 |
+
model.eval()
|
| 111 |
+
|
| 112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 113 |
+
|
| 114 |
+
# Load and move vision tower to GPU
|
| 115 |
+
vision_tower = model.get_vision_tower()
|
| 116 |
+
if not vision_tower.is_loaded:
|
| 117 |
+
vision_tower.load_model()
|
| 118 |
+
vision_tower = vision_tower.to('cuda:0')
|
| 119 |
+
|
| 120 |
+
print("✅ Model loaded successfully!")
|
| 121 |
+
|
| 122 |
+
# Helper function to load image from URL or local path
|
| 123 |
+
def load_image(image_file):
|
| 124 |
+
if image_file.startswith('http'):
|
| 125 |
+
from io import BytesIO
|
| 126 |
+
response = requests.get(image_file, timeout=30)
|
| 127 |
+
response.raise_for_status()
|
| 128 |
+
return Image.open(BytesIO(response.content)).convert('RGB')
|
| 129 |
+
else:
|
| 130 |
+
return Image.open(image_file).convert('RGB')
|
| 131 |
+
|
| 132 |
+
# Download and load the test image from HuggingFace Hub
|
| 133 |
+
# (The image is included in the model repository)
|
| 134 |
+
test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model")
|
| 135 |
+
image = load_image(test_image_path)
|
| 136 |
+
|
| 137 |
+
# You can also use your own image:
|
| 138 |
+
# image = load_image("path/to/your/image.jpg")
|
| 139 |
+
# Or load from URL:
|
| 140 |
+
# image = load_image("https://example.com/image.jpg")
|
| 141 |
+
|
| 142 |
+
# Preprocess image
|
| 143 |
+
image_processor = vision_tower.image_processor
|
| 144 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
|
| 145 |
+
image_tensor = image_tensor.to('cuda:0', dtype=torch.float16)
|
| 146 |
+
|
| 147 |
+
# Prepare conversation prompt
|
| 148 |
+
conv = conv_templates["llava_v1"].copy()
|
| 149 |
+
question = "How to get this?"
|
| 150 |
+
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
|
| 151 |
+
conv.append_message(conv.roles[1], None)
|
| 152 |
+
prompt = conv.get_prompt()
|
| 153 |
+
|
| 154 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
| 155 |
+
input_ids = input_ids.unsqueeze(0).to('cuda:0')
|
| 156 |
+
|
| 157 |
+
# Run safety classification
|
| 158 |
+
with torch.inference_mode():
|
| 159 |
+
outputs = model(
|
| 160 |
+
input_ids=input_ids,
|
| 161 |
+
images=image_tensor,
|
| 162 |
+
do_safety=True,
|
| 163 |
+
output_hidden_states=True,
|
| 164 |
+
return_dict=True
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Get safety classification results
|
| 168 |
+
if outputs.img_safety_probs is not None:
|
| 169 |
+
probs = outputs.img_safety_probs[0]
|
| 170 |
+
pred_class = probs.argmax().item()
|
| 171 |
+
pred_prob = probs[pred_class].item()
|
| 172 |
+
|
| 173 |
+
safety_categories = model.config.safety_categories
|
| 174 |
+
print(f"\n✅ Safety Classification Result:")
|
| 175 |
+
print(f" Predicted: {safety_categories[pred_class]} ({pred_prob:.4f})")
|
| 176 |
+
|
| 177 |
+
# Show top 3 predictions
|
| 178 |
+
print("\n Top 3 predictions:")
|
| 179 |
+
top_k = min(3, len(probs))
|
| 180 |
+
top_probs, top_indices = torch.topk(probs, k=top_k)
|
| 181 |
+
for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
|
| 182 |
+
print(f" {i+1}. {safety_categories[idx.item()]}: {prob.item():.4f}")
|
| 183 |
+
|
| 184 |
+
# >>> Predicted: weapon_related_violence (0.9927)
|
| 185 |
+
|
| 186 |
+
# >>> Top 3 predictions:
|
| 187 |
+
# >>> 1. weapon_related_violence: 0.9927
|
| 188 |
+
# >>> 2. safe: 0.0047
|
| 189 |
+
# >>> 3. terrorism: 0.0012
|
| 190 |
+
|
| 191 |
+
# Generate text description (without safety classification)
|
| 192 |
+
print("\n🤖 Generating text description...")
|
| 193 |
+
with torch.inference_mode():
|
| 194 |
+
output_ids = model.generate(
|
| 195 |
+
input_ids,
|
| 196 |
+
images=image_tensor,
|
| 197 |
+
do_sample=False,
|
| 198 |
+
max_new_tokens=512,
|
| 199 |
+
use_cache=True
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 203 |
+
print(f"\n✅ Generated Description:")
|
| 204 |
+
print(f" {generated_text}")
|
| 205 |
+
# >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against ou r policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items.
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## 🚀 Simple Usage (Text Generation Only)
|
| 211 |
+
|
| 212 |
+
If you only need text generation without safety classification:
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
import torch
|
| 216 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 217 |
+
from PIL import Image
|
| 218 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 219 |
+
import sys
|
| 220 |
+
from pathlib import Path
|
| 221 |
+
|
| 222 |
+
model_path = "etri-vilab/SafeLLaVA-13B"
|
| 223 |
+
|
| 224 |
+
# Add safellava package to path
|
| 225 |
+
model_cache_path = Path(snapshot_download(repo_id=model_path))
|
| 226 |
+
sys.path.insert(0, str(model_cache_path))
|
| 227 |
+
|
| 228 |
+
from safellava.mm_utils import tokenizer_image_token
|
| 229 |
+
from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 230 |
+
from safellava.conversation import conv_templates
|
| 231 |
+
|
| 232 |
+
# Load model
|
| 233 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 234 |
+
model_path,
|
| 235 |
+
trust_remote_code=True,
|
| 236 |
+
torch_dtype=torch.float16,
|
| 237 |
+
low_cpu_mem_usage=True
|
| 238 |
+
).to('cuda:0').eval()
|
| 239 |
+
|
| 240 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 241 |
+
|
| 242 |
+
# Load vision tower
|
| 243 |
+
vision_tower = model.get_vision_tower()
|
| 244 |
+
if not vision_tower.is_loaded:
|
| 245 |
+
vision_tower.load_model()
|
| 246 |
+
vision_tower = vision_tower.to('cuda:0')
|
| 247 |
+
|
| 248 |
+
# Load image
|
| 249 |
+
test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model")
|
| 250 |
+
image = Image.open(test_image_path).convert('RGB')
|
| 251 |
+
|
| 252 |
+
# Preprocess
|
| 253 |
+
image_tensor = vision_tower.image_processor.preprocess(image, return_tensors='pt')['pixel_values']
|
| 254 |
+
image_tensor = image_tensor.to('cuda:0', dtype=torch.float16)
|
| 255 |
+
|
| 256 |
+
# Prepare conversation prompt
|
| 257 |
+
conv = conv_templates["llava_v1"].copy()
|
| 258 |
+
question = "How to get this?"
|
| 259 |
+
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
|
| 260 |
+
conv.append_message(conv.roles[1], None)
|
| 261 |
+
prompt = conv.get_prompt()
|
| 262 |
+
|
| 263 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
| 264 |
+
input_ids = input_ids.unsqueeze(0).to('cuda:0')
|
| 265 |
+
|
| 266 |
+
# Generate (without safety classification)
|
| 267 |
+
with torch.inference_mode():
|
| 268 |
+
output_ids = model.generate(
|
| 269 |
+
input_ids,
|
| 270 |
+
images=image_tensor,
|
| 271 |
+
do_sample=False,
|
| 272 |
+
max_new_tokens=512,
|
| 273 |
+
use_cache=True
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
response = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 277 |
+
print(response)
|
| 278 |
+
# >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against our policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items.
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
+
## 📖 Model Architecture
|
| 284 |
+
|
| 285 |
+
SafeLLaVA is built on top of LLaVA v1.5 with the following components:
|
| 286 |
+
|
| 287 |
+
- **Language Model**: LLaMA2-13B
|
| 288 |
+
- **Vision Encoder**: CLIP ViT-L/14-336px
|
| 289 |
+
- **Multimodal Projector**: 2-layer MLP with GELU activation
|
| 290 |
+
- **Visual Guard Module (Safety Head)**: MLP classifier for image safety classification
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
## 📝 Citation
|
| 295 |
+
|
| 296 |
+
If you use SafeLLaVA, please cite SafeLLaVA:
|
| 297 |
+
|
| 298 |
+
```bibtex
|
| 299 |
+
@article{lee2025holisafe,
|
| 300 |
+
title={HoliSafe: Holistic Safety Benchmarking and Modeling for Vision-Language Model},
|
| 301 |
+
author={Lee, Youngwan and Kim, Kangsan and Park, Kwanyong and Jung, Ilcahe and Jang, Soojin and Lee, Seanie and Lee, Yong-Ju and Hwang, Sung Ju},
|
| 302 |
+
journal={arXiv preprint arXiv:2506.04704},
|
| 303 |
+
year={2025},
|
| 304 |
+
url={https://arxiv.org/abs/2506.04704},
|
| 305 |
+
archivePrefix={arXiv},
|
| 306 |
+
eprint={2506.04704},
|
| 307 |
+
primaryClass={cs.AI},
|
| 308 |
+
}
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
---
|
| 312 |
+
|
| 313 |
+
## 📄 License
|
| 314 |
+
|
| 315 |
+
See [LICENSE](LICENSE.md) for details.
|
| 316 |
+
|
| 317 |
+
---
|
| 318 |
+
|
| 319 |
+
## 🙏 Acknowledgments
|
| 320 |
+
|
| 321 |
+
- Built on [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA)
|
| 322 |
+
|
| 323 |
+
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2022-00187238, Development of Large Korean Language Model Technology for Efficient Pre-training, 45%), (No. 2022-0-00871, Development of AI Autonomy and Knowledge Enhancement for AI Agent Collaboration, 45%) and (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST), 45%).
|
model-00001-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30b19ff7d43c0f59522f07d280755e0cd47f36007cb5f3d63ce01d7600fcb7e7
|
| 3 |
+
size 4978265728
|
model-00002-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f067f75a031295b4ee0e14da3990c31fa2c5a0654b02154945ce31a2a333385c
|
| 3 |
+
size 4970422160
|
model-00003-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d2cf9d000b4d06888153032381e7607641e9e76317b81d100cc042618bfb574
|
| 3 |
+
size 4970422184
|
model-00004-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f04bd9e776835ab2724793be821ce0ed764c5447631216df4bc8dc19076af3d5
|
| 3 |
+
size 4933701432
|
model-00005-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5da5854596848ba5a512f193bd331185eb3efa911179c091ed81153a6567682d
|
| 3 |
+
size 4933722144
|
model-00006-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff21d3d163affb34c2ad4d7021c93e014f301eb138c0f30493bc171ce0c5c9f3
|
| 3 |
+
size 1941570592
|
modeling_safellava.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SafeLLaVA Model for HuggingFace Hub
|
| 3 |
+
|
| 4 |
+
This model is based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 5 |
+
Licensed under Apache License 2.0
|
| 6 |
+
|
| 7 |
+
SafeLLaVA adds image safety classification capabilities to LLaVA.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Re-export classes from safellava package for HuggingFace auto_map
|
| 11 |
+
from safellava.model.language_model.safe_llava_llama import (
|
| 12 |
+
SafetyConfig,
|
| 13 |
+
SafeLlavaLlamaForCausalLM,
|
| 14 |
+
SafetyCausalLMOutputWithPast,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"SafetyConfig",
|
| 19 |
+
"SafeLlavaLlamaForCausalLM",
|
| 20 |
+
"SafetyCausalLMOutputWithPast",
|
| 21 |
+
]
|
safellava/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .model.language_model.llava_llama import LlavaLlamaForCausalLM
|
| 9 |
+
from .model.language_model.safe_llava_llama import SafeLlavaLlamaForCausalLM
|
| 10 |
+
|
| 11 |
+
__all__ = ['LlavaLlamaForCausalLM', 'SafeLlavaLlamaForCausalLM']
|
safellava/constants.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 9 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 10 |
+
|
| 11 |
+
LOGDIR = "."
|
| 12 |
+
|
| 13 |
+
# Model Constants
|
| 14 |
+
IGNORE_INDEX = -100
|
| 15 |
+
IMAGE_TOKEN_INDEX = -200
|
| 16 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 17 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 18 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 19 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 20 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
safellava/conversation.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Conversation prompts for SafeLLaVA.
|
| 10 |
+
|
| 11 |
+
This is a simplified version containing only the llava_v1 conversation template.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import dataclasses
|
| 15 |
+
from enum import auto, Enum
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SeparatorStyle(Enum):
|
| 20 |
+
"""Different separator style."""
|
| 21 |
+
TWO = auto()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclasses.dataclass
|
| 25 |
+
class Conversation:
|
| 26 |
+
"""A class that keeps all conversation history."""
|
| 27 |
+
system: str
|
| 28 |
+
roles: List[str]
|
| 29 |
+
messages: List[List[str]]
|
| 30 |
+
offset: int
|
| 31 |
+
sep_style: SeparatorStyle = SeparatorStyle.TWO
|
| 32 |
+
sep: str = " "
|
| 33 |
+
sep2: str = "</s>"
|
| 34 |
+
version: str = "v1"
|
| 35 |
+
|
| 36 |
+
def get_prompt(self):
|
| 37 |
+
"""Generate the full prompt from conversation history."""
|
| 38 |
+
messages = self.messages
|
| 39 |
+
|
| 40 |
+
# Handle image token in first message
|
| 41 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 42 |
+
messages = self.messages.copy()
|
| 43 |
+
init_role, init_msg = messages[0].copy()
|
| 44 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 45 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 46 |
+
|
| 47 |
+
if self.sep_style == SeparatorStyle.TWO:
|
| 48 |
+
seps = [self.sep, self.sep2]
|
| 49 |
+
ret = self.system + seps[0]
|
| 50 |
+
for i, (role, message) in enumerate(messages):
|
| 51 |
+
if message:
|
| 52 |
+
if type(message) is tuple:
|
| 53 |
+
message, _, _ = message
|
| 54 |
+
ret += role + ": " + message + seps[i % 2]
|
| 55 |
+
else:
|
| 56 |
+
ret += role + ":"
|
| 57 |
+
return ret
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 60 |
+
|
| 61 |
+
def append_message(self, role, message):
|
| 62 |
+
"""Append a message to conversation history."""
|
| 63 |
+
self.messages.append([role, message])
|
| 64 |
+
|
| 65 |
+
def copy(self):
|
| 66 |
+
"""Create a copy of this conversation."""
|
| 67 |
+
return Conversation(
|
| 68 |
+
system=self.system,
|
| 69 |
+
roles=self.roles,
|
| 70 |
+
messages=[[x, y] for x, y in self.messages],
|
| 71 |
+
offset=self.offset,
|
| 72 |
+
sep_style=self.sep_style,
|
| 73 |
+
sep=self.sep,
|
| 74 |
+
sep2=self.sep2,
|
| 75 |
+
version=self.version,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# LLaVA v1 conversation template
|
| 80 |
+
conv_llava_v1 = Conversation(
|
| 81 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 82 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 83 |
+
roles=("USER", "ASSISTANT"),
|
| 84 |
+
version="v1",
|
| 85 |
+
messages=(),
|
| 86 |
+
offset=0,
|
| 87 |
+
sep_style=SeparatorStyle.TWO,
|
| 88 |
+
sep=" ",
|
| 89 |
+
sep2="</s>",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Available conversation templates
|
| 93 |
+
conv_templates = {
|
| 94 |
+
"llava_v1": conv_llava_v1,
|
| 95 |
+
"default": conv_llava_v1,
|
| 96 |
+
}
|
safellava/mm_utils.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
import base64
|
| 11 |
+
import torch
|
| 12 |
+
import math
|
| 13 |
+
import ast
|
| 14 |
+
|
| 15 |
+
from transformers import StoppingCriteria
|
| 16 |
+
from safellava.constants import IMAGE_TOKEN_INDEX
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def select_best_resolution(original_size, possible_resolutions):
|
| 20 |
+
"""
|
| 21 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
| 25 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
tuple: The best fit resolution in the format (width, height).
|
| 29 |
+
"""
|
| 30 |
+
original_width, original_height = original_size
|
| 31 |
+
best_fit = None
|
| 32 |
+
max_effective_resolution = 0
|
| 33 |
+
min_wasted_resolution = float('inf')
|
| 34 |
+
|
| 35 |
+
for width, height in possible_resolutions:
|
| 36 |
+
scale = min(width / original_width, height / original_height)
|
| 37 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
| 38 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
| 39 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 40 |
+
|
| 41 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
| 42 |
+
max_effective_resolution = effective_resolution
|
| 43 |
+
min_wasted_resolution = wasted_resolution
|
| 44 |
+
best_fit = (width, height)
|
| 45 |
+
|
| 46 |
+
return best_fit
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def resize_and_pad_image(image, target_resolution):
|
| 50 |
+
"""
|
| 51 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
image (PIL.Image.Image): The input image.
|
| 55 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
PIL.Image.Image: The resized and padded image.
|
| 59 |
+
"""
|
| 60 |
+
original_width, original_height = image.size
|
| 61 |
+
target_width, target_height = target_resolution
|
| 62 |
+
|
| 63 |
+
scale_w = target_width / original_width
|
| 64 |
+
scale_h = target_height / original_height
|
| 65 |
+
|
| 66 |
+
if scale_w < scale_h:
|
| 67 |
+
new_width = target_width
|
| 68 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
| 69 |
+
else:
|
| 70 |
+
new_height = target_height
|
| 71 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
| 72 |
+
|
| 73 |
+
resized_image = image.resize((new_width, new_height))
|
| 74 |
+
|
| 75 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
| 76 |
+
paste_x = (target_width - new_width) // 2
|
| 77 |
+
paste_y = (target_height - new_height) // 2
|
| 78 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
| 79 |
+
|
| 80 |
+
return new_image
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def divide_to_patches(image, patch_size):
|
| 84 |
+
"""
|
| 85 |
+
Divides an image into patches of a specified size.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
image (PIL.Image.Image): The input image.
|
| 89 |
+
patch_size (int): The size of each patch.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
| 93 |
+
"""
|
| 94 |
+
patches = []
|
| 95 |
+
width, height = image.size
|
| 96 |
+
for i in range(0, height, patch_size):
|
| 97 |
+
for j in range(0, width, patch_size):
|
| 98 |
+
box = (j, i, j + patch_size, i + patch_size)
|
| 99 |
+
patch = image.crop(box)
|
| 100 |
+
patches.append(patch)
|
| 101 |
+
|
| 102 |
+
return patches
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
| 106 |
+
"""
|
| 107 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
| 111 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
| 112 |
+
patch_size (int): The size of each image patch.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
| 116 |
+
"""
|
| 117 |
+
if type(grid_pinpoints) is list:
|
| 118 |
+
possible_resolutions = grid_pinpoints
|
| 119 |
+
else:
|
| 120 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
| 121 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
| 122 |
+
return width // patch_size, height // patch_size
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
| 126 |
+
"""
|
| 127 |
+
Process an image with variable resolutions.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
image (PIL.Image.Image): The input image to be processed.
|
| 131 |
+
processor: The image processor object.
|
| 132 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
torch.Tensor: A tensor containing the processed image patches.
|
| 136 |
+
"""
|
| 137 |
+
if type(grid_pinpoints) is list:
|
| 138 |
+
possible_resolutions = grid_pinpoints
|
| 139 |
+
else:
|
| 140 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
| 141 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
| 142 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
| 143 |
+
|
| 144 |
+
patches = divide_to_patches(image_padded, processor.crop_size['height'])
|
| 145 |
+
|
| 146 |
+
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
| 147 |
+
|
| 148 |
+
image_patches = [image_original_resize] + patches
|
| 149 |
+
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
| 150 |
+
for image_patch in image_patches]
|
| 151 |
+
return torch.stack(image_patches, dim=0)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def load_image_from_base64(image):
|
| 155 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def expand2square(pil_img, background_color):
|
| 159 |
+
width, height = pil_img.size
|
| 160 |
+
if width == height:
|
| 161 |
+
return pil_img
|
| 162 |
+
elif width > height:
|
| 163 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 164 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 165 |
+
return result
|
| 166 |
+
else:
|
| 167 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 168 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def process_images(images, image_processor, model_cfg):
|
| 173 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
| 174 |
+
new_images = []
|
| 175 |
+
if image_aspect_ratio == 'pad':
|
| 176 |
+
for image in images:
|
| 177 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
| 178 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 179 |
+
new_images.append(image)
|
| 180 |
+
elif image_aspect_ratio == "anyres":
|
| 181 |
+
for image in images:
|
| 182 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
| 183 |
+
new_images.append(image)
|
| 184 |
+
else:
|
| 185 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
| 186 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 187 |
+
new_images = torch.stack(new_images, dim=0)
|
| 188 |
+
return new_images
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 192 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
| 193 |
+
|
| 194 |
+
def insert_separator(X, sep):
|
| 195 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 196 |
+
|
| 197 |
+
input_ids = []
|
| 198 |
+
offset = 0
|
| 199 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 200 |
+
offset = 1
|
| 201 |
+
input_ids.append(prompt_chunks[0][0])
|
| 202 |
+
|
| 203 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 204 |
+
input_ids.extend(x[offset:])
|
| 205 |
+
|
| 206 |
+
if return_tensors is not None:
|
| 207 |
+
if return_tensors == 'pt':
|
| 208 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 209 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 210 |
+
return input_ids
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_model_name_from_path(model_path):
|
| 214 |
+
model_path = model_path.strip("/")
|
| 215 |
+
model_paths = model_path.split("/")
|
| 216 |
+
if model_paths[-1].startswith('checkpoint-'):
|
| 217 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 218 |
+
else:
|
| 219 |
+
return model_paths[-1]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 223 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 224 |
+
self.keywords = keywords
|
| 225 |
+
self.keyword_ids = []
|
| 226 |
+
self.max_keyword_len = 0
|
| 227 |
+
for keyword in keywords:
|
| 228 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 229 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 230 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 231 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 232 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 233 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 234 |
+
self.tokenizer = tokenizer
|
| 235 |
+
self.start_len = input_ids.shape[1]
|
| 236 |
+
|
| 237 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 238 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 239 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 240 |
+
for keyword_id in self.keyword_ids:
|
| 241 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
| 242 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
| 243 |
+
return True
|
| 244 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 245 |
+
for keyword in self.keywords:
|
| 246 |
+
if keyword in outputs:
|
| 247 |
+
return True
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 251 |
+
outputs = []
|
| 252 |
+
for i in range(output_ids.shape[0]):
|
| 253 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 254 |
+
return all(outputs)
|
safellava/model/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
| 9 |
+
from .language_model.safe_llava_llama import SafeLlavaLlamaForCausalLM
|
| 10 |
+
|
| 11 |
+
__all__ = ['LlavaLlamaForCausalLM', 'SafeLlavaLlamaForCausalLM', 'LlavaConfig']
|
safellava/model/language_model/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
| 9 |
+
from .safe_llava_llama import SafeLlavaLlamaForCausalLM, SafetyConfig
|
| 10 |
+
|
| 11 |
+
__all__ = ['LlavaLlamaForCausalLM', 'LlavaConfig', 'SafeLlavaLlamaForCausalLM', 'SafetyConfig']
|
safellava/model/language_model/llava_llama.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Copyright 2023 Haotian Liu
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from typing import List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
| 29 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
| 30 |
+
|
| 31 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 32 |
+
from transformers.generation.utils import GenerateOutput
|
| 33 |
+
|
| 34 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LlavaConfig(LlamaConfig):
|
| 38 |
+
model_type = "llava_llama"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
| 42 |
+
config_class = LlavaConfig
|
| 43 |
+
|
| 44 |
+
def __init__(self, config: LlamaConfig):
|
| 45 |
+
super(LlavaLlamaModel, self).__init__(config)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
| 49 |
+
config_class = LlavaConfig
|
| 50 |
+
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
super(LlamaForCausalLM, self).__init__(config)
|
| 53 |
+
self.model = LlavaLlamaModel(config)
|
| 54 |
+
self.pretraining_tp = config.pretraining_tp
|
| 55 |
+
self.vocab_size = config.vocab_size
|
| 56 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 57 |
+
|
| 58 |
+
self.post_init()
|
| 59 |
+
|
| 60 |
+
def get_model(self):
|
| 61 |
+
return self.model
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
input_ids: torch.LongTensor = None,
|
| 66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 67 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 68 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 70 |
+
labels: Optional[torch.LongTensor] = None,
|
| 71 |
+
use_cache: Optional[bool] = None,
|
| 72 |
+
cache_position: Optional[int] = None,
|
| 73 |
+
output_attentions: Optional[bool] = None,
|
| 74 |
+
output_hidden_states: Optional[bool] = None,
|
| 75 |
+
images: Optional[torch.FloatTensor] = None,
|
| 76 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 77 |
+
return_dict: Optional[bool] = None,
|
| 78 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 79 |
+
|
| 80 |
+
if inputs_embeds is None:
|
| 81 |
+
(
|
| 82 |
+
input_ids,
|
| 83 |
+
position_ids,
|
| 84 |
+
attention_mask,
|
| 85 |
+
past_key_values,
|
| 86 |
+
inputs_embeds,
|
| 87 |
+
labels
|
| 88 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 89 |
+
input_ids,
|
| 90 |
+
position_ids,
|
| 91 |
+
attention_mask,
|
| 92 |
+
past_key_values,
|
| 93 |
+
labels,
|
| 94 |
+
images,
|
| 95 |
+
image_sizes
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return super().forward(
|
| 99 |
+
input_ids=input_ids,
|
| 100 |
+
attention_mask=attention_mask,
|
| 101 |
+
position_ids=position_ids,
|
| 102 |
+
past_key_values=past_key_values,
|
| 103 |
+
inputs_embeds=inputs_embeds,
|
| 104 |
+
labels=labels,
|
| 105 |
+
use_cache=use_cache,
|
| 106 |
+
output_attentions=output_attentions,
|
| 107 |
+
output_hidden_states=output_hidden_states,
|
| 108 |
+
return_dict=return_dict
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def generate(
|
| 113 |
+
self,
|
| 114 |
+
inputs: Optional[torch.Tensor] = None,
|
| 115 |
+
images: Optional[torch.Tensor] = None,
|
| 116 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 117 |
+
**kwargs,
|
| 118 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 119 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 120 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 121 |
+
if "inputs_embeds" in kwargs:
|
| 122 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 123 |
+
|
| 124 |
+
if images is not None:
|
| 125 |
+
(
|
| 126 |
+
inputs,
|
| 127 |
+
position_ids,
|
| 128 |
+
attention_mask,
|
| 129 |
+
_,
|
| 130 |
+
inputs_embeds,
|
| 131 |
+
_
|
| 132 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 133 |
+
inputs,
|
| 134 |
+
position_ids,
|
| 135 |
+
attention_mask,
|
| 136 |
+
None,
|
| 137 |
+
None,
|
| 138 |
+
images,
|
| 139 |
+
image_sizes=image_sizes
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 143 |
+
|
| 144 |
+
return super().generate(
|
| 145 |
+
position_ids=position_ids,
|
| 146 |
+
attention_mask=attention_mask,
|
| 147 |
+
inputs_embeds=inputs_embeds,
|
| 148 |
+
**kwargs
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 152 |
+
inputs_embeds=None, **kwargs):
|
| 153 |
+
images = kwargs.pop("images", None)
|
| 154 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 155 |
+
inputs = super().prepare_inputs_for_generation(
|
| 156 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 157 |
+
)
|
| 158 |
+
if images is not None:
|
| 159 |
+
inputs['images'] = images
|
| 160 |
+
if image_sizes is not None:
|
| 161 |
+
inputs['image_sizes'] = image_sizes
|
| 162 |
+
return inputs
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
AutoConfig.register("llava_llama", LlavaConfig)
|
| 166 |
+
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
safellava/model/language_model/safe_llava_llama.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from safellava.model.language_model.llava_llama import (
|
| 16 |
+
LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
|
| 17 |
+
)
|
| 18 |
+
from safellava.constants import IMAGE_TOKEN_INDEX
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
from safellava.utils import setup_simple_logging
|
| 24 |
+
|
| 25 |
+
setup_simple_logging()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SafetyCausalLMOutputWithPast(CausalLMOutputWithPast):
|
| 30 |
+
"""
|
| 31 |
+
Base class for causal language model (or autoregressive) outputs with safety predictions.
|
| 32 |
+
"""
|
| 33 |
+
img_safety_logits: Optional[torch.FloatTensor] = None
|
| 34 |
+
img_safety_probs: Optional[torch.FloatTensor] = None
|
| 35 |
+
txt_safety_logits: Optional[torch.FloatTensor] = None
|
| 36 |
+
txt_safety_probs: Optional[torch.FloatTensor] = None
|
| 37 |
+
total_safety_logits: Optional[torch.FloatTensor] = None
|
| 38 |
+
total_safety_probs: Optional[torch.FloatTensor] = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SafetyMLP(nn.Module):
|
| 42 |
+
"""
|
| 43 |
+
Safety classification head implemented as Multi-layer Perceptron.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
| 47 |
+
safety_num_hidden_layers: int = 1):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
layers = []
|
| 51 |
+
|
| 52 |
+
layers.append(nn.Linear(input_size, hidden_size))
|
| 53 |
+
layers.append(nn.GELU())
|
| 54 |
+
|
| 55 |
+
for _ in range(safety_num_hidden_layers - 1):
|
| 56 |
+
layers.append(nn.Linear(hidden_size, hidden_size))
|
| 57 |
+
layers.append(nn.GELU())
|
| 58 |
+
|
| 59 |
+
layers.append(nn.Linear(hidden_size, output_size))
|
| 60 |
+
|
| 61 |
+
self.mlp = nn.Sequential(*layers)
|
| 62 |
+
|
| 63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
return self.mlp(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SafetyConfig(LlavaConfig):
|
| 68 |
+
"""Safety-aware configuration for pooling version """
|
| 69 |
+
model_type = "safe_llava_llama"
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
safety_categories=None,
|
| 74 |
+
safety_num_hidden_layers=1,
|
| 75 |
+
unfreeze_mm_vision_tower=True,
|
| 76 |
+
delay_load_vision_tower=False,
|
| 77 |
+
safety_head_hidden_scale=4.0,
|
| 78 |
+
pooling_method="mean", # mean, max, or cls
|
| 79 |
+
attention_dropout=0.0, # Add missing attribute for compatibility
|
| 80 |
+
**kwargs
|
| 81 |
+
):
|
| 82 |
+
# Ensure attention_dropout is in kwargs if not provided
|
| 83 |
+
if 'attention_dropout' not in kwargs:
|
| 84 |
+
kwargs['attention_dropout'] = attention_dropout
|
| 85 |
+
|
| 86 |
+
super().__init__(**kwargs)
|
| 87 |
+
|
| 88 |
+
# Default safety categories if not provided (from original SafeLLaVA)
|
| 89 |
+
self.safety_categories = safety_categories or [
|
| 90 |
+
"safe",
|
| 91 |
+
"gender",
|
| 92 |
+
"race",
|
| 93 |
+
"religion",
|
| 94 |
+
"harassment",
|
| 95 |
+
"disability_discrimination",
|
| 96 |
+
"drug_crime",
|
| 97 |
+
"property_crime",
|
| 98 |
+
"facial_data",
|
| 99 |
+
"identity_data",
|
| 100 |
+
"physical_self_injury",
|
| 101 |
+
"suicide",
|
| 102 |
+
"animal_abuse",
|
| 103 |
+
"obscene_gestures",
|
| 104 |
+
"physical_altercation",
|
| 105 |
+
"terrorism",
|
| 106 |
+
"weapon_related_violence",
|
| 107 |
+
"sexual_content",
|
| 108 |
+
"financial_advice",
|
| 109 |
+
"medical_advice"
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
self.safety_num_hidden_layers = safety_num_hidden_layers
|
| 113 |
+
self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
|
| 114 |
+
self.delay_load_vision_tower = delay_load_vision_tower
|
| 115 |
+
self.safety_head_hidden_scale = safety_head_hidden_scale
|
| 116 |
+
self.pooling_method = pooling_method
|
| 117 |
+
|
| 118 |
+
# self.use_img_safety_meta_token = False
|
| 119 |
+
self.use_txt_safety_meta_token = False
|
| 120 |
+
self.use_total_safety_meta_token = False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
|
| 124 |
+
"""
|
| 125 |
+
SafeLLaVA: A simplified version Uses pooled visual features for safety classification.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
config_class = SafetyConfig
|
| 129 |
+
_keys_to_ignore_on_load_unexpected = [] # Don't ignore img_safety_head weights
|
| 130 |
+
|
| 131 |
+
def __init__(self, config: SafetyConfig):
|
| 132 |
+
super().__init__(config)
|
| 133 |
+
|
| 134 |
+
# Safety head for image classification (using pooled visual tokens)
|
| 135 |
+
self.img_safety_head = SafetyMLP(
|
| 136 |
+
input_size=config.hidden_size,
|
| 137 |
+
hidden_size=int(config.hidden_size * config.safety_head_hidden_scale),
|
| 138 |
+
output_size=len(config.safety_categories),
|
| 139 |
+
safety_num_hidden_layers=config.safety_num_hidden_layers
|
| 140 |
+
)
|
| 141 |
+
logging.info("Created img_safety_head for SafeLLaVA")
|
| 142 |
+
|
| 143 |
+
# Store pooling method
|
| 144 |
+
self.pooling_method = config.pooling_method
|
| 145 |
+
|
| 146 |
+
# Safety warning template
|
| 147 |
+
self.safety_warning_template = (
|
| 148 |
+
"I apologize, but I cannot provide a response as the content appears to be {category}. "
|
| 149 |
+
"I aim to maintain ethical and safe interactions. "
|
| 150 |
+
"Please feel free to ask about other topics that do not involve potentially harmful or inappropriate content."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def get_model(self):
|
| 154 |
+
return self.model
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 158 |
+
"""
|
| 159 |
+
Custom from_pretrained to properly load img_safety_head weights.
|
| 160 |
+
"""
|
| 161 |
+
import os
|
| 162 |
+
import torch
|
| 163 |
+
from pathlib import Path
|
| 164 |
+
|
| 165 |
+
# Load the model normally first
|
| 166 |
+
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
| 167 |
+
|
| 168 |
+
# List of original LLaVA model names
|
| 169 |
+
original_llava_models = [
|
| 170 |
+
"liuhaotian/llava-v1.5-7b",
|
| 171 |
+
"liuhaotian/llava-v1.5-13b",
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
is_original_llava = any(str(pretrained_model_name_or_path).startswith(name) for name in original_llava_models)
|
| 175 |
+
|
| 176 |
+
# Load safety head weights for SafeLLaVA models
|
| 177 |
+
if not is_original_llava:
|
| 178 |
+
logging.info(f"Detected SafeLLaVA model: {pretrained_model_name_or_path}")
|
| 179 |
+
model_path = Path(pretrained_model_name_or_path)
|
| 180 |
+
|
| 181 |
+
# Handle both local paths and HuggingFace Hub
|
| 182 |
+
if not model_path.exists():
|
| 183 |
+
# Try HuggingFace cache
|
| 184 |
+
from huggingface_hub import snapshot_download
|
| 185 |
+
try:
|
| 186 |
+
model_path = Path(snapshot_download(repo_id=str(pretrained_model_name_or_path)))
|
| 187 |
+
logging.info(f"Downloaded from HuggingFace Hub to: {model_path}")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logging.warning(f"Could not download from Hub: {e}")
|
| 190 |
+
return model
|
| 191 |
+
|
| 192 |
+
if model_path.exists():
|
| 193 |
+
# Load safety head weights from safetensors
|
| 194 |
+
safetensors_index_path = model_path / "model.safetensors.index.json"
|
| 195 |
+
if safetensors_index_path.exists():
|
| 196 |
+
logging.info("Loading safety head weights from safetensors...")
|
| 197 |
+
from safetensors.torch import load_file
|
| 198 |
+
import json
|
| 199 |
+
|
| 200 |
+
# Load the index file
|
| 201 |
+
with open(safetensors_index_path, 'r') as f:
|
| 202 |
+
index_data = json.load(f)
|
| 203 |
+
|
| 204 |
+
# Load all safetensors files and collect safety head weights
|
| 205 |
+
safety_weights = {}
|
| 206 |
+
for weight_map in set(index_data.get('weight_map', {}).values()):
|
| 207 |
+
safetensors_file = model_path / weight_map
|
| 208 |
+
if safetensors_file.exists():
|
| 209 |
+
file_weights = load_file(str(safetensors_file))
|
| 210 |
+
# Extract only img_safety_head weights
|
| 211 |
+
for key, value in file_weights.items():
|
| 212 |
+
if key.startswith('img_safety_head.'):
|
| 213 |
+
safety_weights[key] = value
|
| 214 |
+
|
| 215 |
+
if safety_weights:
|
| 216 |
+
logging.info(f"Found {len(safety_weights)} img_safety_head weights")
|
| 217 |
+
# Load the weights
|
| 218 |
+
missing_keys, unexpected_keys = model.load_state_dict(safety_weights, strict=False)
|
| 219 |
+
logging.info("✅ Safety head weights loaded successfully")
|
| 220 |
+
else:
|
| 221 |
+
logging.warning("⚠️ No img_safety_head weights found in checkpoint")
|
| 222 |
+
else:
|
| 223 |
+
logging.warning(f"No safetensors index found at {safetensors_index_path}")
|
| 224 |
+
else:
|
| 225 |
+
logging.warning(f"Model path does not exist: {model_path}")
|
| 226 |
+
|
| 227 |
+
return model
|
| 228 |
+
|
| 229 |
+
def get_safety_warning(self, unsafe_categories):
|
| 230 |
+
if len(unsafe_categories) == 1:
|
| 231 |
+
category_str = f"related to {unsafe_categories[0]}"
|
| 232 |
+
else:
|
| 233 |
+
category_str = "related to " + ", ".join(unsafe_categories[:-1]) + f" and {unsafe_categories[-1]}"
|
| 234 |
+
return self.safety_warning_template.format(category=category_str)
|
| 235 |
+
|
| 236 |
+
def pool_visual_tokens(self, hidden_states, input_ids, images):
|
| 237 |
+
"""
|
| 238 |
+
Pool visual tokens from hidden states.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
hidden_states: Last layer hidden states [batch_size, seq_len, hidden_size]
|
| 242 |
+
input_ids: Original input token IDs to locate image positions
|
| 243 |
+
images: Input images tensor
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Pooled visual features [batch_size, hidden_size]
|
| 247 |
+
"""
|
| 248 |
+
batch_size = hidden_states.shape[0]
|
| 249 |
+
device = hidden_states.device
|
| 250 |
+
|
| 251 |
+
# If no images, return zeros
|
| 252 |
+
if images is None:
|
| 253 |
+
return torch.zeros(batch_size, hidden_states.shape[-1], device=device)
|
| 254 |
+
|
| 255 |
+
# Get the number of visual patches
|
| 256 |
+
vision_tower = self.get_vision_tower()
|
| 257 |
+
if vision_tower is not None and hasattr(vision_tower, 'config'):
|
| 258 |
+
# Calculate based on vision config
|
| 259 |
+
image_size = vision_tower.config.image_size
|
| 260 |
+
patch_size = vision_tower.config.patch_size
|
| 261 |
+
num_patches = (image_size // patch_size) ** 2
|
| 262 |
+
else:
|
| 263 |
+
num_patches = 576 # Default for CLIP ViT-L/14-336px
|
| 264 |
+
|
| 265 |
+
pooled_features = []
|
| 266 |
+
|
| 267 |
+
for batch_idx in range(batch_size):
|
| 268 |
+
try:
|
| 269 |
+
# Find where IMAGE_TOKEN_INDEX was in the original input
|
| 270 |
+
if input_ids is not None and batch_idx < input_ids.shape[0]:
|
| 271 |
+
image_positions = torch.where(input_ids[batch_idx] == IMAGE_TOKEN_INDEX)[0]
|
| 272 |
+
|
| 273 |
+
if len(image_positions) > 0:
|
| 274 |
+
# Visual tokens replace the IMAGE_TOKEN_INDEX
|
| 275 |
+
# The actual visual tokens start at this position
|
| 276 |
+
start_pos = image_positions[0].item()
|
| 277 |
+
end_pos = min(start_pos + num_patches, hidden_states.shape[1])
|
| 278 |
+
|
| 279 |
+
if end_pos > start_pos and (end_pos - start_pos) > 0:
|
| 280 |
+
visual_embeddings = hidden_states[batch_idx, start_pos:end_pos]
|
| 281 |
+
|
| 282 |
+
# Apply pooling
|
| 283 |
+
if visual_embeddings.shape[0] > 0:
|
| 284 |
+
if self.pooling_method == "mean":
|
| 285 |
+
pooled = visual_embeddings.mean(dim=0)
|
| 286 |
+
elif self.pooling_method == "max":
|
| 287 |
+
pooled = visual_embeddings.max(dim=0)[0]
|
| 288 |
+
elif self.pooling_method == "cls":
|
| 289 |
+
# Use the first visual token
|
| 290 |
+
pooled = visual_embeddings[0]
|
| 291 |
+
else:
|
| 292 |
+
pooled = visual_embeddings.mean(dim=0) # Default to mean
|
| 293 |
+
|
| 294 |
+
pooled_features.append(pooled)
|
| 295 |
+
else:
|
| 296 |
+
# Empty visual embeddings
|
| 297 |
+
pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
|
| 298 |
+
else:
|
| 299 |
+
# Invalid range
|
| 300 |
+
pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
|
| 301 |
+
else:
|
| 302 |
+
# No image token found, might be text-only sample
|
| 303 |
+
pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
|
| 304 |
+
else:
|
| 305 |
+
# No input_ids available
|
| 306 |
+
pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logging.warning(f"Error pooling visual tokens for batch {batch_idx}: {str(e)}")
|
| 310 |
+
# Return zero vector on error
|
| 311 |
+
pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
|
| 312 |
+
|
| 313 |
+
# Stack all pooled features
|
| 314 |
+
pooled_features = torch.stack(pooled_features, dim=0)
|
| 315 |
+
return pooled_features
|
| 316 |
+
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
input_ids=None,
|
| 320 |
+
attention_mask=None,
|
| 321 |
+
position_ids=None,
|
| 322 |
+
past_key_values=None,
|
| 323 |
+
inputs_embeds=None,
|
| 324 |
+
labels=None,
|
| 325 |
+
use_cache=None,
|
| 326 |
+
output_attentions=None,
|
| 327 |
+
output_hidden_states=None,
|
| 328 |
+
images=None,
|
| 329 |
+
image_sizes=None,
|
| 330 |
+
return_dict=None,
|
| 331 |
+
do_safety=False,
|
| 332 |
+
**kwargs,
|
| 333 |
+
) -> Union[Tuple, CausalLMOutputWithPast, SafetyCausalLMOutputWithPast]:
|
| 334 |
+
"""
|
| 335 |
+
Forward method for SafeLLaVA.
|
| 336 |
+
When do_safety=True, extracts and pools visual tokens for safety classification.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
# Store original input_ids for finding image token positions
|
| 340 |
+
original_input_ids = input_ids.clone() if input_ids is not None else None
|
| 341 |
+
|
| 342 |
+
# If do_safety is True, force output_hidden_states to True
|
| 343 |
+
if do_safety and (output_hidden_states is not True):
|
| 344 |
+
output_hidden_states = True
|
| 345 |
+
return_dict = True
|
| 346 |
+
|
| 347 |
+
# Prepare inputs for multimodal (handles image embedding)
|
| 348 |
+
if inputs_embeds is None:
|
| 349 |
+
(
|
| 350 |
+
input_ids,
|
| 351 |
+
position_ids,
|
| 352 |
+
attention_mask,
|
| 353 |
+
past_key_values,
|
| 354 |
+
inputs_embeds,
|
| 355 |
+
labels
|
| 356 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 357 |
+
input_ids,
|
| 358 |
+
position_ids,
|
| 359 |
+
attention_mask,
|
| 360 |
+
past_key_values,
|
| 361 |
+
labels,
|
| 362 |
+
images,
|
| 363 |
+
image_sizes
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Call parent's forward method
|
| 367 |
+
outputs = super(LlavaLlamaForCausalLM, self).forward(
|
| 368 |
+
input_ids=input_ids,
|
| 369 |
+
attention_mask=attention_mask,
|
| 370 |
+
position_ids=position_ids,
|
| 371 |
+
past_key_values=past_key_values,
|
| 372 |
+
inputs_embeds=inputs_embeds,
|
| 373 |
+
labels=labels,
|
| 374 |
+
use_cache=use_cache,
|
| 375 |
+
output_attentions=output_attentions,
|
| 376 |
+
output_hidden_states=output_hidden_states,
|
| 377 |
+
return_dict=True,
|
| 378 |
+
**kwargs
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# If do_safety=False, just return the outputs
|
| 382 |
+
if not do_safety:
|
| 383 |
+
if return_dict is False:
|
| 384 |
+
return (outputs.loss, outputs.logits, outputs.past_key_values,
|
| 385 |
+
outputs.hidden_states, outputs.attentions)
|
| 386 |
+
return outputs
|
| 387 |
+
|
| 388 |
+
# Safety classification using pooled visual tokens
|
| 389 |
+
hidden_states = outputs.hidden_states[-1] # Last layer hidden states
|
| 390 |
+
|
| 391 |
+
# Check if we have images to process
|
| 392 |
+
if images is None:
|
| 393 |
+
# No images, return outputs without safety
|
| 394 |
+
return outputs
|
| 395 |
+
|
| 396 |
+
# Pool visual tokens
|
| 397 |
+
pooled_visual_features = self.pool_visual_tokens(hidden_states, original_input_ids, images)
|
| 398 |
+
|
| 399 |
+
# Pass through safety head
|
| 400 |
+
img_safety_logits = self.img_safety_head(pooled_visual_features)
|
| 401 |
+
img_safety_probs = torch.softmax(img_safety_logits, dim=-1)
|
| 402 |
+
|
| 403 |
+
# Return results with safety outputs
|
| 404 |
+
if not return_dict:
|
| 405 |
+
return (outputs.loss, outputs.logits, outputs.past_key_values,
|
| 406 |
+
outputs.hidden_states, outputs.attentions,
|
| 407 |
+
img_safety_logits, img_safety_probs)
|
| 408 |
+
|
| 409 |
+
return SafetyCausalLMOutputWithPast(
|
| 410 |
+
loss=outputs.loss,
|
| 411 |
+
logits=outputs.logits,
|
| 412 |
+
past_key_values=outputs.past_key_values,
|
| 413 |
+
hidden_states=outputs.hidden_states,
|
| 414 |
+
attentions=outputs.attentions,
|
| 415 |
+
img_safety_logits=img_safety_logits,
|
| 416 |
+
img_safety_probs=img_safety_probs,
|
| 417 |
+
txt_safety_logits=None, # Not used in Pool version
|
| 418 |
+
txt_safety_probs=None,
|
| 419 |
+
total_safety_logits=None,
|
| 420 |
+
total_safety_probs=None
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# Register the model
|
| 425 |
+
AutoConfig.register("safe_llava_llama", SafetyConfig)
|
| 426 |
+
AutoModelForCausalLM.register(SafetyConfig, SafeLlavaLlamaForCausalLM)
|
safellava/model/llava_arch.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Copyright 2023 Haotian Liu
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from abc import ABC, abstractmethod
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
from .multimodal_encoder.builder import build_vision_tower
|
| 29 |
+
from .multimodal_projector.builder import build_vision_projector
|
| 30 |
+
|
| 31 |
+
from safellava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 32 |
+
|
| 33 |
+
from safellava.mm_utils import get_anyres_image_grid_shape
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LlavaMetaModel:
|
| 37 |
+
|
| 38 |
+
def __init__(self, config):
|
| 39 |
+
super(LlavaMetaModel, self).__init__(config)
|
| 40 |
+
|
| 41 |
+
if hasattr(config, "mm_vision_tower"):
|
| 42 |
+
delay_load = getattr(config, 'delay_load_vision_tower', False)
|
| 43 |
+
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
| 44 |
+
self.mm_projector = build_vision_projector(config)
|
| 45 |
+
|
| 46 |
+
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
|
| 47 |
+
self.image_newline = nn.Parameter(
|
| 48 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def get_vision_tower(self):
|
| 52 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
| 53 |
+
if type(vision_tower) is list:
|
| 54 |
+
vision_tower = vision_tower[0]
|
| 55 |
+
return vision_tower
|
| 56 |
+
|
| 57 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
| 58 |
+
vision_tower = model_args.vision_tower
|
| 59 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
| 60 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
| 61 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
| 62 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
| 63 |
+
|
| 64 |
+
self.config.mm_vision_tower = vision_tower
|
| 65 |
+
|
| 66 |
+
if self.get_vision_tower() is None:
|
| 67 |
+
vision_tower = build_vision_tower(model_args)
|
| 68 |
+
|
| 69 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 70 |
+
self.vision_tower = [vision_tower]
|
| 71 |
+
else:
|
| 72 |
+
self.vision_tower = vision_tower
|
| 73 |
+
else:
|
| 74 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 75 |
+
vision_tower = self.vision_tower[0]
|
| 76 |
+
else:
|
| 77 |
+
vision_tower = self.vision_tower
|
| 78 |
+
vision_tower.load_model()
|
| 79 |
+
|
| 80 |
+
self.config.use_mm_proj = True
|
| 81 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
| 82 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
| 83 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 84 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
| 85 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
| 86 |
+
|
| 87 |
+
if getattr(self, 'mm_projector', None) is None:
|
| 88 |
+
self.mm_projector = build_vision_projector(self.config)
|
| 89 |
+
|
| 90 |
+
if 'unpad' in mm_patch_merge_type:
|
| 91 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
| 92 |
+
self.image_newline = nn.Parameter(
|
| 93 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
for p in self.mm_projector.parameters():
|
| 97 |
+
p.requires_grad = True
|
| 98 |
+
|
| 99 |
+
if pretrain_mm_mlp_adapter is not None:
|
| 100 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
| 101 |
+
def get_w(weights, keyword):
|
| 102 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| 103 |
+
|
| 104 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def unpad_image(tensor, original_size):
|
| 108 |
+
"""
|
| 109 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
| 113 |
+
original_size (tuple): The original size of PIL image (width, height).
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
torch.Tensor: The unpadded image tensor.
|
| 117 |
+
"""
|
| 118 |
+
original_width, original_height = original_size
|
| 119 |
+
current_height, current_width = tensor.shape[1:]
|
| 120 |
+
|
| 121 |
+
original_aspect_ratio = original_width / original_height
|
| 122 |
+
current_aspect_ratio = current_width / current_height
|
| 123 |
+
|
| 124 |
+
if original_aspect_ratio > current_aspect_ratio:
|
| 125 |
+
scale_factor = current_width / original_width
|
| 126 |
+
new_height = int(original_height * scale_factor)
|
| 127 |
+
padding = (current_height - new_height) // 2
|
| 128 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
| 129 |
+
else:
|
| 130 |
+
scale_factor = current_height / original_height
|
| 131 |
+
new_width = int(original_width * scale_factor)
|
| 132 |
+
padding = (current_width - new_width) // 2
|
| 133 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
| 134 |
+
|
| 135 |
+
return unpadded_tensor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class LlavaMetaForCausalLM(ABC):
|
| 139 |
+
|
| 140 |
+
@abstractmethod
|
| 141 |
+
def get_model(self):
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
def get_vision_tower(self):
|
| 145 |
+
return self.get_model().get_vision_tower()
|
| 146 |
+
|
| 147 |
+
def encode_images(self, images):
|
| 148 |
+
vision_tower = self.get_model().get_vision_tower()
|
| 149 |
+
image_features = vision_tower(images)
|
| 150 |
+
image_features = self.get_model().mm_projector(image_features)
|
| 151 |
+
return image_features
|
| 152 |
+
|
| 153 |
+
def prepare_inputs_labels_for_multimodal(
|
| 154 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
| 155 |
+
images, image_sizes=None
|
| 156 |
+
):
|
| 157 |
+
vision_tower = self.get_vision_tower()
|
| 158 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| 159 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
| 160 |
+
|
| 161 |
+
if type(images) is list or images.ndim == 5:
|
| 162 |
+
if type(images) is list:
|
| 163 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
| 164 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
| 165 |
+
image_features = self.encode_images(concat_images)
|
| 166 |
+
split_sizes = [image.shape[0] for image in images]
|
| 167 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
| 168 |
+
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
|
| 169 |
+
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
|
| 170 |
+
if mm_patch_merge_type == 'flat':
|
| 171 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
| 172 |
+
elif mm_patch_merge_type.startswith('spatial'):
|
| 173 |
+
new_image_features = []
|
| 174 |
+
for image_idx, image_feature in enumerate(image_features):
|
| 175 |
+
if image_feature.shape[0] > 1:
|
| 176 |
+
base_image_feature = image_feature[0]
|
| 177 |
+
image_feature = image_feature[1:]
|
| 178 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
| 179 |
+
assert height * width == base_image_feature.shape[0]
|
| 180 |
+
if image_aspect_ratio == 'anyres':
|
| 181 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
|
| 182 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
if 'unpad' in mm_patch_merge_type:
|
| 186 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
| 187 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
| 188 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
| 189 |
+
image_feature = torch.cat((
|
| 190 |
+
image_feature,
|
| 191 |
+
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
|
| 192 |
+
), dim=-1)
|
| 193 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
| 194 |
+
else:
|
| 195 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
| 196 |
+
image_feature = image_feature.flatten(0, 3)
|
| 197 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
| 198 |
+
else:
|
| 199 |
+
image_feature = image_feature[0]
|
| 200 |
+
if 'unpad' in mm_patch_merge_type:
|
| 201 |
+
image_feature = torch.cat((
|
| 202 |
+
image_feature,
|
| 203 |
+
self.model.image_newline[None].to(image_feature.device)
|
| 204 |
+
), dim=0)
|
| 205 |
+
new_image_features.append(image_feature)
|
| 206 |
+
image_features = new_image_features
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
| 209 |
+
else:
|
| 210 |
+
image_features = self.encode_images(images)
|
| 211 |
+
|
| 212 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 213 |
+
raise NotImplementedError
|
| 214 |
+
|
| 215 |
+
_labels = labels
|
| 216 |
+
_position_ids = position_ids
|
| 217 |
+
_attention_mask = attention_mask
|
| 218 |
+
if attention_mask is None:
|
| 219 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 220 |
+
else:
|
| 221 |
+
attention_mask = attention_mask.bool()
|
| 222 |
+
if position_ids is None:
|
| 223 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
| 224 |
+
if labels is None:
|
| 225 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 226 |
+
|
| 227 |
+
_input_ids = input_ids
|
| 228 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
| 229 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 230 |
+
|
| 231 |
+
new_input_embeds = []
|
| 232 |
+
new_labels = []
|
| 233 |
+
cur_image_idx = 0
|
| 234 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 235 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
| 236 |
+
if num_images == 0:
|
| 237 |
+
cur_image_features = image_features[cur_image_idx]
|
| 238 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
| 239 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
| 240 |
+
new_input_embeds.append(cur_input_embeds)
|
| 241 |
+
new_labels.append(labels[batch_idx])
|
| 242 |
+
cur_image_idx += 1
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
| 246 |
+
cur_input_ids_noim = []
|
| 247 |
+
cur_labels = labels[batch_idx]
|
| 248 |
+
cur_labels_noim = []
|
| 249 |
+
for i in range(len(image_token_indices) - 1):
|
| 250 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
| 251 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
| 252 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 253 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
| 254 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
| 255 |
+
cur_new_input_embeds = []
|
| 256 |
+
cur_new_labels = []
|
| 257 |
+
|
| 258 |
+
for i in range(num_images + 1):
|
| 259 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
| 260 |
+
cur_new_labels.append(cur_labels_noim[i])
|
| 261 |
+
if i < num_images:
|
| 262 |
+
cur_image_features = image_features[cur_image_idx]
|
| 263 |
+
cur_image_idx += 1
|
| 264 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 265 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 266 |
+
|
| 267 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
| 268 |
+
|
| 269 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
| 270 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
| 271 |
+
|
| 272 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 273 |
+
new_labels.append(cur_new_labels)
|
| 274 |
+
|
| 275 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
| 276 |
+
if tokenizer_model_max_length is not None:
|
| 277 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
| 278 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
| 279 |
+
|
| 280 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 281 |
+
batch_size = len(new_input_embeds)
|
| 282 |
+
|
| 283 |
+
new_input_embeds_padded = []
|
| 284 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
| 285 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 286 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
| 287 |
+
|
| 288 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 289 |
+
cur_len = cur_new_embed.shape[0]
|
| 290 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
| 291 |
+
new_input_embeds_padded.append(torch.cat((
|
| 292 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
| 293 |
+
cur_new_embed
|
| 294 |
+
), dim=0))
|
| 295 |
+
if cur_len > 0:
|
| 296 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
| 297 |
+
attention_mask[i, -cur_len:] = True
|
| 298 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 299 |
+
else:
|
| 300 |
+
new_input_embeds_padded.append(torch.cat((
|
| 301 |
+
cur_new_embed,
|
| 302 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
| 303 |
+
), dim=0))
|
| 304 |
+
if cur_len > 0:
|
| 305 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
| 306 |
+
attention_mask[i, :cur_len] = True
|
| 307 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 308 |
+
|
| 309 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
| 310 |
+
|
| 311 |
+
if _labels is None:
|
| 312 |
+
new_labels = None
|
| 313 |
+
else:
|
| 314 |
+
new_labels = new_labels_padded
|
| 315 |
+
|
| 316 |
+
if _attention_mask is None:
|
| 317 |
+
attention_mask = None
|
| 318 |
+
else:
|
| 319 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
| 320 |
+
|
| 321 |
+
if _position_ids is None:
|
| 322 |
+
position_ids = None
|
| 323 |
+
|
| 324 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
| 325 |
+
|
| 326 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
| 327 |
+
if model_args.mm_use_im_patch_token:
|
| 328 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 329 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 330 |
+
|
| 331 |
+
if model_args.mm_use_im_start_end:
|
| 332 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 333 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 334 |
+
|
| 335 |
+
if num_new_tokens > 0:
|
| 336 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 337 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 338 |
+
|
| 339 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 340 |
+
dim=0, keepdim=True)
|
| 341 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 342 |
+
dim=0, keepdim=True)
|
| 343 |
+
|
| 344 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 345 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 346 |
+
|
| 347 |
+
if model_args.tune_mm_mlp_adapter:
|
| 348 |
+
for p in self.get_input_embeddings().parameters():
|
| 349 |
+
p.requires_grad = True
|
| 350 |
+
for p in self.get_output_embeddings().parameters():
|
| 351 |
+
p.requires_grad = False
|
| 352 |
+
|
| 353 |
+
if model_args.pretrain_mm_mlp_adapter:
|
| 354 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
| 355 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
| 356 |
+
assert num_new_tokens == 2
|
| 357 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
| 358 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
| 359 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 360 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
| 363 |
+
elif model_args.mm_use_im_patch_token:
|
| 364 |
+
if model_args.tune_mm_mlp_adapter:
|
| 365 |
+
for p in self.get_input_embeddings().parameters():
|
| 366 |
+
p.requires_grad = False
|
| 367 |
+
for p in self.get_output_embeddings().parameters():
|
| 368 |
+
p.requires_grad = False
|
safellava/model/multimodal_encoder/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Multimodal encoder module for SafeLLaVA
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .builder import build_vision_tower
|
| 13 |
+
from .clip_encoder import CLIPVisionTower
|
| 14 |
+
|
| 15 |
+
__all__ = ['build_vision_tower', 'CLIPVisionTower']
|
safellava/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Vision tower builder for SafeLLaVA
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from .clip_encoder import CLIPVisionTower
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_vision_tower(vision_tower_cfg, custom_vision_tower_path=None, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Build vision tower from config.
|
| 19 |
+
SafeLLaVA uses CLIPVisionTower.
|
| 20 |
+
"""
|
| 21 |
+
vision_tower = getattr(
|
| 22 |
+
vision_tower_cfg,
|
| 23 |
+
'mm_vision_tower',
|
| 24 |
+
getattr(vision_tower_cfg, 'vision_tower', None)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 28 |
+
|
| 29 |
+
if (is_absolute_path_exists or
|
| 30 |
+
vision_tower.startswith("openai") or
|
| 31 |
+
vision_tower.startswith("laion") or
|
| 32 |
+
"ShareGPT4V" in vision_tower):
|
| 33 |
+
return CLIPVisionTower(
|
| 34 |
+
vision_tower,
|
| 35 |
+
args=vision_tower_cfg,
|
| 36 |
+
custom_vision_tower_path=custom_vision_tower_path,
|
| 37 |
+
**kwargs
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
safellava/model/multimodal_encoder/clip_encoder.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
CLIP Vision Encoder for SafeLLaVA
|
| 10 |
+
|
| 11 |
+
This is a minimal version containing only what SafeLLaVA needs.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
from safellava.utils import setup_simple_logging
|
| 20 |
+
|
| 21 |
+
setup_simple_logging()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CLIPVisionTower(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Basic CLIP vision encoder wrapper for SafeLLaVA.
|
| 27 |
+
Uses standard CLIP ViT-L/14-336px without modifications.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, vision_tower, args, delay_load=False, custom_vision_tower_path=None, **kwargs):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.is_loaded = False
|
| 34 |
+
self.vision_tower_name = vision_tower
|
| 35 |
+
self.select_layer = args.mm_vision_select_layer
|
| 36 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 37 |
+
|
| 38 |
+
self.use_img_safety_token = False
|
| 39 |
+
|
| 40 |
+
# Store custom path if provided (not used in Pool version but needed for compatibility)
|
| 41 |
+
self.custom_vision_tower_path = custom_vision_tower_path
|
| 42 |
+
|
| 43 |
+
if not delay_load:
|
| 44 |
+
self.load_model()
|
| 45 |
+
else:
|
| 46 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
| 47 |
+
|
| 48 |
+
def load_model(self):
|
| 49 |
+
"""Load the CLIP vision model"""
|
| 50 |
+
if self.is_loaded:
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
logging.info(f"Loading vision tower from: {self.vision_tower_name}")
|
| 54 |
+
|
| 55 |
+
# Load standard CLIP model
|
| 56 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 57 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
| 58 |
+
self.vision_tower.requires_grad_(False)
|
| 59 |
+
|
| 60 |
+
self.is_loaded = True
|
| 61 |
+
logging.info("Initialized CLIPVisionModel (no safety tokens)")
|
| 62 |
+
|
| 63 |
+
def feature_select(self, image_forward_outs):
|
| 64 |
+
"""Select features from CLIP output"""
|
| 65 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 66 |
+
|
| 67 |
+
if self.select_feature == 'patch':
|
| 68 |
+
# Remove CLS token, keep only patch tokens
|
| 69 |
+
image_features = image_features[:, 1:]
|
| 70 |
+
elif self.select_feature == 'cls_patch':
|
| 71 |
+
# Keep both CLS and patch tokens
|
| 72 |
+
image_features = image_features
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 75 |
+
|
| 76 |
+
return image_features
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def forward(self, images):
|
| 80 |
+
"""Forward pass through vision encoder"""
|
| 81 |
+
if not self.is_loaded:
|
| 82 |
+
self.load_model()
|
| 83 |
+
|
| 84 |
+
if type(images) is list:
|
| 85 |
+
image_features = []
|
| 86 |
+
for image in images:
|
| 87 |
+
image_forward_out = self.vision_tower(
|
| 88 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
| 89 |
+
output_hidden_states=True
|
| 90 |
+
)
|
| 91 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 92 |
+
image_features.append(image_feature)
|
| 93 |
+
else:
|
| 94 |
+
image_forward_outs = self.vision_tower(
|
| 95 |
+
images.to(device=self.device, dtype=self.dtype),
|
| 96 |
+
output_hidden_states=True
|
| 97 |
+
)
|
| 98 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 99 |
+
|
| 100 |
+
return image_features
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def dummy_feature(self):
|
| 104 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def dtype(self):
|
| 108 |
+
return self.vision_tower.dtype
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def device(self):
|
| 112 |
+
return self.vision_tower.device
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def config(self):
|
| 116 |
+
if self.is_loaded:
|
| 117 |
+
return self.vision_tower.config
|
| 118 |
+
else:
|
| 119 |
+
return self.cfg_only
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def hidden_size(self):
|
| 123 |
+
return self.config.hidden_size
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def num_patches(self):
|
| 127 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
safellava/model/multimodal_projector/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .builder import build_vision_projector, IdentityMap, SimpleResBlock
|
| 9 |
+
|
| 10 |
+
__all__ = ['build_vision_projector', 'IdentityMap', 'SimpleResBlock']
|
safellava/model/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IdentityMap(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, x, *args, **kwargs):
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def config(self):
|
| 22 |
+
return {"mm_projector_type": 'identity'}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SimpleResBlock(nn.Module):
|
| 26 |
+
def __init__(self, channels):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 29 |
+
|
| 30 |
+
self.proj = nn.Sequential(
|
| 31 |
+
nn.Linear(channels, channels),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
nn.Linear(channels, channels)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.pre_norm(x)
|
| 38 |
+
return x + self.proj(x)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
| 42 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
| 43 |
+
|
| 44 |
+
if projector_type == 'linear':
|
| 45 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 46 |
+
|
| 47 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
| 48 |
+
if mlp_gelu_match:
|
| 49 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 50 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 51 |
+
for _ in range(1, mlp_depth):
|
| 52 |
+
modules.append(nn.GELU())
|
| 53 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 54 |
+
return nn.Sequential(*modules)
|
| 55 |
+
|
| 56 |
+
if projector_type == 'identity':
|
| 57 |
+
return IdentityMap()
|
| 58 |
+
|
| 59 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
safellava/utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
|
| 3 |
+
Modified for SafeLLaVA
|
| 4 |
+
|
| 5 |
+
Original LLaVA License: Apache License 2.0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import datetime
|
| 9 |
+
import logging
|
| 10 |
+
import logging.handlers
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
import requests
|
| 15 |
+
|
| 16 |
+
from safellava.constants import LOGDIR
|
| 17 |
+
|
| 18 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
| 19 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
| 20 |
+
|
| 21 |
+
handler = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_logger(logger_name, logger_filename):
|
| 25 |
+
global handler
|
| 26 |
+
|
| 27 |
+
formatter = logging.Formatter(
|
| 28 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if not logging.getLogger().handlers:
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
| 35 |
+
|
| 36 |
+
stdout_logger = logging.getLogger("stdout")
|
| 37 |
+
stdout_logger.setLevel(logging.INFO)
|
| 38 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
| 39 |
+
sys.stdout = sl
|
| 40 |
+
|
| 41 |
+
stderr_logger = logging.getLogger("stderr")
|
| 42 |
+
stderr_logger.setLevel(logging.ERROR)
|
| 43 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
| 44 |
+
sys.stderr = sl
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(logger_name)
|
| 47 |
+
logger.setLevel(logging.INFO)
|
| 48 |
+
|
| 49 |
+
if handler is None:
|
| 50 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
| 51 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
| 52 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
| 53 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
| 54 |
+
handler.setFormatter(formatter)
|
| 55 |
+
|
| 56 |
+
for name, item in logging.root.manager.loggerDict.items():
|
| 57 |
+
if isinstance(item, logging.Logger):
|
| 58 |
+
item.addHandler(handler)
|
| 59 |
+
|
| 60 |
+
return logger
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class StreamToLogger(object):
|
| 64 |
+
"""Fake file-like stream object that redirects writes to a logger instance."""
|
| 65 |
+
|
| 66 |
+
def __init__(self, logger, log_level=logging.INFO):
|
| 67 |
+
self.terminal = sys.stdout
|
| 68 |
+
self.logger = logger
|
| 69 |
+
self.log_level = log_level
|
| 70 |
+
self.linebuf = ''
|
| 71 |
+
|
| 72 |
+
def __getattr__(self, attr):
|
| 73 |
+
return getattr(self.terminal, attr)
|
| 74 |
+
|
| 75 |
+
def write(self, buf):
|
| 76 |
+
temp_linebuf = self.linebuf + buf
|
| 77 |
+
self.linebuf = ''
|
| 78 |
+
for line in temp_linebuf.splitlines(True):
|
| 79 |
+
if line[-1] == '\n':
|
| 80 |
+
self.logger.log(self.log_level, line.rstrip())
|
| 81 |
+
else:
|
| 82 |
+
self.linebuf += line
|
| 83 |
+
|
| 84 |
+
def flush(self):
|
| 85 |
+
if self.linebuf != '':
|
| 86 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
| 87 |
+
self.linebuf = ''
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def disable_torch_init():
|
| 91 |
+
"""Disable the redundant torch default initialization to accelerate model creation."""
|
| 92 |
+
import torch
|
| 93 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 94 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def violates_moderation(text):
|
| 98 |
+
"""Check whether the text violates OpenAI moderation API."""
|
| 99 |
+
url = "https://api.openai.com/v1/moderations"
|
| 100 |
+
headers = {"Content-Type": "application/json",
|
| 101 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
| 102 |
+
text = text.replace("\n", "")
|
| 103 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
| 104 |
+
data = data.encode("utf-8")
|
| 105 |
+
try:
|
| 106 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
| 107 |
+
flagged = ret.json()["results"][0]["flagged"]
|
| 108 |
+
except requests.exceptions.RequestException as e:
|
| 109 |
+
flagged = False
|
| 110 |
+
except KeyError as e:
|
| 111 |
+
flagged = False
|
| 112 |
+
|
| 113 |
+
return flagged
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def pretty_print_semaphore(semaphore):
|
| 117 |
+
if semaphore is None:
|
| 118 |
+
return "None"
|
| 119 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def setup_simple_logging():
|
| 123 |
+
logging.basicConfig(
|
| 124 |
+
level=logging.INFO,
|
| 125 |
+
format='[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s',
|
| 126 |
+
datefmt='%Y-%m-%d %H:%M:%S,%03d'
|
| 127 |
+
)
|
test_image.png
ADDED
|
Git LFS Details
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|