ywlee88 commited on
Commit
0c444ae
·
verified ·
1 Parent(s): 2c809d7

Remove safe_llava_llama_pool.py (Pool naming removed)

Browse files
safellava/model/language_model/safe_llava_llama_pool.py DELETED
@@ -1,618 +0,0 @@
1
- from typing import List, Optional, Tuple, Union, Dict
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
- from transformers.modeling_outputs import CausalLMOutputWithPast
7
-
8
- from llava.model.language_model.llava_llama import (
9
- LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
10
- )
11
- from llava.constants import IMAGE_TOKEN_INDEX
12
-
13
- from dataclasses import dataclass
14
-
15
- import logging
16
- from llava.utils import setup_simple_logging
17
-
18
- setup_simple_logging()
19
-
20
-
21
- @dataclass
22
- class SafetyCausalLMOutputWithPast(CausalLMOutputWithPast):
23
- """
24
- Base class for causal language model (or autoregressive) outputs with safety predictions.
25
- """
26
- img_safety_logits: Optional[torch.FloatTensor] = None
27
- img_safety_probs: Optional[torch.FloatTensor] = None
28
- txt_safety_logits: Optional[torch.FloatTensor] = None
29
- txt_safety_probs: Optional[torch.FloatTensor] = None
30
- total_safety_logits: Optional[torch.FloatTensor] = None
31
- total_safety_probs: Optional[torch.FloatTensor] = None
32
-
33
-
34
- class SafetyMLP(nn.Module):
35
- """
36
- Safety classification head implemented as Multi-layer Perceptron.
37
- """
38
-
39
- def __init__(self, input_size: int, hidden_size: int, output_size: int,
40
- safety_num_hidden_layers: int = 1):
41
- super().__init__()
42
-
43
- layers = []
44
-
45
- layers.append(nn.Linear(input_size, hidden_size))
46
- layers.append(nn.GELU())
47
-
48
- for _ in range(safety_num_hidden_layers - 1):
49
- layers.append(nn.Linear(hidden_size, hidden_size))
50
- layers.append(nn.GELU())
51
-
52
- layers.append(nn.Linear(hidden_size, output_size))
53
-
54
- self.mlp = nn.Sequential(*layers)
55
-
56
- def forward(self, x: torch.Tensor) -> torch.Tensor:
57
- return self.mlp(x)
58
-
59
-
60
- class SafetyConfig(LlavaConfig):
61
- """Safety-aware configuration for pooling version without meta tokens"""
62
- model_type = "safe_llava_llama_pool"
63
-
64
- def __init__(
65
- self,
66
- safety_categories=None,
67
- safety_num_hidden_layers=1,
68
- unfreeze_mm_vision_tower=True,
69
- delay_load_vision_tower=False,
70
- safety_head_hidden_scale=4.0,
71
- pooling_method="mean", # mean, max, or cls
72
- attention_dropout=0.0, # Add missing attribute for compatibility
73
- **kwargs
74
- ):
75
- # Ensure attention_dropout is in kwargs if not provided
76
- if 'attention_dropout' not in kwargs:
77
- kwargs['attention_dropout'] = attention_dropout
78
-
79
- super().__init__(**kwargs)
80
-
81
- # Default safety categories if not provided (from original SafeLLaVA)
82
- self.safety_categories = safety_categories or [
83
- "safe",
84
- "gender",
85
- "race",
86
- "religion",
87
- "harassment",
88
- "disability_discrimination",
89
- "drug_crime",
90
- "property_crime",
91
- "facial_data",
92
- "identity_data",
93
- "physical_self_injury",
94
- "suicide",
95
- "animal_abuse",
96
- "obscene_gestures",
97
- "physical_altercation",
98
- "terrorism",
99
- "weapon_related_violence",
100
- "sexual_content",
101
- "financial_advice",
102
- "medical_advice"
103
- ]
104
-
105
- self.safety_num_hidden_layers = safety_num_hidden_layers
106
- self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
107
- self.delay_load_vision_tower = delay_load_vision_tower
108
- self.safety_head_hidden_scale = safety_head_hidden_scale
109
- self.pooling_method = pooling_method
110
-
111
- # Pool version doesn't use meta tokens
112
- self.use_img_safety_meta_token = False
113
- self.use_txt_safety_meta_token = False
114
- self.use_total_safety_meta_token = False
115
-
116
-
117
- class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
118
- """
119
- SafeLLaVA-Pool: A simplified version without meta tokens.
120
- Pools visual tokens directly for safety classification.
121
- """
122
-
123
- config_class = SafetyConfig
124
-
125
- def __init__(self, config: SafetyConfig):
126
- super().__init__(config)
127
-
128
- # Safety head for image classification (using pooled visual tokens)
129
- self.img_safety_head = SafetyMLP(
130
- input_size=config.hidden_size,
131
- hidden_size=int(config.hidden_size * config.safety_head_hidden_scale),
132
- output_size=len(config.safety_categories),
133
- safety_num_hidden_layers=config.safety_num_hidden_layers
134
- )
135
- logging.info("Created img_safety_head for SafeLLaVA-Pool")
136
-
137
- # Store pooling method
138
- self.pooling_method = config.pooling_method
139
-
140
- # Safety warning template
141
- self.safety_warning_template = (
142
- "I apologize, but I cannot provide a response as the content appears to be {category}. "
143
- "I aim to maintain ethical and safe interactions. "
144
- "Please feel free to ask about other topics that do not involve potentially harmful or inappropriate content."
145
- )
146
-
147
- def get_model(self):
148
- return self.model
149
-
150
- def get_safety_warning(self, unsafe_categories):
151
- if len(unsafe_categories) == 1:
152
- category_str = f"related to {unsafe_categories[0]}"
153
- else:
154
- category_str = "related to " + ", ".join(unsafe_categories[:-1]) + f" and {unsafe_categories[-1]}"
155
- return self.safety_warning_template.format(category=category_str)
156
-
157
- def pool_visual_tokens(self, hidden_states, input_ids, images):
158
- """
159
- Pool visual tokens from hidden states.
160
-
161
- Args:
162
- hidden_states: Last layer hidden states [batch_size, seq_len, hidden_size]
163
- input_ids: Original input token IDs to locate image positions
164
- images: Input images tensor
165
-
166
- Returns:
167
- Pooled visual features [batch_size, hidden_size]
168
- """
169
- batch_size = hidden_states.shape[0]
170
- device = hidden_states.device
171
-
172
- # If no images, return zeros
173
- if images is None:
174
- return torch.zeros(batch_size, hidden_states.shape[-1], device=device)
175
-
176
- # Get the number of visual patches
177
- vision_tower = self.get_vision_tower()
178
- if vision_tower is not None and hasattr(vision_tower, 'config'):
179
- # Calculate based on vision config
180
- image_size = vision_tower.config.image_size
181
- patch_size = vision_tower.config.patch_size
182
- num_patches = (image_size // patch_size) ** 2
183
- else:
184
- num_patches = 576 # Default for CLIP ViT-L/14-336px
185
-
186
- pooled_features = []
187
-
188
- for batch_idx in range(batch_size):
189
- try:
190
- # Find where IMAGE_TOKEN_INDEX was in the original input
191
- if input_ids is not None and batch_idx < input_ids.shape[0]:
192
- image_positions = torch.where(input_ids[batch_idx] == IMAGE_TOKEN_INDEX)[0]
193
-
194
- if len(image_positions) > 0:
195
- # Visual tokens replace the IMAGE_TOKEN_INDEX
196
- # The actual visual tokens start at this position
197
- start_pos = image_positions[0].item()
198
- end_pos = min(start_pos + num_patches, hidden_states.shape[1])
199
-
200
- if end_pos > start_pos and (end_pos - start_pos) > 0:
201
- visual_embeddings = hidden_states[batch_idx, start_pos:end_pos]
202
-
203
- # Apply pooling
204
- if visual_embeddings.shape[0] > 0:
205
- if self.pooling_method == "mean":
206
- pooled = visual_embeddings.mean(dim=0)
207
- elif self.pooling_method == "max":
208
- pooled = visual_embeddings.max(dim=0)[0]
209
- elif self.pooling_method == "cls":
210
- # Use the first visual token
211
- pooled = visual_embeddings[0]
212
- else:
213
- pooled = visual_embeddings.mean(dim=0) # Default to mean
214
-
215
- pooled_features.append(pooled)
216
- else:
217
- # Empty visual embeddings
218
- pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
219
- else:
220
- # Invalid range
221
- pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
222
- else:
223
- # No image token found, might be text-only sample
224
- pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
225
- else:
226
- # No input_ids available
227
- pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
228
-
229
- except Exception as e:
230
- logging.warning(f"Error pooling visual tokens for batch {batch_idx}: {str(e)}")
231
- # Return zero vector on error
232
- pooled_features.append(torch.zeros(hidden_states.shape[-1], device=device))
233
-
234
- # Stack all pooled features
235
- pooled_features = torch.stack(pooled_features, dim=0)
236
- return pooled_features
237
-
238
- def compute_gradcam(
239
- self,
240
- input_ids=None,
241
- attention_mask=None,
242
- images=None,
243
- image_sizes=None,
244
- target_class=None,
245
- use_pre_pooling=False,
246
- **kwargs,
247
- ):
248
- """
249
- Compute Grad-CAM for the image safety classification.
250
-
251
- Args:
252
- input_ids: Input token IDs
253
- attention_mask: Attention mask
254
- images: Input images tensor [batch_size, 3, H, W]
255
- image_sizes: Image sizes
256
- target_class: Target class index for Grad-CAM. If None, uses the predicted class.
257
- use_pre_pooling: If True, compute Grad-CAM before pooling for better spatial resolution
258
-
259
- Returns:
260
- dict with keys:
261
- - 'heatmap': Grad-CAM heatmap [batch_size, H_feat, W_feat]
262
- - 'predicted_class': Predicted class index
263
- - 'predicted_prob': Probability of predicted class
264
- - 'class_name': Name of the target class
265
- """
266
- if images is None:
267
- raise ValueError("Images are required for Grad-CAM computation")
268
-
269
- # Enable gradient computation for images
270
- # Note: We need to enable train mode for vision tower to compute gradients
271
- was_training = self.training
272
- was_vision_training = self.get_vision_tower().training
273
-
274
- # Set vision tower to train mode to enable gradients
275
- vision_tower = self.get_vision_tower()
276
- vision_tower.train()
277
-
278
- # CRITICAL: Enable gradients for vision tower parameters
279
- # This is necessary because merged LoRA models might have frozen parameters
280
- for param in vision_tower.parameters():
281
- param.requires_grad = True
282
-
283
- # Note: We keep model in eval mode for other components (dropout, batchnorm)
284
- # but vision tower is in train mode for gradient computation
285
-
286
- # Ensure images require grad
287
- if not images.requires_grad:
288
- images = images.clone().detach().requires_grad_(True)
289
-
290
- logging.info(f"Images requires_grad: {images.requires_grad}")
291
-
292
- # Store activations and gradients for Grad-CAM
293
- activations = []
294
- gradients = []
295
-
296
- def save_gradient(grad):
297
- """Backward hook to capture gradients"""
298
- logging.info(f"Gradient hook called! Grad shape: {grad.shape}")
299
- gradients.append(grad.detach())
300
-
301
- def forward_hook(module, input, output):
302
- """Forward hook to save activations and register backward hook"""
303
- if isinstance(output, tuple):
304
- activation = output[0]
305
- else:
306
- activation = output
307
-
308
- logging.info(f"Forward hook: activation shape={activation.shape}, requires_grad={activation.requires_grad}")
309
-
310
- # Register backward hook on the activation tensor itself BEFORE saving
311
- if activation.requires_grad:
312
- activation.register_hook(save_gradient)
313
- logging.info("Registered backward hook on activation")
314
- else:
315
- logging.warning("Activation does not require grad, cannot register backward hook!")
316
-
317
- # Save activation (keep gradient connection for now, will detach later if needed)
318
- activations.append(activation)
319
-
320
- # Register hook on vision tower
321
- vision_tower = self.get_vision_tower()
322
- if vision_tower is None:
323
- raise AttributeError("Vision tower not found")
324
-
325
- hook_handle = vision_tower.register_forward_hook(forward_hook)
326
-
327
- try:
328
- # Forward pass - Do normal forward but intercept and modify vision features
329
- # CRITICAL: Use autograd.enable_grad() to force gradient tracking
330
-
331
- # Store original vision tower forward
332
- vision_tower = self.get_vision_tower()
333
- original_forward = vision_tower.forward
334
-
335
- # Create a wrapper that forces requires_grad on output
336
- def forward_with_grad(*args, **kwargs):
337
- output = original_forward(*args, **kwargs)
338
- if not output.requires_grad:
339
- output = output.clone().requires_grad_(True)
340
- # Register hook on this tensor
341
- output.register_hook(save_gradient)
342
- # Save to activations
343
- activations.append(output)
344
- return output
345
-
346
- # Temporarily replace forward
347
- vision_tower.forward = forward_with_grad
348
-
349
- try:
350
- with torch.enable_grad():
351
- if use_pre_pooling:
352
- # For pre-pooling Grad-CAM, we need to capture the visual tokens from hidden_states
353
- # before they are pooled
354
- pre_pool_activations = []
355
- pre_pool_gradients = []
356
-
357
- def save_pre_pool_gradient(grad):
358
- pre_pool_gradients.append(grad)
359
-
360
- # Store original pool_visual_tokens method
361
- original_pool_method = self.pool_visual_tokens
362
-
363
- # Replace with a wrapper that captures pre-pooling features
364
- def pool_with_capture(hidden_states, input_ids, images):
365
- # Extract visual tokens before pooling
366
- # Visual tokens are typically in the positions where image tokens were
367
- batch_size = hidden_states.shape[0]
368
-
369
- # Find image token positions
370
- # The image token index is -200 by default in LLaVA
371
- IMAGE_TOKEN_INDEX = -200
372
- image_token_indices = []
373
- for batch_idx in range(batch_size):
374
- image_positions = (input_ids[batch_idx] == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
375
- if len(image_positions) > 0:
376
- image_token_indices.append(image_positions)
377
-
378
- # Extract visual features before pooling
379
- if len(image_token_indices) > 0:
380
- visual_features = hidden_states[0, image_token_indices[0]] # [num_patches, hidden_dim]
381
- visual_features = visual_features.clone().requires_grad_(True)
382
- pre_pool_activations.append(visual_features)
383
- visual_features.register_hook(save_pre_pool_gradient)
384
-
385
- # Call original pooling method
386
- return original_pool_method(hidden_states, input_ids, images)
387
-
388
- # Temporarily replace the pooling method
389
- self.pool_visual_tokens = pool_with_capture
390
-
391
- # Now do the full forward pass
392
- outputs = self.forward(
393
- input_ids=input_ids,
394
- attention_mask=attention_mask,
395
- images=images,
396
- image_sizes=image_sizes,
397
- do_safety=True,
398
- return_dict=True,
399
- **kwargs
400
- )
401
-
402
- img_safety_logits = outputs.img_safety_logits
403
- img_safety_probs = outputs.img_safety_probs
404
-
405
- if use_pre_pooling:
406
- # Restore original pooling method
407
- self.pool_visual_tokens = original_pool_method
408
- finally:
409
- # Restore original forward
410
- vision_tower.forward = original_forward
411
-
412
- # Get predicted class if not specified
413
- if target_class is None:
414
- # Use the class with highest probability
415
- target_class = img_safety_probs.argmax(dim=-1)
416
- else:
417
- # Ensure target_class is a tensor
418
- if isinstance(target_class, int):
419
- target_class = torch.tensor([target_class], device=img_safety_probs.device)
420
-
421
- # Get the logit for the target class
422
- batch_size = img_safety_probs.shape[0]
423
- target_logits = img_safety_logits[torch.arange(batch_size), target_class]
424
-
425
- # Backward pass to compute gradients
426
- self.zero_grad()
427
- target_logits.sum().backward()
428
-
429
- # Choose which activations and gradients to use
430
- if use_pre_pooling:
431
- # Use pre-pooling features for better spatial resolution
432
- if 'pre_pool_activations' not in locals() or len(pre_pool_activations) == 0:
433
- raise RuntimeError("Failed to capture pre-pooling activations")
434
- if 'pre_pool_gradients' not in locals() or len(pre_pool_gradients) == 0:
435
- raise RuntimeError("Failed to capture pre-pooling gradients")
436
-
437
- # Get the pre-pooling features
438
- # These have spatial structure: [num_patches, hidden_dim]
439
- activation = pre_pool_activations[0].detach()
440
- gradient = pre_pool_gradients[0]
441
-
442
- # Add batch dimension if needed for consistency
443
- if activation.dim() == 2:
444
- activation = activation.unsqueeze(0) # [1, num_patches, hidden_dim]
445
- gradient = gradient.unsqueeze(0)
446
- else:
447
- # Use post-pooling features (original behavior - from vision tower)
448
- if len(activations) == 0:
449
- raise RuntimeError("Failed to capture activations")
450
- if len(gradients) == 0:
451
- raise RuntimeError("Failed to capture gradients")
452
-
453
- activation = activations[0].detach() # [batch_size, num_patches, hidden_dim]
454
- gradient = gradients[0] # [batch_size, num_patches, hidden_dim]
455
-
456
- # Compute Grad-CAM with correct formula
457
- # For Vision Transformer: gradients and activations are [batch, num_patches, hidden_dim]
458
- # Standard Grad-CAM: compute importance by averaging gradients across hidden dimension
459
- # Then weight the activations
460
-
461
- # Option 1: Standard Grad-CAM - use gradient magnitude as importance
462
- # This captures which patches have the strongest gradient signal
463
- cam = (gradient * activation).sum(dim=-1) # [batch_size, num_patches]
464
-
465
- # Alternative would be:
466
- # weights = gradient.mean(dim=1, keepdim=True) # Average across patches
467
- # cam = (activation * weights).sum(dim=-1)
468
-
469
- # Apply ReLU (only positive contributions)
470
- cam = torch.nn.functional.relu(cam)
471
-
472
- # Reshape to 2D spatial grid
473
- # CLIP ViT-L/14-336px has 24x24 patches
474
- num_patches_per_side = int(cam.shape[1] ** 0.5)
475
- cam = cam.reshape(batch_size, num_patches_per_side, num_patches_per_side)
476
-
477
- # Normalize to [0, 1]
478
- for i in range(batch_size):
479
- cam_min = cam[i].min()
480
- cam_max = cam[i].max()
481
- if cam_max > cam_min:
482
- cam[i] = (cam[i] - cam_min) / (cam_max - cam_min)
483
-
484
- # Get class names
485
- if isinstance(target_class, torch.Tensor):
486
- target_class_idx = target_class[0].item()
487
- else:
488
- target_class_idx = target_class
489
-
490
- class_name = self.config.safety_categories[target_class_idx]
491
-
492
- return {
493
- 'heatmap': cam.detach().cpu().numpy(),
494
- 'predicted_class': target_class.cpu().numpy() if isinstance(target_class, torch.Tensor) else target_class,
495
- 'predicted_prob': img_safety_probs[torch.arange(batch_size), target_class].detach().cpu().numpy(),
496
- 'class_name': class_name,
497
- 'all_probs': img_safety_probs.detach().cpu().numpy()
498
- }
499
-
500
- finally:
501
- # Remove hook
502
- hook_handle.remove()
503
- # Restore training state
504
- if not was_vision_training:
505
- self.get_vision_tower().eval()
506
- if was_training:
507
- self.train()
508
-
509
- def forward(
510
- self,
511
- input_ids=None,
512
- attention_mask=None,
513
- position_ids=None,
514
- past_key_values=None,
515
- inputs_embeds=None,
516
- labels=None,
517
- use_cache=None,
518
- output_attentions=None,
519
- output_hidden_states=None,
520
- images=None,
521
- image_sizes=None,
522
- return_dict=None,
523
- do_safety=False,
524
- **kwargs,
525
- ) -> Union[Tuple, CausalLMOutputWithPast, SafetyCausalLMOutputWithPast]:
526
- """
527
- Forward method for SafeLLaVA-Pool.
528
- When do_safety=True, extracts and pools visual tokens for safety classification.
529
- """
530
-
531
- # Store original input_ids for finding image token positions
532
- original_input_ids = input_ids.clone() if input_ids is not None else None
533
-
534
- # If do_safety is True, force output_hidden_states to True
535
- if do_safety and (output_hidden_states is not True):
536
- output_hidden_states = True
537
- return_dict = True
538
-
539
- # Prepare inputs for multimodal (handles image embedding)
540
- if inputs_embeds is None:
541
- (
542
- input_ids,
543
- position_ids,
544
- attention_mask,
545
- past_key_values,
546
- inputs_embeds,
547
- labels
548
- ) = self.prepare_inputs_labels_for_multimodal(
549
- input_ids,
550
- position_ids,
551
- attention_mask,
552
- past_key_values,
553
- labels,
554
- images,
555
- image_sizes
556
- )
557
-
558
- # Call parent's forward method
559
- outputs = super(LlavaLlamaForCausalLM, self).forward(
560
- input_ids=input_ids,
561
- attention_mask=attention_mask,
562
- position_ids=position_ids,
563
- past_key_values=past_key_values,
564
- inputs_embeds=inputs_embeds,
565
- labels=labels,
566
- use_cache=use_cache,
567
- output_attentions=output_attentions,
568
- output_hidden_states=output_hidden_states,
569
- return_dict=True,
570
- **kwargs
571
- )
572
-
573
- # If do_safety=False, just return the outputs
574
- if not do_safety:
575
- if return_dict is False:
576
- return (outputs.loss, outputs.logits, outputs.past_key_values,
577
- outputs.hidden_states, outputs.attentions)
578
- return outputs
579
-
580
- # Safety classification using pooled visual tokens
581
- hidden_states = outputs.hidden_states[-1] # Last layer hidden states
582
-
583
- # Check if we have images to process
584
- if images is None:
585
- # No images, return outputs without safety
586
- return outputs
587
-
588
- # Pool visual tokens
589
- pooled_visual_features = self.pool_visual_tokens(hidden_states, original_input_ids, images)
590
-
591
- # Pass through safety head
592
- img_safety_logits = self.img_safety_head(pooled_visual_features)
593
- img_safety_probs = torch.softmax(img_safety_logits, dim=-1)
594
-
595
- # Return results with safety outputs
596
- if not return_dict:
597
- return (outputs.loss, outputs.logits, outputs.past_key_values,
598
- outputs.hidden_states, outputs.attentions,
599
- img_safety_logits, img_safety_probs)
600
-
601
- return SafetyCausalLMOutputWithPast(
602
- loss=outputs.loss,
603
- logits=outputs.logits,
604
- past_key_values=outputs.past_key_values,
605
- hidden_states=outputs.hidden_states,
606
- attentions=outputs.attentions,
607
- img_safety_logits=img_safety_logits,
608
- img_safety_probs=img_safety_probs,
609
- txt_safety_logits=None, # Not used in Pool version
610
- txt_safety_probs=None,
611
- total_safety_logits=None,
612
- total_safety_probs=None
613
- )
614
-
615
-
616
- # Register the model
617
- AutoConfig.register("safe_llava_llama_pool", SafetyConfig)
618
- AutoModelForCausalLM.register(SafetyConfig, SafeLlavaLlamaForCausalLM)