ywlee88 commited on
Commit
1e68047
·
verified ·
1 Parent(s): b72a92c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test_image.png filter=lfs diff=lfs merge=lfs -text
LICENSE.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # License for SafeLLaVA
2
+
3
+ The SafeLLaVA project is governed by a **hybrid license model**. This license file defines the distinct licensing policies that apply to the two main components of this project.
4
+
5
+ ---
6
+
7
+ ## 1. Definition of Work
8
+
9
+ ### Model Name
10
+ **SafeLLaVA**
11
+
12
+ ### Reference Publication
13
+ This model (SafeLLaVA) is the official model presented in the academic paper:
14
+
15
+ > **"HoliSafe: Holistic Safety Benchmarking and Modeling for Vision-Language Model"**
16
+ > [https://arxiv.org/abs/2506.04704](https://arxiv.org/abs/2506.04704)
17
+
18
+ ### Base Model
19
+ This model is a **Derivative Work** based on the LLaVA-v1.5 model:
20
+ - [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA)
21
+
22
+ ### Modifications by ETRI
23
+ This work integrates an independently developed **Visual Guard Module (VGM)** to classify harmful image inputs and generate safe text responses. All modifications and additions are the work of ETRI.
24
+
25
+ ---
26
+
27
+ ## 2. License Summary
28
+
29
+ | Component | License |
30
+ |-----------|---------|
31
+ | Independently Developed Code (e.g., VGM) | Apache License 2.0 |
32
+ | LLaVA-Based Components and Entire Model | LLaVA-v1.5 Usage and License Notices |
33
+
34
+ ---
35
+
36
+ ## Part 1: Apache License 2.0 (For Independently Developed Code)
37
+
38
+ All original source code and components developed independently by **Electronics and Telecommunications Research Institute (ETRI)** (hereinafter "Copyright Holder"), including the **Visual Guard Module (VGM)** contained in this project, are subject to the Apache License, Version 2.0 (the "License").
39
+
40
+ You may not use this file except in compliance with the License. You may obtain a copy of the License at:
41
+
42
+ [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
43
+
44
+ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
45
+ ```
46
+ Copyright 2025 Electronics and Telecommunications Research Institute (ETRI)
47
+
48
+ Licensed under the Apache License, Version 2.0 (the "License");
49
+ you may not use this file except in compliance with the License.
50
+ You may obtain a copy of the License at
51
+
52
+ http://www.apache.org/licenses/LICENSE-2.0
53
+
54
+ Unless required by applicable law or agreed to in writing, software
55
+ distributed under the License is distributed on an "AS IS" BASIS,
56
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57
+ See the License for the specific language governing permissions and
58
+ limitations under the License.
59
+ ```
60
+
61
+ ---
62
+
63
+ ## Part 2: LLaVA-v1.5 Usage and License Notices (For the LLaVA-Based Derivative Work)
64
+
65
+ This SafeLLaVA model is a **Derivative Work** based on the LLaVA-v1.5 model.
66
+
67
+ Therefore, the use, reproduction, modification, and distribution of the entire SafeLLaVA model (including LLaVA-based components and weights) are subject to the original **"Usage and License Notices"** from the LLaVA-v1.5 repository.
68
+
69
+ Any user of SafeLLaVA must agree to and comply with these upstream notices, which are reproduced below:
70
+
71
+ ### Usage and License Notices (from [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file))
72
+
73
+ > This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the **OpenAI Terms of Use** for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. **Llama community license** for LLaMA-2 and **Vicuna-v1.5**). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
74
+
75
+ ### Implications for Users
76
+
77
+ This means that any user of SafeLLaVA is responsible for independently verifying and complying with all applicable upstream licenses, which may include (but are not limited to):
78
+
79
+ - **The OpenAI Terms of Use**
80
+ [https://openai.com/policies/terms-of-use](https://openai.com/policies/terms-of-use)
81
+
82
+ - **The Llama 2 Community License**
83
+ [https://ai.meta.com/llama/license/](https://ai.meta.com/llama/license/)
84
+
85
+ ---
86
+
87
+ ## Part 3: Attribution and Contact
88
+
89
+ This SafeLLaVA model was developed by the **Electronics and Telecommunications Research Institute (ETRI)** in the Republic of Korea.
90
+
91
+ For any questions regarding the SafeLLaVA model or its licensing, please contact:
92
+
93
+ **Youngwan Lee**
94
README.md ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: llava-v1.5
3
+ tags:
4
+ - vision
5
+ - multimodal
6
+ - safety
7
+ - content-moderation
8
+ - llava
9
+ - image-classification
10
+ - vision-language
11
+ language:
12
+ - en
13
+ pipeline_tag: image-text-to-text
14
+ library_name: transformers
15
+ ---
16
+
17
+ # SafeLLaVA-13B: Vision-Language Model with Visual Guard Module
18
+
19
+ [**🌐 Website**](https://youngwanlee.github.io/holisafe) | [**📑 Paper**](https://www.arxiv.org/pdf/2506.04704)
20
+
21
+
22
+
23
+ <div align="center">
24
+ <img src="https://dl.dropbox.com/scl/fi/soi772p6sig2tx16f092o/arch.jpg?rlkey=uj4ver4pp889oowigqld502hc&dl=1" width="1024px" />
25
+ </div>
26
+
27
+ SafeLLaVA-13B is a safe multimodal large language model that extends [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA) with built-in image safety classification capabilities. It can simultaneously generate text responses to visual questions while classifying potentially unsafe image content across 20 safety categories.
28
+
29
+ ## Model Description
30
+
31
+ - **Base Model**: LLaVA-v1.5-13B
32
+ - **Architecture**: Vision-language model with Visual Guard Module (VGM)
33
+ - **Training Data**: HoliSafe train set
34
+ - **Training Method**: LoRA fine-tuning
35
+ - **Parameters**: 13B (base) + 8.43M (VGM)
36
+ - **Safety Categories**: 20 categories based on HoliSafe taxonomy
37
+
38
+ ## Key Features
39
+
40
+ 1. **Multimodal Understanding**: Processes images and text for comprehensive visual understanding
41
+ 2. **Safety Classification**: Identifies unsafe content in images across 20 categories
42
+ 3. **Non-invasive Architecture**: Maintains full LLaVA-v1.5-13B capabilities while adding safety features
43
+ 4. **End-to-end Training**: VGM is jointly trained with the vision-language model
44
+
45
+ The model classifies images into the following 20 safety categories:
46
+
47
+ | Category ID | Category Name |
48
+ |------------|---------------|
49
+ | 0 | Safe |
50
+ | 1 | Gender discrimination |
51
+ | 2 | Race discrimination |
52
+ | 3 | Religion discrimination |
53
+ | 4 | Harassment |
54
+ | 5 | Disability discrimination |
55
+ | 6 | Drug Related Hazards |
56
+ | 7 | Property crime |
57
+ | 8 | Facial data exposure |
58
+ | 9 | Identity data exposure |
59
+ | 10 | Physical self-injury |
60
+ | 11 | Suicide |
61
+ | 12 | Animal abuse |
62
+ | 13 | Obscene gestures |
63
+ | 14 | Physical altercation |
64
+ | 15 | Terrorism |
65
+ | 16 | Weapon-related violence |
66
+ | 17 | Sexual content |
67
+ | 18 | Financial advice |
68
+ | 19 | Medical advice |
69
+
70
+ ## 🚀 Quick Start
71
+
72
+ ### Installation
73
+
74
+ ```bash
75
+ pip install torch transformers pillow accelerate requests
76
+ ```
77
+
78
+ ### Complete Example - Copy & Paste Ready
79
+
80
+ ```python
81
+ import requests
82
+ import torch
83
+ import sys
84
+ from pathlib import Path
85
+ from transformers import AutoModelForCausalLM, AutoTokenizer
86
+ from PIL import Image
87
+ from huggingface_hub import snapshot_download, hf_hub_download
88
+
89
+ # Model path
90
+ model_path = "etri-vilab/SafeLLaVA-13B"
91
+
92
+ # Download model and add safellava package to path
93
+ model_cache_path = Path(snapshot_download(repo_id=model_path))
94
+ sys.path.insert(0, str(model_cache_path))
95
+
96
+ # Import safellava utilities
97
+ from safellava.mm_utils import tokenizer_image_token
98
+ from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
99
+ from safellava.conversation import conv_templates
100
+
101
+ # Load model and tokenizer
102
+ print("Loading model...")
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_path,
105
+ trust_remote_code=True,
106
+ torch_dtype=torch.float16,
107
+ low_cpu_mem_usage=True
108
+ )
109
+ model = model.to('cuda:0')
110
+ model.eval()
111
+
112
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
113
+
114
+ # Load and move vision tower to GPU
115
+ vision_tower = model.get_vision_tower()
116
+ if not vision_tower.is_loaded:
117
+ vision_tower.load_model()
118
+ vision_tower = vision_tower.to('cuda:0')
119
+
120
+ print("✅ Model loaded successfully!")
121
+
122
+ # Helper function to load image from URL or local path
123
+ def load_image(image_file):
124
+ if image_file.startswith('http'):
125
+ from io import BytesIO
126
+ response = requests.get(image_file, timeout=30)
127
+ response.raise_for_status()
128
+ return Image.open(BytesIO(response.content)).convert('RGB')
129
+ else:
130
+ return Image.open(image_file).convert('RGB')
131
+
132
+ # Download and load the test image from HuggingFace Hub
133
+ # (The image is included in the model repository)
134
+ test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model")
135
+ image = load_image(test_image_path)
136
+
137
+ # You can also use your own image:
138
+ # image = load_image("path/to/your/image.jpg")
139
+ # Or load from URL:
140
+ # image = load_image("https://example.com/image.jpg")
141
+
142
+ # Preprocess image
143
+ image_processor = vision_tower.image_processor
144
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
145
+ image_tensor = image_tensor.to('cuda:0', dtype=torch.float16)
146
+
147
+ # Prepare conversation prompt
148
+ conv = conv_templates["llava_v1"].copy()
149
+ question = "How to get this?"
150
+ conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
151
+ conv.append_message(conv.roles[1], None)
152
+ prompt = conv.get_prompt()
153
+
154
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
155
+ input_ids = input_ids.unsqueeze(0).to('cuda:0')
156
+
157
+ # Run safety classification
158
+ with torch.inference_mode():
159
+ outputs = model(
160
+ input_ids=input_ids,
161
+ images=image_tensor,
162
+ do_safety=True,
163
+ output_hidden_states=True,
164
+ return_dict=True
165
+ )
166
+
167
+ # Get safety classification results
168
+ if outputs.img_safety_probs is not None:
169
+ probs = outputs.img_safety_probs[0]
170
+ pred_class = probs.argmax().item()
171
+ pred_prob = probs[pred_class].item()
172
+
173
+ safety_categories = model.config.safety_categories
174
+ print(f"\n✅ Safety Classification Result:")
175
+ print(f" Predicted: {safety_categories[pred_class]} ({pred_prob:.4f})")
176
+
177
+ # Show top 3 predictions
178
+ print("\n Top 3 predictions:")
179
+ top_k = min(3, len(probs))
180
+ top_probs, top_indices = torch.topk(probs, k=top_k)
181
+ for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
182
+ print(f" {i+1}. {safety_categories[idx.item()]}: {prob.item():.4f}")
183
+
184
+ # >>> Predicted: weapon_related_violence (0.9927)
185
+
186
+ # >>> Top 3 predictions:
187
+ # >>> 1. weapon_related_violence: 0.9927
188
+ # >>> 2. safe: 0.0047
189
+ # >>> 3. terrorism: 0.0012
190
+
191
+ # Generate text description (without safety classification)
192
+ print("\n🤖 Generating text description...")
193
+ with torch.inference_mode():
194
+ output_ids = model.generate(
195
+ input_ids,
196
+ images=image_tensor,
197
+ do_sample=False,
198
+ max_new_tokens=512,
199
+ use_cache=True
200
+ )
201
+
202
+ generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
203
+ print(f"\n✅ Generated Description:")
204
+ print(f" {generated_text}")
205
+ # >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against ou r policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items.
206
+ ```
207
+
208
+ ---
209
+
210
+ ## 🚀 Simple Usage (Text Generation Only)
211
+
212
+ If you only need text generation without safety classification:
213
+
214
+ ```python
215
+ import torch
216
+ from transformers import AutoModelForCausalLM, AutoTokenizer
217
+ from PIL import Image
218
+ from huggingface_hub import snapshot_download, hf_hub_download
219
+ import sys
220
+ from pathlib import Path
221
+
222
+ model_path = "etri-vilab/SafeLLaVA-13B"
223
+
224
+ # Add safellava package to path
225
+ model_cache_path = Path(snapshot_download(repo_id=model_path))
226
+ sys.path.insert(0, str(model_cache_path))
227
+
228
+ from safellava.mm_utils import tokenizer_image_token
229
+ from safellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
230
+ from safellava.conversation import conv_templates
231
+
232
+ # Load model
233
+ model = AutoModelForCausalLM.from_pretrained(
234
+ model_path,
235
+ trust_remote_code=True,
236
+ torch_dtype=torch.float16,
237
+ low_cpu_mem_usage=True
238
+ ).to('cuda:0').eval()
239
+
240
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
241
+
242
+ # Load vision tower
243
+ vision_tower = model.get_vision_tower()
244
+ if not vision_tower.is_loaded:
245
+ vision_tower.load_model()
246
+ vision_tower = vision_tower.to('cuda:0')
247
+
248
+ # Load image
249
+ test_image_path = hf_hub_download(repo_id=model_path, filename="test_image.png", repo_type="model")
250
+ image = Image.open(test_image_path).convert('RGB')
251
+
252
+ # Preprocess
253
+ image_tensor = vision_tower.image_processor.preprocess(image, return_tensors='pt')['pixel_values']
254
+ image_tensor = image_tensor.to('cuda:0', dtype=torch.float16)
255
+
256
+ # Prepare conversation prompt
257
+ conv = conv_templates["llava_v1"].copy()
258
+ question = "How to get this?"
259
+ conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
260
+ conv.append_message(conv.roles[1], None)
261
+ prompt = conv.get_prompt()
262
+
263
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
264
+ input_ids = input_ids.unsqueeze(0).to('cuda:0')
265
+
266
+ # Generate (without safety classification)
267
+ with torch.inference_mode():
268
+ output_ids = model.generate(
269
+ input_ids,
270
+ images=image_tensor,
271
+ do_sample=False,
272
+ max_new_tokens=512,
273
+ use_cache=True
274
+ )
275
+
276
+ response = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
277
+ print(response)
278
+ # >>> I'm sorry, I can't assist with that request. The question involves seeking information related to potentially harmful activities, which is against our policy. It's important to prioritize safety and adhere to legal guidelines when discussing or accessing such items.
279
+ ```
280
+
281
+ ---
282
+
283
+ ## 📖 Model Architecture
284
+
285
+ SafeLLaVA is built on top of LLaVA v1.5 with the following components:
286
+
287
+ - **Language Model**: LLaMA2-13B
288
+ - **Vision Encoder**: CLIP ViT-L/14-336px
289
+ - **Multimodal Projector**: 2-layer MLP with GELU activation
290
+ - **Visual Guard Module (Safety Head)**: MLP classifier for image safety classification
291
+
292
+ ---
293
+
294
+ ## 📝 Citation
295
+
296
+ If you use SafeLLaVA, please cite SafeLLaVA:
297
+
298
+ ```bibtex
299
+ @article{lee2025holisafe,
300
+ title={HoliSafe: Holistic Safety Benchmarking and Modeling for Vision-Language Model},
301
+ author={Lee, Youngwan and Kim, Kangsan and Park, Kwanyong and Jung, Ilcahe and Jang, Soojin and Lee, Seanie and Lee, Yong-Ju and Hwang, Sung Ju},
302
+ journal={arXiv preprint arXiv:2506.04704},
303
+ year={2025},
304
+ url={https://arxiv.org/abs/2506.04704},
305
+ archivePrefix={arXiv},
306
+ eprint={2506.04704},
307
+ primaryClass={cs.AI},
308
+ }
309
+ ```
310
+
311
+ ---
312
+
313
+ ## 📄 License
314
+
315
+ See [LICENSE](LICENSE.md) for details.
316
+
317
+ ---
318
+
319
+ ## 🙏 Acknowledgments
320
+
321
+ - Built on [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA)
322
+
323
+ This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2022-00187238, Development of Large Korean Language Model Technology for Efficient Pre-training, 45%), (No. 2022-0-00871, Development of AI Autonomy and Knowledge Enhancement for AI Agent Collaboration, 45%) and (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST), 45%).
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30b19ff7d43c0f59522f07d280755e0cd47f36007cb5f3d63ce01d7600fcb7e7
3
+ size 4978265728
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f067f75a031295b4ee0e14da3990c31fa2c5a0654b02154945ce31a2a333385c
3
+ size 4970422160
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d2cf9d000b4d06888153032381e7607641e9e76317b81d100cc042618bfb574
3
+ size 4970422184
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f04bd9e776835ab2724793be821ce0ed764c5447631216df4bc8dc19076af3d5
3
+ size 4933701432
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5da5854596848ba5a512f193bd331185eb3efa911179c091ed81153a6567682d
3
+ size 4933722144
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff21d3d163affb34c2ad4d7021c93e014f301eb138c0f30493bc171ce0c5c9f3
3
+ size 1941570592
modeling_safellava.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SafeLLaVA Model for HuggingFace Hub
3
+
4
+ This model is based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
5
+ Licensed under Apache License 2.0
6
+
7
+ SafeLLaVA adds image safety classification capabilities to LLaVA.
8
+ """
9
+
10
+ # Re-export classes from safellava package for HuggingFace auto_map
11
+ from safellava.model.language_model.safe_llava_llama import (
12
+ SafetyConfig,
13
+ SafeLlavaLlamaForCausalLM,
14
+ SafetyCausalLMOutputWithPast,
15
+ )
16
+
17
+ __all__ = [
18
+ "SafetyConfig",
19
+ "SafeLlavaLlamaForCausalLM",
20
+ "SafetyCausalLMOutputWithPast",
21
+ ]
safellava/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from .model.language_model.llava_llama import LlavaLlamaForCausalLM
9
+ from .model.language_model.safe_llava_llama import SafeLlavaLlamaForCausalLM
10
+
11
+ __all__ = ['LlavaLlamaForCausalLM', 'SafeLlavaLlamaForCausalLM']
safellava/constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
9
+ WORKER_HEART_BEAT_INTERVAL = 15
10
+
11
+ LOGDIR = "."
12
+
13
+ # Model Constants
14
+ IGNORE_INDEX = -100
15
+ IMAGE_TOKEN_INDEX = -200
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
18
+ DEFAULT_IM_START_TOKEN = "<im_start>"
19
+ DEFAULT_IM_END_TOKEN = "<im_end>"
20
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
safellava/conversation.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ """
9
+ Conversation prompts for SafeLLaVA.
10
+
11
+ This is a simplified version containing only the llava_v1 conversation template.
12
+ """
13
+
14
+ import dataclasses
15
+ from enum import auto, Enum
16
+ from typing import List
17
+
18
+
19
+ class SeparatorStyle(Enum):
20
+ """Different separator style."""
21
+ TWO = auto()
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class Conversation:
26
+ """A class that keeps all conversation history."""
27
+ system: str
28
+ roles: List[str]
29
+ messages: List[List[str]]
30
+ offset: int
31
+ sep_style: SeparatorStyle = SeparatorStyle.TWO
32
+ sep: str = " "
33
+ sep2: str = "</s>"
34
+ version: str = "v1"
35
+
36
+ def get_prompt(self):
37
+ """Generate the full prompt from conversation history."""
38
+ messages = self.messages
39
+
40
+ # Handle image token in first message
41
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
42
+ messages = self.messages.copy()
43
+ init_role, init_msg = messages[0].copy()
44
+ init_msg = init_msg[0].replace("<image>", "").strip()
45
+ messages[0] = (init_role, "<image>\n" + init_msg)
46
+
47
+ if self.sep_style == SeparatorStyle.TWO:
48
+ seps = [self.sep, self.sep2]
49
+ ret = self.system + seps[0]
50
+ for i, (role, message) in enumerate(messages):
51
+ if message:
52
+ if type(message) is tuple:
53
+ message, _, _ = message
54
+ ret += role + ": " + message + seps[i % 2]
55
+ else:
56
+ ret += role + ":"
57
+ return ret
58
+ else:
59
+ raise ValueError(f"Invalid style: {self.sep_style}")
60
+
61
+ def append_message(self, role, message):
62
+ """Append a message to conversation history."""
63
+ self.messages.append([role, message])
64
+
65
+ def copy(self):
66
+ """Create a copy of this conversation."""
67
+ return Conversation(
68
+ system=self.system,
69
+ roles=self.roles,
70
+ messages=[[x, y] for x, y in self.messages],
71
+ offset=self.offset,
72
+ sep_style=self.sep_style,
73
+ sep=self.sep,
74
+ sep2=self.sep2,
75
+ version=self.version,
76
+ )
77
+
78
+
79
+ # LLaVA v1 conversation template
80
+ conv_llava_v1 = Conversation(
81
+ system="A chat between a curious human and an artificial intelligence assistant. "
82
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
83
+ roles=("USER", "ASSISTANT"),
84
+ version="v1",
85
+ messages=(),
86
+ offset=0,
87
+ sep_style=SeparatorStyle.TWO,
88
+ sep=" ",
89
+ sep2="</s>",
90
+ )
91
+
92
+ # Available conversation templates
93
+ conv_templates = {
94
+ "llava_v1": conv_llava_v1,
95
+ "default": conv_llava_v1,
96
+ }
safellava/mm_utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import base64
11
+ import torch
12
+ import math
13
+ import ast
14
+
15
+ from transformers import StoppingCriteria
16
+ from safellava.constants import IMAGE_TOKEN_INDEX
17
+
18
+
19
+ def select_best_resolution(original_size, possible_resolutions):
20
+ """
21
+ Selects the best resolution from a list of possible resolutions based on the original size.
22
+
23
+ Args:
24
+ original_size (tuple): The original size of the image in the format (width, height).
25
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
26
+
27
+ Returns:
28
+ tuple: The best fit resolution in the format (width, height).
29
+ """
30
+ original_width, original_height = original_size
31
+ best_fit = None
32
+ max_effective_resolution = 0
33
+ min_wasted_resolution = float('inf')
34
+
35
+ for width, height in possible_resolutions:
36
+ scale = min(width / original_width, height / original_height)
37
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
38
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
39
+ wasted_resolution = (width * height) - effective_resolution
40
+
41
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
42
+ max_effective_resolution = effective_resolution
43
+ min_wasted_resolution = wasted_resolution
44
+ best_fit = (width, height)
45
+
46
+ return best_fit
47
+
48
+
49
+ def resize_and_pad_image(image, target_resolution):
50
+ """
51
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
52
+
53
+ Args:
54
+ image (PIL.Image.Image): The input image.
55
+ target_resolution (tuple): The target resolution (width, height) of the image.
56
+
57
+ Returns:
58
+ PIL.Image.Image: The resized and padded image.
59
+ """
60
+ original_width, original_height = image.size
61
+ target_width, target_height = target_resolution
62
+
63
+ scale_w = target_width / original_width
64
+ scale_h = target_height / original_height
65
+
66
+ if scale_w < scale_h:
67
+ new_width = target_width
68
+ new_height = min(math.ceil(original_height * scale_w), target_height)
69
+ else:
70
+ new_height = target_height
71
+ new_width = min(math.ceil(original_width * scale_h), target_width)
72
+
73
+ resized_image = image.resize((new_width, new_height))
74
+
75
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
76
+ paste_x = (target_width - new_width) // 2
77
+ paste_y = (target_height - new_height) // 2
78
+ new_image.paste(resized_image, (paste_x, paste_y))
79
+
80
+ return new_image
81
+
82
+
83
+ def divide_to_patches(image, patch_size):
84
+ """
85
+ Divides an image into patches of a specified size.
86
+
87
+ Args:
88
+ image (PIL.Image.Image): The input image.
89
+ patch_size (int): The size of each patch.
90
+
91
+ Returns:
92
+ list: A list of PIL.Image.Image objects representing the patches.
93
+ """
94
+ patches = []
95
+ width, height = image.size
96
+ for i in range(0, height, patch_size):
97
+ for j in range(0, width, patch_size):
98
+ box = (j, i, j + patch_size, i + patch_size)
99
+ patch = image.crop(box)
100
+ patches.append(patch)
101
+
102
+ return patches
103
+
104
+
105
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
106
+ """
107
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
108
+
109
+ Args:
110
+ image_size (tuple): The size of the input image in the format (width, height).
111
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
112
+ patch_size (int): The size of each image patch.
113
+
114
+ Returns:
115
+ tuple: The shape of the image patch grid in the format (width, height).
116
+ """
117
+ if type(grid_pinpoints) is list:
118
+ possible_resolutions = grid_pinpoints
119
+ else:
120
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
121
+ width, height = select_best_resolution(image_size, possible_resolutions)
122
+ return width // patch_size, height // patch_size
123
+
124
+
125
+ def process_anyres_image(image, processor, grid_pinpoints):
126
+ """
127
+ Process an image with variable resolutions.
128
+
129
+ Args:
130
+ image (PIL.Image.Image): The input image to be processed.
131
+ processor: The image processor object.
132
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
133
+
134
+ Returns:
135
+ torch.Tensor: A tensor containing the processed image patches.
136
+ """
137
+ if type(grid_pinpoints) is list:
138
+ possible_resolutions = grid_pinpoints
139
+ else:
140
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
141
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
142
+ image_padded = resize_and_pad_image(image, best_resolution)
143
+
144
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
145
+
146
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
147
+
148
+ image_patches = [image_original_resize] + patches
149
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
150
+ for image_patch in image_patches]
151
+ return torch.stack(image_patches, dim=0)
152
+
153
+
154
+ def load_image_from_base64(image):
155
+ return Image.open(BytesIO(base64.b64decode(image)))
156
+
157
+
158
+ def expand2square(pil_img, background_color):
159
+ width, height = pil_img.size
160
+ if width == height:
161
+ return pil_img
162
+ elif width > height:
163
+ result = Image.new(pil_img.mode, (width, width), background_color)
164
+ result.paste(pil_img, (0, (width - height) // 2))
165
+ return result
166
+ else:
167
+ result = Image.new(pil_img.mode, (height, height), background_color)
168
+ result.paste(pil_img, ((height - width) // 2, 0))
169
+ return result
170
+
171
+
172
+ def process_images(images, image_processor, model_cfg):
173
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
174
+ new_images = []
175
+ if image_aspect_ratio == 'pad':
176
+ for image in images:
177
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
178
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
179
+ new_images.append(image)
180
+ elif image_aspect_ratio == "anyres":
181
+ for image in images:
182
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
183
+ new_images.append(image)
184
+ else:
185
+ return image_processor(images, return_tensors='pt')['pixel_values']
186
+ if all(x.shape == new_images[0].shape for x in new_images):
187
+ new_images = torch.stack(new_images, dim=0)
188
+ return new_images
189
+
190
+
191
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
192
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
193
+
194
+ def insert_separator(X, sep):
195
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
196
+
197
+ input_ids = []
198
+ offset = 0
199
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
200
+ offset = 1
201
+ input_ids.append(prompt_chunks[0][0])
202
+
203
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
204
+ input_ids.extend(x[offset:])
205
+
206
+ if return_tensors is not None:
207
+ if return_tensors == 'pt':
208
+ return torch.tensor(input_ids, dtype=torch.long)
209
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
210
+ return input_ids
211
+
212
+
213
+ def get_model_name_from_path(model_path):
214
+ model_path = model_path.strip("/")
215
+ model_paths = model_path.split("/")
216
+ if model_paths[-1].startswith('checkpoint-'):
217
+ return model_paths[-2] + "_" + model_paths[-1]
218
+ else:
219
+ return model_paths[-1]
220
+
221
+
222
+ class KeywordsStoppingCriteria(StoppingCriteria):
223
+ def __init__(self, keywords, tokenizer, input_ids):
224
+ self.keywords = keywords
225
+ self.keyword_ids = []
226
+ self.max_keyword_len = 0
227
+ for keyword in keywords:
228
+ cur_keyword_ids = tokenizer(keyword).input_ids
229
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
230
+ cur_keyword_ids = cur_keyword_ids[1:]
231
+ if len(cur_keyword_ids) > self.max_keyword_len:
232
+ self.max_keyword_len = len(cur_keyword_ids)
233
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
234
+ self.tokenizer = tokenizer
235
+ self.start_len = input_ids.shape[1]
236
+
237
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
238
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
239
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
240
+ for keyword_id in self.keyword_ids:
241
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
242
+ if torch.equal(truncated_output_ids, keyword_id):
243
+ return True
244
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
245
+ for keyword in self.keywords:
246
+ if keyword in outputs:
247
+ return True
248
+ return False
249
+
250
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
251
+ outputs = []
252
+ for i in range(output_ids.shape[0]):
253
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
254
+ return all(outputs)
safellava/model/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
9
+ from .language_model.safe_llava_llama import SafeLlavaLlamaForCausalLM
10
+
11
+ __all__ = ['LlavaLlamaForCausalLM', 'SafeLlavaLlamaForCausalLM', 'LlavaConfig']
safellava/model/language_model/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from .llava_llama import LlavaLlamaForCausalLM, LlavaConfig
9
+ from .safe_llava_llama import SafeLlavaLlamaForCausalLM, SafetyConfig
10
+
11
+ __all__ = ['LlavaLlamaForCausalLM', 'LlavaConfig', 'SafeLlavaLlamaForCausalLM', 'SafetyConfig']
safellava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ # Copyright 2023 Haotian Liu
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from transformers import AutoConfig, AutoModelForCausalLM, \
29
+ LlamaConfig, LlamaModel, LlamaForCausalLM
30
+
31
+ from transformers.modeling_outputs import CausalLMOutputWithPast
32
+ from transformers.generation.utils import GenerateOutput
33
+
34
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
35
+
36
+
37
+ class LlavaConfig(LlamaConfig):
38
+ model_type = "llava_llama"
39
+
40
+
41
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config: LlamaConfig):
45
+ super(LlavaLlamaModel, self).__init__(config)
46
+
47
+
48
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
49
+ config_class = LlavaConfig
50
+
51
+ def __init__(self, config):
52
+ super(LlamaForCausalLM, self).__init__(config)
53
+ self.model = LlavaLlamaModel(config)
54
+ self.pretraining_tp = config.pretraining_tp
55
+ self.vocab_size = config.vocab_size
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ cache_position: Optional[int] = None,
73
+ output_attentions: Optional[bool] = None,
74
+ output_hidden_states: Optional[bool] = None,
75
+ images: Optional[torch.FloatTensor] = None,
76
+ image_sizes: Optional[List[List[int]]] = None,
77
+ return_dict: Optional[bool] = None,
78
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
79
+
80
+ if inputs_embeds is None:
81
+ (
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ inputs_embeds,
87
+ labels
88
+ ) = self.prepare_inputs_labels_for_multimodal(
89
+ input_ids,
90
+ position_ids,
91
+ attention_mask,
92
+ past_key_values,
93
+ labels,
94
+ images,
95
+ image_sizes
96
+ )
97
+
98
+ return super().forward(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ position_ids=position_ids,
102
+ past_key_values=past_key_values,
103
+ inputs_embeds=inputs_embeds,
104
+ labels=labels,
105
+ use_cache=use_cache,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ return_dict=return_dict
109
+ )
110
+
111
+ @torch.no_grad()
112
+ def generate(
113
+ self,
114
+ inputs: Optional[torch.Tensor] = None,
115
+ images: Optional[torch.Tensor] = None,
116
+ image_sizes: Optional[torch.Tensor] = None,
117
+ **kwargs,
118
+ ) -> Union[GenerateOutput, torch.LongTensor]:
119
+ position_ids = kwargs.pop("position_ids", None)
120
+ attention_mask = kwargs.pop("attention_mask", None)
121
+ if "inputs_embeds" in kwargs:
122
+ raise NotImplementedError("`inputs_embeds` is not supported")
123
+
124
+ if images is not None:
125
+ (
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ _,
130
+ inputs_embeds,
131
+ _
132
+ ) = self.prepare_inputs_labels_for_multimodal(
133
+ inputs,
134
+ position_ids,
135
+ attention_mask,
136
+ None,
137
+ None,
138
+ images,
139
+ image_sizes=image_sizes
140
+ )
141
+ else:
142
+ inputs_embeds = self.get_model().embed_tokens(inputs)
143
+
144
+ return super().generate(
145
+ position_ids=position_ids,
146
+ attention_mask=attention_mask,
147
+ inputs_embeds=inputs_embeds,
148
+ **kwargs
149
+ )
150
+
151
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
152
+ inputs_embeds=None, **kwargs):
153
+ images = kwargs.pop("images", None)
154
+ image_sizes = kwargs.pop("image_sizes", None)
155
+ inputs = super().prepare_inputs_for_generation(
156
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
157
+ )
158
+ if images is not None:
159
+ inputs['images'] = images
160
+ if image_sizes is not None:
161
+ inputs['image_sizes'] = image_sizes
162
+ return inputs
163
+
164
+
165
+ AutoConfig.register("llava_llama", LlavaConfig)
166
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
safellava/model/language_model/safe_llava_llama.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from typing import List, Optional, Tuple, Union, Dict
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import AutoConfig, AutoModelForCausalLM
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+ from safellava.model.language_model.llava_llama import (
16
+ LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
17
+ )
18
+ from safellava.constants import IMAGE_TOKEN_INDEX
19
+
20
+ from dataclasses import dataclass
21
+
22
+ import logging
23
+ from safellava.utils import setup_simple_logging
24
+
25
+ setup_simple_logging()
26
+
27
+
28
+ @dataclass
29
+ class SafetyCausalLMOutputWithPast(CausalLMOutputWithPast):
30
+ """
31
+ Base class for causal language model (or autoregressive) outputs with safety predictions.
32
+ """
33
+ img_safety_logits: Optional[torch.FloatTensor] = None
34
+ img_safety_probs: Optional[torch.FloatTensor] = None
35
+ txt_safety_logits: Optional[torch.FloatTensor] = None
36
+ txt_safety_probs: Optional[torch.FloatTensor] = None
37
+ total_safety_logits: Optional[torch.FloatTensor] = None
38
+ total_safety_probs: Optional[torch.FloatTensor] = None
39
+
40
+
41
+ class SafetyMLP(nn.Module):
42
+ """
43
+ Safety classification head implemented as Multi-layer Perceptron.
44
+ """
45
+
46
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
47
+ safety_num_hidden_layers: int = 1):
48
+ super().__init__()
49
+
50
+ layers = []
51
+
52
+ layers.append(nn.Linear(input_size, hidden_size))
53
+ layers.append(nn.GELU())
54
+
55
+ for _ in range(safety_num_hidden_layers - 1):
56
+ layers.append(nn.Linear(hidden_size, hidden_size))
57
+ layers.append(nn.GELU())
58
+
59
+ layers.append(nn.Linear(hidden_size, output_size))
60
+
61
+ self.mlp = nn.Sequential(*layers)
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ return self.mlp(x)
65
+
66
+
67
+ class SafetyConfig(LlavaConfig):
68
+ """Safety-aware configuration for pooling version """
69
+ model_type = "safe_llava_llama"
70
+
71
+ def __init__(
72
+ self,
73
+ safety_categories=None,
74
+ safety_num_hidden_layers=1,
75
+ unfreeze_mm_vision_tower=True,
76
+ delay_load_vision_tower=False,
77
+ safety_head_hidden_scale=4.0,
78
+ pooling_method="mean", # mean, max, or cls
79
+ attention_dropout=0.0, # Add missing attribute for compatibility
80
+ **kwargs
81
+ ):
82
+ # Ensure attention_dropout is in kwargs if not provided
83
+ if 'attention_dropout' not in kwargs:
84
+ kwargs['attention_dropout'] = attention_dropout
85
+
86
+ super().__init__(**kwargs)
87
+
88
+ # Default safety categories if not provided (from original SafeLLaVA)
89
+ self.safety_categories = safety_categories or [
90
+ "safe",
91
+ "gender",
92
+ "race",
93
+ "religion",
94
+ "harassment",
95
+ "disability_discrimination",
96
+ "drug_crime",
97
+ "property_crime",
98
+ "facial_data",
99
+ "identity_data",
100
+ "physical_self_injury",
101
+ "suicide",
102
+ "animal_abuse",
103
+ "obscene_gestures",
104
+ "physical_altercation",
105
+ "terrorism",
106
+ "weapon_related_violence",
107
+ "sexual_content",
108
+ "financial_advice",
109
+ "medical_advice"
110
+ ]
111
+
112
+ self.safety_num_hidden_layers = safety_num_hidden_layers
113
+ self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
114
+ self.delay_load_vision_tower = delay_load_vision_tower
115
+ self.safety_head_hidden_scale = safety_head_hidden_scale
116
+ self.pooling_method = pooling_method
117
+
118
+ # self.use_img_safety_meta_token = False
119
+ self.use_txt_safety_meta_token = False
120
+ self.use_total_safety_meta_token = False
121
+
122
+
123
+ class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
124
+ """
125
+ SafeLLaVA: A simplified version Uses pooled visual features for safety classification.
126
+ """
127
+
128
+ config_class = SafetyConfig
129
+ _keys_to_ignore_on_load_unexpected = [] # Don't ignore img_safety_head weights
130
+
131
+ def __init__(self, config: SafetyConfig):
132
+ super().__init__(config)
133
+
134
+ # Safety head for image classification (using pooled visual tokens)
135
+ self.img_safety_head = SafetyMLP(
136
+ input_size=config.hidden_size,
137
+ hidden_size=int(config.hidden_size * config.safety_head_hidden_scale),
138
+ output_size=len(config.safety_categories),
139
+ safety_num_hidden_layers=config.safety_num_hidden_layers
140
+ )
141
+ logging.info("Created img_safety_head for SafeLLaVA")
142
+
143
+ # Store pooling method
144
+ self.pooling_method = config.pooling_method
145
+
146
+ # Safety warning template
147
+ self.safety_warning_template = (
148
+ "I apologize, but I cannot provide a response as the content appears to be {category}. "
149
+ "I aim to maintain ethical and safe interactions. "
150
+ "Please feel free to ask about other topics that do not involve potentially harmful or inappropriate content."
151
+ )
152
+
153
+ def get_model(self):
154
+ return self.model
155
+
156
+ @classmethod
157
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
158
+ """
159
+ Custom from_pretrained to properly load img_safety_head weights.
160
+ """
161
+ import os
162
+ import torch
163
+ from pathlib import Path
164
+
165
+ # Load the model normally first
166
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
167
+
168
+ # List of original LLaVA model names
169
+ original_llava_models = [
170
+ "liuhaotian/llava-v1.5-7b",
171
+ "liuhaotian/llava-v1.5-13b",
172
+ ]
173
+
174
+ is_original_llava = any(str(pretrained_model_name_or_path).startswith(name) for name in original_llava_models)
175
+
176
+ # Load safety head weights for SafeLLaVA models
177
+ if not is_original_llava:
178
+ logging.info(f"Detected SafeLLaVA model: {pretrained_model_name_or_path}")
179
+ model_path = Path(pretrained_model_name_or_path)
180
+
181
+ # Handle both local paths and HuggingFace Hub
182
+ if not model_path.exists():
183
+ # Try HuggingFace cache
184
+ from huggingface_hub import snapshot_download
185
+ try:
186
+ model_path = Path(snapshot_download(repo_id=str(pretrained_model_name_or_path)))
187
+ logging.info(f"Downloaded from HuggingFace Hub to: {model_path}")
188
+ except Exception as e:
189
+ logging.warning(f"Could not download from Hub: {e}")
190
+ return model
191
+
192
+ if model_path.exists():
193
+ # Load safety head weights from safetensors
194
+ safetensors_index_path = model_path / "model.safetensors.index.json"
195
+ if safetensors_index_path.exists():
196
+ logging.info("Loading safety head weights from safetensors...")
197
+ from safetensors.torch import load_file
198
+ import json
199
+
200
+ # Load the index file
201
+ with open(safetensors_index_path, 'r') as f:
202
+ index_data = json.load(f)
203
+
204
+ # Load all safetensors files and collect safety head weights
205
+ safety_weights = {}
206
+ for weight_map in set(index_data.get('weight_map', {}).values()):
207
+ safetensors_file = model_path / weight_map
208
+ if safetensors_file.exists():
209
+ file_weights = load_file(str(safetensors_file))
210
+ # Extract only img_safety_head weights
211
+ for key, value in file_weights.items():
212
+ if key.startswith('img_safety_head.'):
213
+ safety_weights[key] = value
214
+
215
+ if safety_weights:
216
+ logging.info(f"Found {len(safety_weights)} img_safety_head weights")
217
+ # Load the weights
218
+ missing_keys, unexpected_keys = model.load_state_dict(safety_weights, strict=False)
219
+ logging.info("✅ Safety head weights loaded successfully")
220
+ else:
221
+ logging.warning("⚠️ No img_safety_head weights found in checkpoint")
222
+ else:
223
+ logging.warning(f"No safetensors index found at {safetensors_index_path}")
224
+ else:
225
+ logging.warning(f"Model path does not exist: {model_path}")
226
+
227
+ return model
228
+
229
+ def get_safety_warning(self, unsafe_categories):
230
+ if len(unsafe_categories) == 1:
231
+ category_str = f"related to {unsafe_categories[0]}"
232
+ else:
233
+ category_str = "related to " + ", ".join(unsafe_categories[:-1]) + f" and {unsafe_categories[-1]}"
234
+ return self.safety_warning_template.format(category=category_str)
235
+
236
+ def pool_visual_tokens(self, hidden_states, input_ids, images):
237
+ """
238
+ Pool visual tokens from hidden states.
239
+
240
+ Args:
241
+ hidden_states: Last layer hidden states [batch_size, seq_len, hidden_size]
242
+ input_ids: Original input token IDs to locate image positions
243
+ images: Input images tensor
244
+
245
+ Returns:
246
+ Pooled visual features [batch_size, hidden_size]
247
+ """
248
+ batch_size = hidden_states.shape[0]
249
+ device = hidden_states.device
250
+
251
+ # If no images, return zeros
252
+ if images is None:
253
+ return torch.zeros(batch_size, hidden_states.shape[-1], device=device)
254
+
255
+ # Get the number of visual patches
256
+ vision_tower = self.get_vision_tower()
257
+ if vision_tower is not None and hasattr(vision_tower, 'config'):
258
+ # Calculate based on vision config
259
+ image_size = vision_tower.config.image_size
260
+ patch_size = vision_tower.config.patch_size
261
+ num_patches = (image_size // patch_size) ** 2
262
+ else:
263
+ num_patches = 576 # Default for CLIP ViT-L/14-336px
264
+
265
+ pooled_features = []
266
+
267
+ for batch_idx in range(batch_size):
268
+ try:
269
+ # Find where IMAGE_TOKEN_INDEX was in the original input
270
+ if input_ids is not None and batch_idx < input_ids.shape[0]:
271
+ image_positions = torch.where(input_ids[batch_idx] == IMAGE_TOKEN_INDEX)[0]
272
+
273
+ if len(image_positions) > 0:
274
+ # Visual tokens replace the IMAGE_TOKEN_INDEX
275
+ # The actual visual tokens start at this position
276
+ start_pos = image_positions[0].item()
277
+ end_pos = min(start_pos + num_patches, hidden_states.shape[1])
278
+
279
+ if end_pos > start_pos and (end_pos - start_pos) > 0:
280
+ visual_embeddings = hidden_states[batch_idx, start_pos:end_pos]
281
+
282
+ # Apply pooling
283
+ if visual_embeddings.shape[0] > 0:
284
+ if self.pooling_method == "mean":
285
+ pooled = visual_embeddings.mean(dim=0)
286
+ elif self.pooling_method == "max":
287
+ pooled = visual_embeddings.max(dim=0)[0]
288
+ elif self.pooling_method == "cls":
289
+ # Use the first visual token
290
+ pooled = visual_embeddings[0]
291
+ else:
292
+ pooled = visual_embeddings.mean(dim=0) # Default to mean
293
+
294
+ pooled_features.append(pooled)
295
+ else:
296
+ # Empty visual embeddings
297
+ pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
298
+ else:
299
+ # Invalid range
300
+ pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
301
+ else:
302
+ # No image token found, might be text-only sample
303
+ pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
304
+ else:
305
+ # No input_ids available
306
+ pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
307
+
308
+ except Exception as e:
309
+ logging.warning(f"Error pooling visual tokens for batch {batch_idx}: {str(e)}")
310
+ # Return zero vector on error
311
+ pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
312
+
313
+ # Stack all pooled features
314
+ pooled_features = torch.stack(pooled_features, dim=0)
315
+ return pooled_features
316
+
317
+ def forward(
318
+ self,
319
+ input_ids=None,
320
+ attention_mask=None,
321
+ position_ids=None,
322
+ past_key_values=None,
323
+ inputs_embeds=None,
324
+ labels=None,
325
+ use_cache=None,
326
+ output_attentions=None,
327
+ output_hidden_states=None,
328
+ images=None,
329
+ image_sizes=None,
330
+ return_dict=None,
331
+ do_safety=False,
332
+ **kwargs,
333
+ ) -> Union[Tuple, CausalLMOutputWithPast, SafetyCausalLMOutputWithPast]:
334
+ """
335
+ Forward method for SafeLLaVA.
336
+ When do_safety=True, extracts and pools visual tokens for safety classification.
337
+ """
338
+
339
+ # Store original input_ids for finding image token positions
340
+ original_input_ids = input_ids.clone() if input_ids is not None else None
341
+
342
+ # If do_safety is True, force output_hidden_states to True
343
+ if do_safety and (output_hidden_states is not True):
344
+ output_hidden_states = True
345
+ return_dict = True
346
+
347
+ # Prepare inputs for multimodal (handles image embedding)
348
+ if inputs_embeds is None:
349
+ (
350
+ input_ids,
351
+ position_ids,
352
+ attention_mask,
353
+ past_key_values,
354
+ inputs_embeds,
355
+ labels
356
+ ) = self.prepare_inputs_labels_for_multimodal(
357
+ input_ids,
358
+ position_ids,
359
+ attention_mask,
360
+ past_key_values,
361
+ labels,
362
+ images,
363
+ image_sizes
364
+ )
365
+
366
+ # Call parent's forward method
367
+ outputs = super(LlavaLlamaForCausalLM, self).forward(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ position_ids=position_ids,
371
+ past_key_values=past_key_values,
372
+ inputs_embeds=inputs_embeds,
373
+ labels=labels,
374
+ use_cache=use_cache,
375
+ output_attentions=output_attentions,
376
+ output_hidden_states=output_hidden_states,
377
+ return_dict=True,
378
+ **kwargs
379
+ )
380
+
381
+ # If do_safety=False, just return the outputs
382
+ if not do_safety:
383
+ if return_dict is False:
384
+ return (outputs.loss, outputs.logits, outputs.past_key_values,
385
+ outputs.hidden_states, outputs.attentions)
386
+ return outputs
387
+
388
+ # Safety classification using pooled visual tokens
389
+ hidden_states = outputs.hidden_states[-1] # Last layer hidden states
390
+
391
+ # Check if we have images to process
392
+ if images is None:
393
+ # No images, return outputs without safety
394
+ return outputs
395
+
396
+ # Pool visual tokens
397
+ pooled_visual_features = self.pool_visual_tokens(hidden_states, original_input_ids, images)
398
+
399
+ # Pass through safety head
400
+ img_safety_logits = self.img_safety_head(pooled_visual_features)
401
+ img_safety_probs = torch.softmax(img_safety_logits, dim=-1)
402
+
403
+ # Return results with safety outputs
404
+ if not return_dict:
405
+ return (outputs.loss, outputs.logits, outputs.past_key_values,
406
+ outputs.hidden_states, outputs.attentions,
407
+ img_safety_logits, img_safety_probs)
408
+
409
+ return SafetyCausalLMOutputWithPast(
410
+ loss=outputs.loss,
411
+ logits=outputs.logits,
412
+ past_key_values=outputs.past_key_values,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
+ img_safety_logits=img_safety_logits,
416
+ img_safety_probs=img_safety_probs,
417
+ txt_safety_logits=None, # Not used in Pool version
418
+ txt_safety_probs=None,
419
+ total_safety_logits=None,
420
+ total_safety_probs=None
421
+ )
422
+
423
+
424
+ # Register the model
425
+ AutoConfig.register("safe_llava_llama", SafetyConfig)
426
+ AutoModelForCausalLM.register(SafetyConfig, SafeLlavaLlamaForCausalLM)
safellava/model/llava_arch.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ # Copyright 2023 Haotian Liu
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from abc import ABC, abstractmethod
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from .multimodal_encoder.builder import build_vision_tower
29
+ from .multimodal_projector.builder import build_vision_projector
30
+
31
+ from safellava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
32
+
33
+ from safellava.mm_utils import get_anyres_image_grid_shape
34
+
35
+
36
+ class LlavaMetaModel:
37
+
38
+ def __init__(self, config):
39
+ super(LlavaMetaModel, self).__init__(config)
40
+
41
+ if hasattr(config, "mm_vision_tower"):
42
+ delay_load = getattr(config, 'delay_load_vision_tower', False)
43
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
44
+ self.mm_projector = build_vision_projector(config)
45
+
46
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
47
+ self.image_newline = nn.Parameter(
48
+ torch.empty(config.hidden_size, dtype=self.dtype)
49
+ )
50
+
51
+ def get_vision_tower(self):
52
+ vision_tower = getattr(self, 'vision_tower', None)
53
+ if type(vision_tower) is list:
54
+ vision_tower = vision_tower[0]
55
+ return vision_tower
56
+
57
+ def initialize_vision_modules(self, model_args, fsdp=None):
58
+ vision_tower = model_args.vision_tower
59
+ mm_vision_select_layer = model_args.mm_vision_select_layer
60
+ mm_vision_select_feature = model_args.mm_vision_select_feature
61
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
62
+ mm_patch_merge_type = model_args.mm_patch_merge_type
63
+
64
+ self.config.mm_vision_tower = vision_tower
65
+
66
+ if self.get_vision_tower() is None:
67
+ vision_tower = build_vision_tower(model_args)
68
+
69
+ if fsdp is not None and len(fsdp) > 0:
70
+ self.vision_tower = [vision_tower]
71
+ else:
72
+ self.vision_tower = vision_tower
73
+ else:
74
+ if fsdp is not None and len(fsdp) > 0:
75
+ vision_tower = self.vision_tower[0]
76
+ else:
77
+ vision_tower = self.vision_tower
78
+ vision_tower.load_model()
79
+
80
+ self.config.use_mm_proj = True
81
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
82
+ self.config.mm_hidden_size = vision_tower.hidden_size
83
+ self.config.mm_vision_select_layer = mm_vision_select_layer
84
+ self.config.mm_vision_select_feature = mm_vision_select_feature
85
+ self.config.mm_patch_merge_type = mm_patch_merge_type
86
+
87
+ if getattr(self, 'mm_projector', None) is None:
88
+ self.mm_projector = build_vision_projector(self.config)
89
+
90
+ if 'unpad' in mm_patch_merge_type:
91
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
92
+ self.image_newline = nn.Parameter(
93
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
94
+ )
95
+ else:
96
+ for p in self.mm_projector.parameters():
97
+ p.requires_grad = True
98
+
99
+ if pretrain_mm_mlp_adapter is not None:
100
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
101
+ def get_w(weights, keyword):
102
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
103
+
104
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
105
+
106
+
107
+ def unpad_image(tensor, original_size):
108
+ """
109
+ Unpads a PyTorch tensor of a padded and resized image.
110
+
111
+ Args:
112
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
113
+ original_size (tuple): The original size of PIL image (width, height).
114
+
115
+ Returns:
116
+ torch.Tensor: The unpadded image tensor.
117
+ """
118
+ original_width, original_height = original_size
119
+ current_height, current_width = tensor.shape[1:]
120
+
121
+ original_aspect_ratio = original_width / original_height
122
+ current_aspect_ratio = current_width / current_height
123
+
124
+ if original_aspect_ratio > current_aspect_ratio:
125
+ scale_factor = current_width / original_width
126
+ new_height = int(original_height * scale_factor)
127
+ padding = (current_height - new_height) // 2
128
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
129
+ else:
130
+ scale_factor = current_height / original_height
131
+ new_width = int(original_width * scale_factor)
132
+ padding = (current_width - new_width) // 2
133
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
134
+
135
+ return unpadded_tensor
136
+
137
+
138
+ class LlavaMetaForCausalLM(ABC):
139
+
140
+ @abstractmethod
141
+ def get_model(self):
142
+ pass
143
+
144
+ def get_vision_tower(self):
145
+ return self.get_model().get_vision_tower()
146
+
147
+ def encode_images(self, images):
148
+ vision_tower = self.get_model().get_vision_tower()
149
+ image_features = vision_tower(images)
150
+ image_features = self.get_model().mm_projector(image_features)
151
+ return image_features
152
+
153
+ def prepare_inputs_labels_for_multimodal(
154
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
155
+ images, image_sizes=None
156
+ ):
157
+ vision_tower = self.get_vision_tower()
158
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
159
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
160
+
161
+ if type(images) is list or images.ndim == 5:
162
+ if type(images) is list:
163
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
164
+ concat_images = torch.cat([image for image in images], dim=0)
165
+ image_features = self.encode_images(concat_images)
166
+ split_sizes = [image.shape[0] for image in images]
167
+ image_features = torch.split(image_features, split_sizes, dim=0)
168
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
169
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
170
+ if mm_patch_merge_type == 'flat':
171
+ image_features = [x.flatten(0, 1) for x in image_features]
172
+ elif mm_patch_merge_type.startswith('spatial'):
173
+ new_image_features = []
174
+ for image_idx, image_feature in enumerate(image_features):
175
+ if image_feature.shape[0] > 1:
176
+ base_image_feature = image_feature[0]
177
+ image_feature = image_feature[1:]
178
+ height = width = self.get_vision_tower().num_patches_per_side
179
+ assert height * width == base_image_feature.shape[0]
180
+ if image_aspect_ratio == 'anyres':
181
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
182
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
183
+ else:
184
+ raise NotImplementedError
185
+ if 'unpad' in mm_patch_merge_type:
186
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
187
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
188
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
189
+ image_feature = torch.cat((
190
+ image_feature,
191
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
192
+ ), dim=-1)
193
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
194
+ else:
195
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
196
+ image_feature = image_feature.flatten(0, 3)
197
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
198
+ else:
199
+ image_feature = image_feature[0]
200
+ if 'unpad' in mm_patch_merge_type:
201
+ image_feature = torch.cat((
202
+ image_feature,
203
+ self.model.image_newline[None].to(image_feature.device)
204
+ ), dim=0)
205
+ new_image_features.append(image_feature)
206
+ image_features = new_image_features
207
+ else:
208
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
209
+ else:
210
+ image_features = self.encode_images(images)
211
+
212
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
213
+ raise NotImplementedError
214
+
215
+ _labels = labels
216
+ _position_ids = position_ids
217
+ _attention_mask = attention_mask
218
+ if attention_mask is None:
219
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
220
+ else:
221
+ attention_mask = attention_mask.bool()
222
+ if position_ids is None:
223
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
224
+ if labels is None:
225
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
226
+
227
+ _input_ids = input_ids
228
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
229
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
230
+
231
+ new_input_embeds = []
232
+ new_labels = []
233
+ cur_image_idx = 0
234
+ for batch_idx, cur_input_ids in enumerate(input_ids):
235
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
236
+ if num_images == 0:
237
+ cur_image_features = image_features[cur_image_idx]
238
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
239
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
240
+ new_input_embeds.append(cur_input_embeds)
241
+ new_labels.append(labels[batch_idx])
242
+ cur_image_idx += 1
243
+ continue
244
+
245
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
246
+ cur_input_ids_noim = []
247
+ cur_labels = labels[batch_idx]
248
+ cur_labels_noim = []
249
+ for i in range(len(image_token_indices) - 1):
250
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
251
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
252
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
253
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
254
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
255
+ cur_new_input_embeds = []
256
+ cur_new_labels = []
257
+
258
+ for i in range(num_images + 1):
259
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
260
+ cur_new_labels.append(cur_labels_noim[i])
261
+ if i < num_images:
262
+ cur_image_features = image_features[cur_image_idx]
263
+ cur_image_idx += 1
264
+ cur_new_input_embeds.append(cur_image_features)
265
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
266
+
267
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
268
+
269
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
270
+ cur_new_labels = torch.cat(cur_new_labels)
271
+
272
+ new_input_embeds.append(cur_new_input_embeds)
273
+ new_labels.append(cur_new_labels)
274
+
275
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
276
+ if tokenizer_model_max_length is not None:
277
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
278
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
279
+
280
+ max_len = max(x.shape[0] for x in new_input_embeds)
281
+ batch_size = len(new_input_embeds)
282
+
283
+ new_input_embeds_padded = []
284
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
285
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
286
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
287
+
288
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
289
+ cur_len = cur_new_embed.shape[0]
290
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
291
+ new_input_embeds_padded.append(torch.cat((
292
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
293
+ cur_new_embed
294
+ ), dim=0))
295
+ if cur_len > 0:
296
+ new_labels_padded[i, -cur_len:] = cur_new_labels
297
+ attention_mask[i, -cur_len:] = True
298
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
299
+ else:
300
+ new_input_embeds_padded.append(torch.cat((
301
+ cur_new_embed,
302
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
303
+ ), dim=0))
304
+ if cur_len > 0:
305
+ new_labels_padded[i, :cur_len] = cur_new_labels
306
+ attention_mask[i, :cur_len] = True
307
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
308
+
309
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
310
+
311
+ if _labels is None:
312
+ new_labels = None
313
+ else:
314
+ new_labels = new_labels_padded
315
+
316
+ if _attention_mask is None:
317
+ attention_mask = None
318
+ else:
319
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
320
+
321
+ if _position_ids is None:
322
+ position_ids = None
323
+
324
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
325
+
326
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
327
+ if model_args.mm_use_im_patch_token:
328
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
329
+ self.resize_token_embeddings(len(tokenizer))
330
+
331
+ if model_args.mm_use_im_start_end:
332
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
333
+ self.resize_token_embeddings(len(tokenizer))
334
+
335
+ if num_new_tokens > 0:
336
+ input_embeddings = self.get_input_embeddings().weight.data
337
+ output_embeddings = self.get_output_embeddings().weight.data
338
+
339
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
340
+ dim=0, keepdim=True)
341
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
342
+ dim=0, keepdim=True)
343
+
344
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
345
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
346
+
347
+ if model_args.tune_mm_mlp_adapter:
348
+ for p in self.get_input_embeddings().parameters():
349
+ p.requires_grad = True
350
+ for p in self.get_output_embeddings().parameters():
351
+ p.requires_grad = False
352
+
353
+ if model_args.pretrain_mm_mlp_adapter:
354
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
355
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
356
+ assert num_new_tokens == 2
357
+ if input_embeddings.shape == embed_tokens_weight.shape:
358
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
359
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
360
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
361
+ else:
362
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
363
+ elif model_args.mm_use_im_patch_token:
364
+ if model_args.tune_mm_mlp_adapter:
365
+ for p in self.get_input_embeddings().parameters():
366
+ p.requires_grad = False
367
+ for p in self.get_output_embeddings().parameters():
368
+ p.requires_grad = False
safellava/model/multimodal_encoder/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ """
9
+ Multimodal encoder module for SafeLLaVA
10
+ """
11
+
12
+ from .builder import build_vision_tower
13
+ from .clip_encoder import CLIPVisionTower
14
+
15
+ __all__ = ['build_vision_tower', 'CLIPVisionTower']
safellava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ """
9
+ Vision tower builder for SafeLLaVA
10
+ """
11
+
12
+ import os
13
+ from .clip_encoder import CLIPVisionTower
14
+
15
+
16
+ def build_vision_tower(vision_tower_cfg, custom_vision_tower_path=None, **kwargs):
17
+ """
18
+ Build vision tower from config.
19
+ SafeLLaVA uses CLIPVisionTower.
20
+ """
21
+ vision_tower = getattr(
22
+ vision_tower_cfg,
23
+ 'mm_vision_tower',
24
+ getattr(vision_tower_cfg, 'vision_tower', None)
25
+ )
26
+
27
+ is_absolute_path_exists = os.path.exists(vision_tower)
28
+
29
+ if (is_absolute_path_exists or
30
+ vision_tower.startswith("openai") or
31
+ vision_tower.startswith("laion") or
32
+ "ShareGPT4V" in vision_tower):
33
+ return CLIPVisionTower(
34
+ vision_tower,
35
+ args=vision_tower_cfg,
36
+ custom_vision_tower_path=custom_vision_tower_path,
37
+ **kwargs
38
+ )
39
+
40
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
safellava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ """
9
+ CLIP Vision Encoder for SafeLLaVA
10
+
11
+ This is a minimal version containing only what SafeLLaVA needs.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
17
+ import logging
18
+
19
+ from safellava.utils import setup_simple_logging
20
+
21
+ setup_simple_logging()
22
+
23
+
24
+ class CLIPVisionTower(nn.Module):
25
+ """
26
+ Basic CLIP vision encoder wrapper for SafeLLaVA.
27
+ Uses standard CLIP ViT-L/14-336px without modifications.
28
+ """
29
+
30
+ def __init__(self, vision_tower, args, delay_load=False, custom_vision_tower_path=None, **kwargs):
31
+ super().__init__()
32
+
33
+ self.is_loaded = False
34
+ self.vision_tower_name = vision_tower
35
+ self.select_layer = args.mm_vision_select_layer
36
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
37
+
38
+ self.use_img_safety_token = False
39
+
40
+ # Store custom path if provided (not used in Pool version but needed for compatibility)
41
+ self.custom_vision_tower_path = custom_vision_tower_path
42
+
43
+ if not delay_load:
44
+ self.load_model()
45
+ else:
46
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
47
+
48
+ def load_model(self):
49
+ """Load the CLIP vision model"""
50
+ if self.is_loaded:
51
+ return
52
+
53
+ logging.info(f"Loading vision tower from: {self.vision_tower_name}")
54
+
55
+ # Load standard CLIP model
56
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
57
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
58
+ self.vision_tower.requires_grad_(False)
59
+
60
+ self.is_loaded = True
61
+ logging.info("Initialized CLIPVisionModel (no safety tokens)")
62
+
63
+ def feature_select(self, image_forward_outs):
64
+ """Select features from CLIP output"""
65
+ image_features = image_forward_outs.hidden_states[self.select_layer]
66
+
67
+ if self.select_feature == 'patch':
68
+ # Remove CLS token, keep only patch tokens
69
+ image_features = image_features[:, 1:]
70
+ elif self.select_feature == 'cls_patch':
71
+ # Keep both CLS and patch tokens
72
+ image_features = image_features
73
+ else:
74
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
75
+
76
+ return image_features
77
+
78
+ @torch.no_grad()
79
+ def forward(self, images):
80
+ """Forward pass through vision encoder"""
81
+ if not self.is_loaded:
82
+ self.load_model()
83
+
84
+ if type(images) is list:
85
+ image_features = []
86
+ for image in images:
87
+ image_forward_out = self.vision_tower(
88
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
89
+ output_hidden_states=True
90
+ )
91
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
92
+ image_features.append(image_feature)
93
+ else:
94
+ image_forward_outs = self.vision_tower(
95
+ images.to(device=self.device, dtype=self.dtype),
96
+ output_hidden_states=True
97
+ )
98
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
99
+
100
+ return image_features
101
+
102
+ @property
103
+ def dummy_feature(self):
104
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
105
+
106
+ @property
107
+ def dtype(self):
108
+ return self.vision_tower.dtype
109
+
110
+ @property
111
+ def device(self):
112
+ return self.vision_tower.device
113
+
114
+ @property
115
+ def config(self):
116
+ if self.is_loaded:
117
+ return self.vision_tower.config
118
+ else:
119
+ return self.cfg_only
120
+
121
+ @property
122
+ def hidden_size(self):
123
+ return self.config.hidden_size
124
+
125
+ @property
126
+ def num_patches(self):
127
+ return (self.config.image_size // self.config.patch_size) ** 2
safellava/model/multimodal_projector/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ from .builder import build_vision_projector, IdentityMap, SimpleResBlock
9
+
10
+ __all__ = ['build_vision_projector', 'IdentityMap', 'SimpleResBlock']
safellava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import re
11
+
12
+
13
+ class IdentityMap(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, *args, **kwargs):
18
+ return x
19
+
20
+ @property
21
+ def config(self):
22
+ return {"mm_projector_type": 'identity'}
23
+
24
+
25
+ class SimpleResBlock(nn.Module):
26
+ def __init__(self, channels):
27
+ super().__init__()
28
+ self.pre_norm = nn.LayerNorm(channels)
29
+
30
+ self.proj = nn.Sequential(
31
+ nn.Linear(channels, channels),
32
+ nn.GELU(),
33
+ nn.Linear(channels, channels)
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = self.pre_norm(x)
38
+ return x + self.proj(x)
39
+
40
+
41
+ def build_vision_projector(config, delay_load=False, **kwargs):
42
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
43
+
44
+ if projector_type == 'linear':
45
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
46
+
47
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
48
+ if mlp_gelu_match:
49
+ mlp_depth = int(mlp_gelu_match.group(1))
50
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
51
+ for _ in range(1, mlp_depth):
52
+ modules.append(nn.GELU())
53
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
54
+ return nn.Sequential(*modules)
55
+
56
+ if projector_type == 'identity':
57
+ return IdentityMap()
58
+
59
+ raise ValueError(f'Unknown projector type: {projector_type}')
safellava/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
+ Modified for SafeLLaVA
4
+
5
+ Original LLaVA License: Apache License 2.0
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import logging.handlers
11
+ import os
12
+ import sys
13
+
14
+ import requests
15
+
16
+ from safellava.constants import LOGDIR
17
+
18
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
19
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
20
+
21
+ handler = None
22
+
23
+
24
+ def build_logger(logger_name, logger_filename):
25
+ global handler
26
+
27
+ formatter = logging.Formatter(
28
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ )
31
+
32
+ if not logging.getLogger().handlers:
33
+ logging.basicConfig(level=logging.INFO)
34
+ logging.getLogger().handlers[0].setFormatter(formatter)
35
+
36
+ stdout_logger = logging.getLogger("stdout")
37
+ stdout_logger.setLevel(logging.INFO)
38
+ sl = StreamToLogger(stdout_logger, logging.INFO)
39
+ sys.stdout = sl
40
+
41
+ stderr_logger = logging.getLogger("stderr")
42
+ stderr_logger.setLevel(logging.ERROR)
43
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
44
+ sys.stderr = sl
45
+
46
+ logger = logging.getLogger(logger_name)
47
+ logger.setLevel(logging.INFO)
48
+
49
+ if handler is None:
50
+ os.makedirs(LOGDIR, exist_ok=True)
51
+ filename = os.path.join(LOGDIR, logger_filename)
52
+ handler = logging.handlers.TimedRotatingFileHandler(
53
+ filename, when='D', utc=True, encoding='UTF-8')
54
+ handler.setFormatter(formatter)
55
+
56
+ for name, item in logging.root.manager.loggerDict.items():
57
+ if isinstance(item, logging.Logger):
58
+ item.addHandler(handler)
59
+
60
+ return logger
61
+
62
+
63
+ class StreamToLogger(object):
64
+ """Fake file-like stream object that redirects writes to a logger instance."""
65
+
66
+ def __init__(self, logger, log_level=logging.INFO):
67
+ self.terminal = sys.stdout
68
+ self.logger = logger
69
+ self.log_level = log_level
70
+ self.linebuf = ''
71
+
72
+ def __getattr__(self, attr):
73
+ return getattr(self.terminal, attr)
74
+
75
+ def write(self, buf):
76
+ temp_linebuf = self.linebuf + buf
77
+ self.linebuf = ''
78
+ for line in temp_linebuf.splitlines(True):
79
+ if line[-1] == '\n':
80
+ self.logger.log(self.log_level, line.rstrip())
81
+ else:
82
+ self.linebuf += line
83
+
84
+ def flush(self):
85
+ if self.linebuf != '':
86
+ self.logger.log(self.log_level, self.linebuf.rstrip())
87
+ self.linebuf = ''
88
+
89
+
90
+ def disable_torch_init():
91
+ """Disable the redundant torch default initialization to accelerate model creation."""
92
+ import torch
93
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
94
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
95
+
96
+
97
+ def violates_moderation(text):
98
+ """Check whether the text violates OpenAI moderation API."""
99
+ url = "https://api.openai.com/v1/moderations"
100
+ headers = {"Content-Type": "application/json",
101
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
102
+ text = text.replace("\n", "")
103
+ data = "{" + '"input": ' + f'"{text}"' + "}"
104
+ data = data.encode("utf-8")
105
+ try:
106
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
107
+ flagged = ret.json()["results"][0]["flagged"]
108
+ except requests.exceptions.RequestException as e:
109
+ flagged = False
110
+ except KeyError as e:
111
+ flagged = False
112
+
113
+ return flagged
114
+
115
+
116
+ def pretty_print_semaphore(semaphore):
117
+ if semaphore is None:
118
+ return "None"
119
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
120
+
121
+
122
+ def setup_simple_logging():
123
+ logging.basicConfig(
124
+ level=logging.INFO,
125
+ format='[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s',
126
+ datefmt='%Y-%m-%d %H:%M:%S,%03d'
127
+ )
test_image.png ADDED

Git LFS Details

  • SHA256: b76c421a37642461082bde65f52c653cffe81c2031c36aa77c18dc106cd3b866
  • Pointer size: 131 Bytes
  • Size of remote file: 564 kB
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723