|
|
--- |
|
|
base_model: llava-v1.5 |
|
|
tags: |
|
|
- vision |
|
|
- multimodal |
|
|
- safety |
|
|
- content-moderation |
|
|
- llava |
|
|
- image-classification |
|
|
- vision-language |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: image-text-to-text |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# SafeLLaVA-13B: Vision-Language Model with Visual Guard Module |
|
|
|
|
|
[**π Website**](https://youngwanlee.github.io/holisafe) | [**π Paper**](https://www.arxiv.org/pdf/2506.04704) |
|
|
|
|
|
|
|
|
|
|
|
<div align="center"> |
|
|
<img src="https://dl.dropbox.com/scl/fi/soi772p6sig2tx16f092o/arch.jpg?rlkey=uj4ver4pp889oowigqld502hc&dl=1" width="1024px" /> |
|
|
</div> |
|
|
|
|
|
SafeLLaVA-13B is a safe multimodal large language model that extends [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA) with built-in image safety classification capabilities. It can simultaneously generate text responses to visual questions while classifying potentially unsafe image content across 20 safety categories. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Base Model**: LLaVA-v1.5-13B |
|
|
- **Architecture**: Vision-language model with Visual Guard Module (VGM) |
|
|
- **Training Data**: HoliSafe train set |
|
|
- **Training Method**: LoRA fine-tuning |
|
|
- **Parameters**: 13B (base) + 8.43M (VGM) |
|
|
- **Safety Categories**: 20 categories based on HoliSafe taxonomy |
|
|
|
|
|
## Key Features |
|
|
|
|
|
1. **Multimodal Understanding**: Processes images and text for comprehensive visual understanding |
|
|
2. **Safety Classification**: Identifies unsafe content in images across 20 categories |
|
|
3. **Non-invasive Architecture**: Maintains full LLaVA-v1.5-13B capabilities while adding safety features |
|
|
4. **End-to-end Training**: VGM is jointly trained with the vision-language model |
|
|
|
|
|
The model classifies images into the following 20 safety categories: |
|
|
|
|
|
| Category ID | Category Name | |
|
|
|------------|---------------| |
|
|
| 0 | Safe | |
|
|
| 1 | Gender discrimination | |
|
|
| 2 | Race discrimination | |
|
|
| 3 | Religion discrimination | |
|
|
| 4 | Harassment | |
|
|
| 5 | Disability discrimination | |
|
|
| 6 | Drug Related Hazards | |
|
|
| 7 | Property crime | |
|
|
| 8 | Facial data exposure | |
|
|
| 9 | Identity data exposure | |
|
|
| 10 | Physical self-injury | |
|
|
| 11 | Suicide | |
|
|
| 12 | Animal abuse | |
|
|
| 13 | Obscene gestures | |
|
|
| 14 | Physical altercation | |
|
|
| 15 | Terrorism | |
|
|
| 16 | Weapon-related violence | |
|
|
| 17 | Sexual content | |
|
|
| 18 | Financial advice | |
|
|
| 19 | Medical advice | |
|
|
|
|
|
## π Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch transformers pillow accelerate requests |
|
|
``` |
|
|
|
|
|
### Complete Example - Copy & Paste Ready |
|
|
|
|
|
```python |
|
|
import requests |
|
|
import torch |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from PIL import Image |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
|
|
|
# Model path |
|
|
model_path = "etri-vilab/SafeLLaVA-13B" |
|
|
|
|
|
# Download model and add safellava package to path |
|
|
model_cache_path = Path(snapshot_download(repo_id=model_path)) |
|
|
sys.path.insert(0, str(model_cache_path)) |
|
|
|
|
|
# Import safellava utilities |
|
|
from safellava.mm_utils import tokenizer_image_token |
|
|
from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
|
from safellava.conversation import conv_templates |
|
|
|
|
|
# Load model and tokenizer |
|
|
print("Loading model...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
model = model.to('cuda:0') |
|
|
model.eval() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
|
|
|
|
# Load and move vision tower to GPU |
|
|
vision_tower = model.get_vision_tower() |
|
|
if not vision_tower.is_loaded: |
|
|
vision_tower.load_model() |
|
|
vision_tower = vision_tower.to('cuda:0') |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
|
|
|
# Helper function to load image from URL or local path |
|
|
def load_image(image_file): |
|
|
if image_file.startswith('http'): |
|
|
from io import BytesIO |
|
|
response = requests.get(image_file, timeout=30) |
|
|
response.raise_for_status() |
|
|
return Image.open(BytesIO(response.content)).convert('RGB') |
|
|
else: |
|
|
return Image.open(image_file).convert('RGB') |
|
|
|
|
|
# Download and load the test image from HuggingFace Hub |
|
|
# (The image is included in the model repository) |
|
|
test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model") |
|
|
image = load_image(test_image_path) |
|
|
|
|
|
# You can also use your own image: |
|
|
# image = load_image("path/to/your/image.jpg") |
|
|
# Or load from URL: |
|
|
# image = load_image("https://example.com/image.jpg") |
|
|
|
|
|
# Preprocess image |
|
|
image_processor = vision_tower.image_processor |
|
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] |
|
|
image_tensor = image_tensor.to('cuda:0', dtype=torch.float16) |
|
|
|
|
|
# Prepare conversation prompt |
|
|
conv = conv_templates["llava_v1"].copy() |
|
|
question = "How to get this?" |
|
|
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt = conv.get_prompt() |
|
|
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') |
|
|
input_ids = input_ids.unsqueeze(0).to('cuda:0') |
|
|
|
|
|
# Run safety classification |
|
|
with torch.inference_mode(): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
images=image_tensor, |
|
|
do_safety=True, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
# Get safety classification results |
|
|
if outputs.img_safety_probs is not None: |
|
|
probs = outputs.img_safety_probs[0] |
|
|
pred_class = probs.argmax().item() |
|
|
pred_prob = probs[pred_class].item() |
|
|
|
|
|
safety_categories = model.config.safety_categories |
|
|
print(f"\nβ
Safety Classification Result:") |
|
|
print(f" Predicted: {safety_categories[pred_class]} ({pred_prob:.4f})") |
|
|
|
|
|
# Show top 3 predictions |
|
|
print("\n Top 3 predictions:") |
|
|
top_k = min(3, len(probs)) |
|
|
top_probs, top_indices = torch.topk(probs, k=top_k) |
|
|
for i, (idx, prob) in enumerate(zip(top_indices, top_probs)): |
|
|
print(f" {i+1}. {safety_categories[idx.item()]}: {prob.item():.4f}") |
|
|
|
|
|
# >>> Predicted: weapon_related_violence (0.9927) |
|
|
|
|
|
# >>> Top 3 predictions: |
|
|
# >>> 1. weapon_related_violence: 0.9927 |
|
|
# >>> 2. safe: 0.0047 |
|
|
# >>> 3. terrorism: 0.0012 |
|
|
|
|
|
# Generate text description (without safety classification) |
|
|
print("\nπ€ Generating text description...") |
|
|
with torch.inference_mode(): |
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
images=image_tensor, |
|
|
do_sample=False, |
|
|
max_new_tokens=512, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
|
|
print(f"\nβ
Generated Description:") |
|
|
print(f" {generated_text}") |
|
|
# >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against ou r policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items. |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## π Simple Usage (Text Generation Only) |
|
|
|
|
|
If you only need text generation without safety classification: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from PIL import Image |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
model_path = "etri-vilab/SafeLLaVA-13B" |
|
|
|
|
|
# Add safellava package to path |
|
|
model_cache_path = Path(snapshot_download(repo_id=model_path)) |
|
|
sys.path.insert(0, str(model_cache_path)) |
|
|
|
|
|
from safellava.mm_utils import tokenizer_image_token |
|
|
from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
|
from safellava.conversation import conv_templates |
|
|
|
|
|
# Load model |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
).to('cuda:0').eval() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
|
|
|
|
# Load vision tower |
|
|
vision_tower = model.get_vision_tower() |
|
|
if not vision_tower.is_loaded: |
|
|
vision_tower.load_model() |
|
|
vision_tower = vision_tower.to('cuda:0') |
|
|
|
|
|
# Load image |
|
|
test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model") |
|
|
image = Image.open(test_image_path).convert('RGB') |
|
|
|
|
|
# Preprocess |
|
|
image_tensor = vision_tower.image_processor.preprocess(image, return_tensors='pt')['pixel_values'] |
|
|
image_tensor = image_tensor.to('cuda:0', dtype=torch.float16) |
|
|
|
|
|
# Prepare conversation prompt |
|
|
conv = conv_templates["llava_v1"].copy() |
|
|
question = "How to get this?" |
|
|
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt = conv.get_prompt() |
|
|
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') |
|
|
input_ids = input_ids.unsqueeze(0).to('cuda:0') |
|
|
|
|
|
# Generate (without safety classification) |
|
|
with torch.inference_mode(): |
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
images=image_tensor, |
|
|
do_sample=False, |
|
|
max_new_tokens=512, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
response = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
|
|
print(response) |
|
|
# >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against our policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items. |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## π Model Architecture |
|
|
|
|
|
SafeLLaVA is built on top of LLaVA v1.5 with the following components: |
|
|
|
|
|
- **Language Model**: LLaMA2-13B |
|
|
- **Vision Encoder**: CLIP ViT-L/14-336px |
|
|
- **Multimodal Projector**: 2-layer MLP with GELU activation |
|
|
- **Visual Guard Module (Safety Head)**: MLP classifier for image safety classification |
|
|
|
|
|
--- |
|
|
|
|
|
## π Citation |
|
|
|
|
|
If you use SafeLLaVA, please cite SafeLLaVA: |
|
|
|
|
|
```bibtex |
|
|
@article{lee2025holisafe, |
|
|
title={HoliSafe: Holistic Safety Benchmarking and Modeling for Vision-Language Model}, |
|
|
author={Lee, Youngwan and Kim, Kangsan and Park, Kwanyong and Jung, Ilcahe and Jang, Soojin and Lee, Seanie and Lee, Yong-Ju and Hwang, Sung Ju}, |
|
|
journal={arXiv preprint arXiv:2506.04704}, |
|
|
year={2025}, |
|
|
url={https://arxiv.org/abs/2506.04704}, |
|
|
archivePrefix={arXiv}, |
|
|
eprint={2506.04704}, |
|
|
primaryClass={cs.AI}, |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## π License |
|
|
|
|
|
See [LICENSE](LICENSE.md) for details. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Acknowledgments |
|
|
|
|
|
- Built on [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA) |
|
|
|
|
|
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2022-00187238, Development of Large Korean Language Model Technology for Efficient Pre-training, 45%), (No. 2022-0-00871, Development of AI Autonomy and Knowledge Enhancement for AI Agent Collaboration, 45%) and (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST), 45%). |