manu02 commited on
Commit
65cc576
·
1 Parent(s): 110fd76

Upload 4 files

Browse files
utils/complete_model.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModel, GPT2Tokenizer
5
+
6
+ from utils.modifiedGPT2 import create_decoder
7
+
8
+ from utils.layer_mask import gaussian_layer_stack_pipeline
9
+
10
+
11
+ class DINOEncoder(nn.Module):
12
+ def __init__(self, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", freeze=True):
13
+ super().__init__()
14
+ self.model = AutoModel.from_pretrained(model_id)
15
+ if freeze:
16
+ for p in self.model.parameters():
17
+ p.requires_grad = False
18
+ @torch.no_grad()
19
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ pixel_values: [B, C, H, W]
22
+ returns patches: [B, Np, Cenc]
23
+ """
24
+ out = self.model(pixel_values=pixel_values)
25
+ tokens = out.last_hidden_state # [B, 1+Np, Cenc] (CLS + patches) for ViT-like
26
+ # Skip a few special tokens if your backbone adds them; adjust as needed.
27
+ patches = tokens[:, 5:, :] # [B, Np, Cenc]
28
+ return patches
29
+
30
+ class DinoUNet(nn.Module):
31
+ def __init__(self, model_name="facebook/dinov3-convnext-small-pretrain-lvd1689m", freeze=True):
32
+ super().__init__()
33
+ self.encoder = AutoModel.from_pretrained(model_name)
34
+ # NOTE: confirm channels of the chosen hidden state; 768 is common for small convnext/dinov3
35
+ self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1)
36
+ self.decoder = nn.Sequential(
37
+ nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True),
38
+ nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(inplace=True),
39
+ nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(inplace=True),
40
+ nn.Conv2d(64, 1, 1)
41
+ )
42
+ if freeze:
43
+ for m in (self.encoder, self.channel_adapter, self.decoder):
44
+ for p in m.parameters():
45
+ p.requires_grad = False
46
+
47
+ @torch.no_grad()
48
+ def forward(self, x: torch.Tensor, num_layers: int) -> torch.Tensor:
49
+ """
50
+ x: [B, C, H, W]; returns mask: [B, 1, H', W'] (your upsampling stack defines H',W')
51
+ """
52
+ enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True)
53
+ # take the last 4D feature map from hidden_states
54
+ feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
55
+ feats = self.channel_adapter(feats)
56
+ pred = self.decoder(feats) # (B,1,h,w)
57
+ _, _, segmentation_mask = gaussian_layer_stack_pipeline(pred, n_layers = num_layers)
58
+ return segmentation_mask # [B, num_layers, h, w]
59
+
60
+
61
+ class LinearProjection(nn.Module):
62
+ def __init__(self, input_dim=384, output_dim=768, freeze=False):
63
+ super().__init__()
64
+ self.proj = nn.Linear(input_dim, output_dim)
65
+ if freeze:
66
+ for p in self.proj.parameters():
67
+ p.requires_grad = False
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ # x: [B, Np, input_dim] -> [B, Np, output_dim]
71
+ return self.proj(x)
72
+
73
+
74
+ class CustomModel(nn.Module):
75
+ def __init__(
76
+ self,
77
+ device: str = "cuda",
78
+ ENCODER_MODEL_PATH: str | None = "dino_encoder.pth",
79
+ SEGMENTER_MODEL_PATH: str | None = "dino_segmenter.pth",
80
+ DECODER_MODEL_PATH: str | None = "dino_decoder.pth",
81
+ LINEAR_PROJECTION_PATH: str | None = "linear_projection.pth",
82
+ freeze_encoder: bool = True,
83
+ freeze_segmenter: bool = True,
84
+ freeze_linear_projection: bool = False,
85
+ freeze_decoder: bool = False,
86
+ attention_implementation: str = "sdpa",
87
+ ):
88
+ super().__init__()
89
+ self.device = torch.device(device)
90
+
91
+ # Encoder
92
+ self.encoder = DINOEncoder()
93
+ if ENCODER_MODEL_PATH and os.path.exists(ENCODER_MODEL_PATH):
94
+ self.encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location="cpu"), strict=False)
95
+ print("Loaded encoder weights from", ENCODER_MODEL_PATH)
96
+ if freeze_encoder:
97
+ self.encoder.eval()
98
+
99
+ # Segmenter
100
+ self.segmenter = DinoUNet()
101
+ if SEGMENTER_MODEL_PATH and os.path.exists(SEGMENTER_MODEL_PATH):
102
+ self.segmenter.load_state_dict(torch.load(SEGMENTER_MODEL_PATH, map_location="cpu"), strict=False)
103
+ print("Loaded segmenter weights from", SEGMENTER_MODEL_PATH)
104
+ if freeze_segmenter:
105
+ self.segmenter.eval()
106
+
107
+ # Decoder (modified GPT-2)
108
+ self.decoder = create_decoder(attention=attention_implementation) # must expose .config.hidden_size & .config.num_hidden_layers
109
+ if DECODER_MODEL_PATH and os.path.exists(DECODER_MODEL_PATH):
110
+ self.decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location="cpu"), strict=False)
111
+ print("Loaded decoder weights from", DECODER_MODEL_PATH)
112
+ if freeze_decoder:
113
+ self.decoder.eval()
114
+
115
+ # Linear projection: DINO hidden -> GPT2 hidden
116
+ enc_h = self.encoder.model.config.hidden_size
117
+ dec_h = self.decoder.config.hidden_size
118
+ self.linear_projection = LinearProjection(input_dim=enc_h, output_dim=dec_h)
119
+ if LINEAR_PROJECTION_PATH and os.path.exists(LINEAR_PROJECTION_PATH):
120
+ self.linear_projection.load_state_dict(torch.load(LINEAR_PROJECTION_PATH, map_location="cpu"), strict=False)
121
+ print("Loaded linear projection weights from", LINEAR_PROJECTION_PATH)
122
+ if freeze_linear_projection:
123
+ self.linear_projection.eval()
124
+
125
+ # Tokenizer (pad token for GPT-2)
126
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
127
+ if self.tokenizer.pad_token_id is None:
128
+ self.tokenizer.pad_token = self.tokenizer.eos_token
129
+ self.pad_token_id = self.tokenizer.pad_token_id # ✅ use ID, not string
130
+
131
+ self.num_layers = self.decoder.config.num_hidden_layers
132
+
133
+ # move everything once
134
+ self.to(self.device)
135
+
136
+ def forward(self, pixel_values: torch.Tensor, tgt_ids: torch.Tensor | None = None, **kwargs) -> dict:
137
+ """
138
+ pixel_values: [B,C,H,W], float
139
+ tgt_ids: [B,T], long (token IDs), padded with pad_token_id if any padding is present
140
+ """
141
+ pixel_values = pixel_values.to(self.device, non_blocking=True)
142
+
143
+ # Visual path
144
+ patches = self.encoder(pixel_values) # [B,Np,Cenc]
145
+ projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
146
+
147
+ # Segmentation path per layer
148
+ segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
149
+
150
+ # Text path (optional teacher-forced training)
151
+ labels = None
152
+ if tgt_ids is not None:
153
+ if tgt_ids.dtype != torch.long:
154
+ tgt_ids = tgt_ids.long()
155
+ tgt_ids = tgt_ids.to(self.device, non_blocking=True) # [B,T]
156
+ text_embeds = self.decoder.transformer.wte(tgt_ids) # [B,T,n_embd]
157
+ inputs_embeds = torch.cat([projected_patches, text_embeds], dim=1) # [B,Np+T,n_embd]
158
+
159
+ # Labels: ignore prefix tokens (vision) and PADs in text
160
+ B, Np, _ = projected_patches.shape
161
+ labels_prefix = torch.full((B, Np), -100, device=self.device, dtype=torch.long)
162
+ text_labels = tgt_ids.clone()
163
+ text_labels[text_labels == self.pad_token_id] = -100 # ✅ compare to ID
164
+ labels = torch.cat([labels_prefix, text_labels], dim=1) # [B,Np+T]
165
+ else:
166
+ inputs_embeds = projected_patches
167
+
168
+ # Decoder forward
169
+ out = self.decoder(inputs_embeds=inputs_embeds, segmentation_mask=segmented_layers, labels=labels, **kwargs)
170
+ return out
171
+
172
+ @torch.inference_mode()
173
+ def generate(
174
+ self,
175
+ pixel_values: torch.Tensor,
176
+ max_new_tokens: int = 100,
177
+ output_attentions: bool = False,
178
+ ) -> torch.Tensor:
179
+ """
180
+ pixel_values: [B,C,H,W], float
181
+ returns generated_ids: [B, T]
182
+ """
183
+ pixel_values = pixel_values.to(self.device, non_blocking=True)
184
+
185
+ # Visual path
186
+ patches = self.encoder(pixel_values) # [B,Np,Cenc]
187
+ projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
188
+
189
+ # Segmentation path per layer
190
+ segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
191
+
192
+ # Generate
193
+ output = self.decoder.generate(
194
+ inputs_embeds=projected_patches,
195
+ max_new_tokens=max_new_tokens,
196
+ do_sample=False,
197
+ repetition_penalty=1.2,
198
+ eos_token_id=self.tokenizer.eos_token_id,
199
+ pad_token_id=self.pad_token_id,
200
+ use_cache=True,
201
+ segmentation_mask=segmented_layers,
202
+ prefix_allowed_length=0,
203
+ plot_attention_mask=False,
204
+ plot_attention_mask_layer=[],
205
+ plot_attention_map=False,
206
+ plot_attention_map_layer=[],
207
+ plot_attention_map_generation=0,
208
+ output_attentions=output_attentions,
209
+ return_dict_in_generate=True,
210
+ )
211
+ # Remove prefix tokens (vision)
212
+ generated_ids = output.sequences#[:, projected_patches.shape[1]:] # [B,T]
213
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
214
+ return generated_ids, generated_text, output.attentions if output_attentions else None
215
+
216
+ def create_complete_model(device: str = "cuda", **kwargs) -> CustomModel:
217
+ model = CustomModel(device=device, **kwargs)
218
+ return model
219
+
220
+ def save_complete_model(model: CustomModel, save_path: str, device: str = "cuda") -> None:
221
+ # Ensure folder exists
222
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
223
+
224
+ # Save on CPU to keep checkpoint portable
225
+ orig_device = next(model.parameters()).device
226
+ model.to("cpu")
227
+ torch.save(model.state_dict(), save_path)
228
+ print(f"Saved complete model weights to {save_path}")
229
+
230
+ # Restore model device
231
+ model.to(device if isinstance(device, str) else orig_device)
232
+
233
+ def save_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, save_path: str) -> None:
234
+ # Ensure folder exists
235
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
236
+
237
+ checkpoint = {
238
+ "model_state_dict": model.state_dict(),
239
+ "optimizer_state_dict": optimizer.state_dict(),
240
+ }
241
+ torch.save(checkpoint, save_path)
242
+ print(f"Saved checkpoint to {save_path}")
243
+
244
+ def load_complete_model(model: CustomModel, load_path: str, device: str = "cpu", strict: bool = True) -> CustomModel:
245
+ if not os.path.exists(load_path):
246
+ print(f"No weights found at {load_path}")
247
+ model.to(device)
248
+ return model
249
+
250
+ # Load to CPU first, then move to target device
251
+ state = torch.load(load_path, map_location="cpu")
252
+ missing, unexpected = model.load_state_dict(state, strict=strict)
253
+ if not strict:
254
+ if missing:
255
+ print(f"[load warning] Missing keys: {missing}")
256
+ if unexpected:
257
+ print(f"[load warning] Unexpected keys: {unexpected}")
258
+
259
+ model.to(device)
260
+ print(f"Loaded complete model weights from {load_path}")
261
+ return model
262
+
263
+ def load_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, load_path: str, device: str = "cpu") -> tuple[CustomModel, torch.optim.Optimizer]:
264
+ if not os.path.exists(load_path):
265
+ print(f"No checkpoint found at {load_path}")
266
+ model.to(device)
267
+ return model, optimizer
268
+
269
+ # Load to CPU first, then move to target device
270
+ checkpoint = torch.load(load_path, map_location="cpu")
271
+ model.load_state_dict(checkpoint["model_state_dict"])
272
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
273
+
274
+ model.to(device)
275
+ print(f"Loaded checkpoint from {load_path}")
276
+ return model, optimizer
utils/layer_mask.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ @torch.no_grad()
8
+ def gaussian_layer_stack_pipeline(
9
+ x: torch.Tensor,
10
+ n_layers: int,
11
+ base_ksize: int = 3,
12
+ ksize_growth: int = 2,
13
+ sigma: float | None = None,
14
+ eps: float = 1e-8,
15
+ ):
16
+ """
17
+ All-in-one GPU batch pipeline:
18
+ 1) Per-sample min-max normalize to [0,1]
19
+ 2) Resize to (32,32)
20
+ 3) Apply L Gaussian blurs with increasing kernel size in a single
21
+ horizontal conv + single vertical conv using depthwise groups
22
+ (via a shared max kernel padded with zeros)
23
+ 4) Renormalize each layer to [0,1]
24
+ 5) Return stacked (B,L,32,32), flat (B,L,1024), tiled (B,L,1024,1024 view)
25
+
26
+ Args:
27
+ x: (B,H,W) or (B,1,H,W) tensor (any device/dtype)
28
+ n_layers: number of layers
29
+ base_ksize: starting odd kernel size (e.g., 3)
30
+ ksize_growth: increment per layer (e.g., 2) -> ensures odd sizes
31
+ sigma: if None, uses (ksize-1)/6 per layer; else fixed sigma for all
32
+ eps: small number for safe division
33
+
34
+ Returns:
35
+ stacked: (B, n_layers, 32, 32) float on x.device
36
+ flat: (B, n_layers, 1024)
37
+ tiled: (B, n_layers, 1024, 1024) (expand view; memory-cheap)
38
+ """
39
+ assert n_layers >= 1, "n_layers must be >= 1"
40
+
41
+ # ---- Ensure 4D, 1 channel; cast to float (stay on same device) ----
42
+ if x.ndim == 3:
43
+ x = x.unsqueeze(1) # (B,1,H,W)
44
+ elif x.ndim != 4 or x.shape[1] not in (1,):
45
+ raise ValueError(f"Expected (B,H,W) or (B,1,H,W); got {tuple(x.shape)}")
46
+ x = x.float()
47
+
48
+ B, _, H, W = x.shape
49
+
50
+ # ---- Per-sample min-max normalize to [0,1] ----
51
+ xmin = x.amin(dim=(2, 3), keepdim=True)
52
+ xmax = x.amax(dim=(2, 3), keepdim=True)
53
+ denom = (xmax - xmin).clamp_min(eps)
54
+ x = (x - xmin) / denom # (B,1,H,W) in [0,1]
55
+
56
+ # ---- Resize to 32x32 on GPU ----
57
+ x = F.interpolate(x, size=(32, 32), mode="bilinear", align_corners=False) # (B,1,32,32)
58
+
59
+ # ---- Prepare per-layer kernel sizes (odd) ----
60
+ ksizes = []
61
+ for i in range(n_layers, 0, -1): # to keep your original ordering: L...1
62
+ k = base_ksize + i * ksize_growth
63
+ k = int(k)
64
+ if k % 2 == 0:
65
+ k += 1
66
+ k = max(k, 1)
67
+ ksizes.append(k)
68
+
69
+ Kmax = max(ksizes)
70
+ pad = Kmax // 2
71
+
72
+ # ---- Build per-layer 1D Gaussian vectors and embed into shared Kmax kernel ----
73
+ # We create horizontal weights of shape (L,1,1,Kmax) and vertical (L,1,Kmax,1)
74
+ device, dtype = x.device, x.dtype
75
+ weight_h = torch.zeros((n_layers, 1, 1, Kmax), device=device, dtype=dtype)
76
+ weight_v = torch.zeros((n_layers, 1, Kmax, 1), device=device, dtype=dtype)
77
+
78
+ for idx, k in enumerate(ksizes):
79
+ # choose sigma
80
+ sig = sigma if (sigma is not None and sigma > 0) else (k - 1) / 6.0
81
+ r = k // 2
82
+ xp = torch.arange(-r, r + 1, device=device, dtype=dtype)
83
+ g = torch.exp(-(xp * xp) / (2.0 * sig * sig))
84
+ g = g / g.sum() # (k,)
85
+
86
+ # center g into Kmax with zeros around
87
+ start = (Kmax - k) // 2
88
+ end = start + k
89
+
90
+ # horizontal row
91
+ weight_h[idx, 0, 0, start:end] = g # (1 x Kmax)
92
+
93
+ # vertical column
94
+ weight_v[idx, 0, start:end, 0] = g # (Kmax x 1)
95
+
96
+ # ---- Duplicate input across L channels (depthwise groups) ----
97
+ xL = x.expand(B, n_layers, 32, 32).contiguous() # (B,L,32,32)
98
+
99
+ # ---- Separable Gaussian blur with a single pass per axis (groups=L) ----
100
+ # Horizontal
101
+ xh = F.pad(xL, (pad, pad, 0, 0), mode="reflect")
102
+ xh = F.conv2d(xh, weight=weight_h, bias=None, stride=1, padding=0, groups=n_layers) # (B,L,32,32)
103
+
104
+ # Vertical
105
+ xv = F.pad(xh, (0, 0, pad, pad), mode="reflect")
106
+ yL = F.conv2d(xv, weight=weight_v, bias=None, stride=1, padding=0, groups=n_layers) # (B,L,32,32)
107
+
108
+ # ---- Renormalize each layer to [0,1] (per-sample, per-layer) ----
109
+ y_min = yL.amin(dim=(2, 3), keepdim=True)
110
+ y_max = yL.amax(dim=(2, 3), keepdim=True)
111
+ y_den = (y_max - y_min).clamp_min(eps)
112
+ stacked = (yL - y_min) / y_den # (B,L,32,32) in [0,1]
113
+
114
+ # ---- Flatten + tile (expand view; caution w/ later materialization) ----
115
+ flat = stacked.reshape(B, n_layers, 32 * 32) # (B,L,1024)
116
+ tiled = flat.unsqueeze(-2).expand(-1, -1, 32 * 32, -1) # (B,L,1024,1024) view
117
+
118
+ return stacked, flat, tiled
119
+
120
+ def plot_layers_any(
121
+ x,
122
+ *,
123
+ max_batches=None,
124
+ vlim=(0, 1),
125
+ one_indexed: bool = False,
126
+ max_cols: int = 6,
127
+ ):
128
+ """
129
+ Plot layers for each batch sample in separate figures.
130
+
131
+ Accepts:
132
+ - stacked: (B, L, H, W)
133
+ - flat: (B, L, HW)
134
+ - tiled: (B, L, HW, HW)
135
+
136
+ Behavior:
137
+ - Creates one figure PER BATCH (up to `max_batches`).
138
+ - At most `max_cols` layers per row (default 6).
139
+ - Column headers: 'Layer {i}' descending from n-1 -> 0 (or n -> 1 if one_indexed=True).
140
+ - Figure title per batch: 'Masks for input {i} out of {B}'.
141
+
142
+ Returns:
143
+ A list of (fig, axes) tuples, one per plotted batch.
144
+ """
145
+ # ---- Normalize input to torch ----
146
+ if isinstance(x, np.ndarray):
147
+ x = torch.from_numpy(x)
148
+ if not isinstance(x, torch.Tensor):
149
+ raise TypeError(f"Expected torch.Tensor or np.ndarray, got {type(x)}")
150
+
151
+ if x.ndim not in (3, 4):
152
+ raise ValueError(f"Expected ndim 3 or 4, got shape {tuple(x.shape)}")
153
+
154
+ # ---- Convert to (B, L, H, W) 'stacked' ----
155
+ if x.ndim == 4:
156
+ B, L, A, B_ = x.shape
157
+ if A == B_:
158
+ # Could be stacked (H==W) or tiled (HW x HW). Heuristic: if A is a perfect square
159
+ # and reasonably large (e.g., 1024), treat as tiled and collapse to flat.
160
+ s = int(math.isqrt(A))
161
+ if s * s == A and A >= 64:
162
+ flat = x[..., 0, :].detach() # (B, L, HW)
163
+ H = W = s
164
+ stacked = flat.reshape(B, L, H, W)
165
+ else:
166
+ stacked = x.detach()
167
+ else:
168
+ stacked = x.detach()
169
+ else:
170
+ # x.ndim == 3 -> (B, L, HW)
171
+ B, L, HW = x.shape
172
+ s = int(math.isqrt(HW))
173
+ if s * s != HW:
174
+ if HW != 32 * 32:
175
+ raise ValueError(
176
+ f"Cannot infer square image size from HW={HW}. "
177
+ f"Provide stacked (B,L,H,W) or flat with square HW."
178
+ )
179
+ s = 32
180
+ H = W = s
181
+ stacked = x.detach().reshape(B, L, H, W)
182
+
183
+ # Ensure float & CPU for plotting
184
+ stacked = stacked.to(torch.float32).cpu().numpy()
185
+
186
+ # ---- Batch selection ----
187
+ B, L, H, W = stacked.shape
188
+ plot_B = B if max_batches is None else max(1, min(B, int(max_batches)))
189
+
190
+ # ---- Layout params ----
191
+ cols = max(1, int(max_cols))
192
+ rows_needed = lambda L: (L + cols - 1) // cols
193
+
194
+ figs = []
195
+ for b in range(plot_B):
196
+ # number of rows for this batch
197
+ r = rows_needed(L)
198
+ fig, axes = plt.subplots(r, cols, figsize=(cols * 3, r * 3), squeeze=False)
199
+ fig.suptitle(f"Masks for input {b} out of {B}", fontsize=12, y=1.02)
200
+
201
+ for l in range(L):
202
+ rr = l // cols
203
+ cc = l % cols
204
+ ax = axes[rr, cc]
205
+ if vlim is None:
206
+ ax.imshow(stacked[b, l], cmap="gray")
207
+ else:
208
+ ax.imshow(stacked[b, l], cmap="gray", vmin=vlim[0], vmax=vlim[1])
209
+ ax.axis("off")
210
+
211
+ # Set column titles only on the first row of the grid
212
+ label_num = (l + 1) if one_indexed else l
213
+ ax.set_title(f"Layer {label_num}", fontsize=10)
214
+
215
+ # Hide any unused axes (when L is not a multiple of cols)
216
+ total_slots = r * cols
217
+ for empty_idx in range(L, total_slots):
218
+ rr = empty_idx // cols
219
+ cc = empty_idx % cols
220
+ axes[rr, cc].axis("off")
221
+
222
+ plt.tight_layout()
223
+ figs.append((fig, axes))
224
+ return figs
utils/modifiedGPT2.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import GPT2LMHeadModel, GPT2Model, GPT2Config
5
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
6
+ from transformers.masking_utils import create_causal_mask
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutputWithPastAndCrossAttentions,
10
+ CausalLMOutputWithCrossAttentions,
11
+ )
12
+ from transformers.utils import (
13
+ logging,
14
+ )
15
+
16
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Attention, eager_attention_forward
17
+ from torch import nn
18
+ from typing import Callable
19
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
20
+
21
+ import matplotlib.pyplot as plt
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ class GPT2AttentionModified(GPT2Attention):
26
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
27
+ super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
28
+ self.config = config
29
+ max_positions = 2048
30
+ self.register_buffer(
31
+ "bias",
32
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
33
+ 1, 1, max_positions, max_positions
34
+ ),
35
+ persistent=False,
36
+ )
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: Optional[tuple[torch.FloatTensor]],
41
+ past_key_values: Optional[Cache] = None,
42
+ cache_position: Optional[torch.LongTensor] = None,
43
+ attention_mask: Optional[torch.FloatTensor] = None,
44
+ head_mask: Optional[torch.FloatTensor] = None,
45
+ encoder_hidden_states: Optional[torch.Tensor] = None,
46
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
47
+ output_attentions: Optional[bool] = False,
48
+ **kwargs,
49
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
50
+ is_cross_attention = encoder_hidden_states is not None
51
+ if past_key_values is not None:
52
+ if isinstance(past_key_values, EncoderDecoderCache):
53
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
54
+ if is_cross_attention:
55
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
56
+ curr_past_key_value = past_key_values.cross_attention_cache
57
+ else:
58
+ curr_past_key_value = past_key_values.self_attention_cache
59
+ else:
60
+ curr_past_key_value = past_key_values
61
+
62
+ if is_cross_attention:
63
+ if not hasattr(self, "q_attn"):
64
+ raise ValueError(
65
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
66
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
67
+ )
68
+ query_states = self.q_attn(hidden_states)
69
+ attention_mask = encoder_attention_mask
70
+
71
+ # Try to get key/value states from cache if possible
72
+ if past_key_values is not None and is_updated:
73
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
74
+ value_states = curr_past_key_value.layers[self.layer_idx].values
75
+ else:
76
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
77
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
78
+ key_states = key_states.view(shape_kv).transpose(1, 2)
79
+ value_states = value_states.view(shape_kv).transpose(1, 2)
80
+ else:
81
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
82
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
83
+ key_states = key_states.view(shape_kv).transpose(1, 2)
84
+ value_states = value_states.view(shape_kv).transpose(1, 2)
85
+
86
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
87
+ query_states = query_states.view(shape_q).transpose(1, 2)
88
+
89
+ if (past_key_values is not None and not is_cross_attention) or (
90
+ past_key_values is not None and is_cross_attention and not is_updated
91
+ ):
92
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
93
+ cache_position = cache_position if not is_cross_attention else None
94
+ key_states, value_states = curr_past_key_value.update(
95
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
96
+ )
97
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
98
+ if is_cross_attention:
99
+ past_key_values.is_updated[self.layer_idx] = True
100
+
101
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
102
+
103
+ using_eager = self.config._attn_implementation == "eager"
104
+ attention_interface: Callable = eager_attention_forward
105
+ if self.config._attn_implementation != "eager":
106
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
107
+
108
+ if using_eager and self.reorder_and_upcast_attn:
109
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
110
+ query_states, key_states, value_states, attention_mask, head_mask
111
+ )
112
+ else:
113
+ if getattr(self.config, "prefix_allowed_length", None) is not None:
114
+ temp = self
115
+ temp.is_cross_attention = True
116
+ attn_output, attn_weights = attention_interface(
117
+ self if getattr(self.config, "prefix_allowed_length", None) is None else temp,
118
+ query_states,
119
+ key_states,
120
+ value_states,
121
+ attention_mask,
122
+ head_mask=head_mask,
123
+ dropout=self.attn_dropout.p if self.training else 0.0,
124
+ is_causal=is_causal if getattr(self.config, "is_prefix", None) is None else False,
125
+ **kwargs,
126
+ )
127
+ if getattr(self.config, "plot_attention_map", False) and self.layer_idx in getattr(self.config, "plot_attention_map_layer", []):
128
+ # pick batch=0, head=0
129
+ attn_bh = attn_weights[0, 0] # [L,S]
130
+ L, S = attn_bh.shape
131
+ if L > 1:
132
+ if getattr(self.config, "plot_attention_map_generation", 0) == 0:
133
+ print(f"Plotting attention map for inputs on layer {self.layer_idx}")
134
+ # full 2D heatmap
135
+ data = attn_bh.detach().float().cpu().numpy() # [L,S]
136
+ plt.figure(figsize=(6,5))
137
+ plt.imshow(data, aspect="auto", cmap="hot", vmin=0, vmax=0.01)
138
+ plt.colorbar()
139
+ plt.xlabel("Keys (S)")
140
+ plt.ylabel("Queries (L)")
141
+ plt.title(f"Attention map (B0,H0) L={L}, S={S}")
142
+ plt.show()
143
+ else:
144
+ if getattr(self.config, "plot_attention_map_generation", 0) == S:
145
+ print(f"Plotting attention row map for token {S} generation on layer {self.layer_idx}")
146
+ # attn_bh expected shape: [..., S] for the selected (B0, H0) row
147
+ row = attn_bh[0].detach().float().cpu().numpy() # -> np.ndarray shape [S]
148
+ n = row.shape[0]
149
+
150
+ # ----- First 1024 as 32x32 -----
151
+ head_1024 = row[:min(1024, n)]
152
+ grid = head_1024.reshape(32, 32)
153
+
154
+ plt.figure(figsize=(6, 5))
155
+ plt.imshow(grid, aspect="auto", cmap="hot", vmin=0, vmax=0.01)
156
+ plt.yticks([])
157
+ plt.colorbar()
158
+ plt.xlabel("Keys (S) [indices 0..1023]")
159
+ plt.title(f"Attention row (B0,H0) L={self.layer_idx}, S={S} — first 1024")
160
+ plt.tight_layout()
161
+ plt.show()
162
+
163
+ # ----- Tail (>=1024) as a single-row heatmap -----
164
+ tail = row[1024:]
165
+ if tail.size > 0:
166
+ plt.figure(figsize=(10, 1.2))
167
+ # one-row heatmap
168
+ plt.imshow(tail[None, :], aspect="auto", cmap="hot", vmin=0, vmax=0.01)
169
+ plt.yticks([])
170
+ plt.colorbar()
171
+ plt.xlabel(f"Keys (S) [indices 1024..{n-1}]")
172
+ plt.title(f"Attention row tail (B0,H0) L={self.layer_idx}, S={S}")
173
+ plt.tight_layout()
174
+ plt.show()
175
+
176
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
177
+ attn_output = self.c_proj(attn_output)
178
+ attn_output = self.resid_dropout(attn_output)
179
+
180
+ return attn_output, attn_weights
181
+
182
+ class GPT2BlockModified(GPT2Block):
183
+ def __init__(self, config, layer_idx=None):
184
+ super().__init__(config=config)
185
+ self.attn = GPT2AttentionModified(config=config, layer_idx=layer_idx)
186
+
187
+ def forward(
188
+ self,
189
+ hidden_states: Optional[tuple[torch.FloatTensor]],
190
+ past_key_values: Optional[Cache] = None,
191
+ cache_position: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.FloatTensor] = None,
193
+ head_mask: Optional[torch.FloatTensor] = None,
194
+ encoder_hidden_states: Optional[torch.Tensor] = None,
195
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs,
199
+ ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
200
+ residual = hidden_states
201
+ hidden_states = self.ln_1(hidden_states)
202
+ attn_output, self_attn_weights = self.attn(
203
+ hidden_states,
204
+ past_key_values=past_key_values,
205
+ cache_position=cache_position,
206
+ attention_mask=attention_mask,
207
+ head_mask=head_mask,
208
+ use_cache=use_cache,
209
+ output_attentions=output_attentions,
210
+ **kwargs,
211
+ )
212
+
213
+ # residual connection
214
+ hidden_states = attn_output + residual
215
+
216
+ if encoder_hidden_states is not None:
217
+ # add one self-attention block for cross-attention
218
+ if not hasattr(self, "crossattention"):
219
+ raise ValueError(
220
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
221
+ "cross-attention layers by setting `config.add_cross_attention=True`"
222
+ )
223
+ residual = hidden_states
224
+ hidden_states = self.ln_cross_attn(hidden_states)
225
+ cross_attn_output, cross_attn_weights = self.crossattention(
226
+ hidden_states,
227
+ past_key_values=past_key_values,
228
+ attention_mask=attention_mask,
229
+ head_mask=head_mask,
230
+ encoder_hidden_states=encoder_hidden_states,
231
+ encoder_attention_mask=encoder_attention_mask,
232
+ output_attentions=output_attentions,
233
+ )
234
+ # residual connection
235
+ hidden_states = residual + cross_attn_output
236
+
237
+ residual = hidden_states
238
+ hidden_states = self.ln_2(hidden_states)
239
+ feed_forward_hidden_states = self.mlp(hidden_states)
240
+ # residual connection
241
+ hidden_states = residual + feed_forward_hidden_states
242
+
243
+ outputs = (hidden_states,)
244
+ if output_attentions:
245
+ outputs += (self_attn_weights,)
246
+ if encoder_hidden_states is not None:
247
+ outputs += (cross_attn_weights,)
248
+
249
+ return outputs
250
+
251
+
252
+ class GPT2ModelModified(GPT2Model):
253
+ def __init__(self, config):
254
+ super().__init__(config)
255
+ self.config = config
256
+ self.config_causal = config
257
+ self.config_causal._attn_implementation = "eager" # Ensure causal mask creation uses eager implementation
258
+ # TEMPORARY: override the transformer blocks to pass segmentation masks
259
+ self.h = nn.ModuleList([GPT2BlockModified(config, layer_idx=i) for i in range(config.num_hidden_layers)])
260
+
261
+ def forward(
262
+ self,
263
+ input_ids: Optional[torch.LongTensor] = None,
264
+ past_key_values: Optional[Union[tuple[tuple[torch.Tensor]], Cache]] = None,
265
+ cache_position: Optional[torch.LongTensor] = None,
266
+ attention_mask: Optional[torch.FloatTensor] = None,
267
+ token_type_ids: Optional[torch.LongTensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ head_mask: Optional[torch.FloatTensor] = None,
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ encoder_hidden_states: Optional[torch.Tensor] = None,
272
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
273
+ use_cache: Optional[bool] = None,
274
+ output_attentions: Optional[bool] = None,
275
+ output_hidden_states: Optional[bool] = None,
276
+ return_dict: Optional[bool] = None,
277
+ segmentation_mask: Optional[torch.FloatTensor] = None,
278
+ **kwargs,
279
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
280
+ r"""
281
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
282
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
283
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
284
+ sequence tokens in the vocabulary.
285
+
286
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
287
+ `input_ids`.
288
+
289
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
290
+ [`PreTrainedTokenizer.__call__`] for details.
291
+
292
+ [What are input IDs?](../glossary#input-ids)
293
+ """
294
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
295
+ output_hidden_states = (
296
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
297
+ )
298
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
299
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
300
+
301
+ if input_ids is not None and inputs_embeds is not None:
302
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
303
+ elif input_ids is not None:
304
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
305
+ input_shape = input_ids.size()
306
+ input_ids = input_ids.view(-1, input_shape[-1])
307
+ batch_size = input_ids.shape[0]
308
+ elif inputs_embeds is not None:
309
+ input_shape = inputs_embeds.size()[:-1]
310
+ batch_size = inputs_embeds.shape[0]
311
+ else:
312
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
313
+
314
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
315
+
316
+ if token_type_ids is not None:
317
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
318
+
319
+ if self.gradient_checkpointing and self.training:
320
+ if use_cache:
321
+ logger.warning_once(
322
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
323
+ )
324
+ use_cache = False
325
+
326
+ # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
327
+ if use_cache:
328
+ if past_key_values is None:
329
+ past_key_values = DynamicCache()
330
+ elif isinstance(past_key_values, tuple):
331
+ logger.warning_once(
332
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
333
+ "You should pass an instance of `Cache` instead, e.g. "
334
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
335
+ )
336
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
337
+
338
+ if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
339
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
340
+
341
+ if inputs_embeds is None:
342
+ inputs_embeds = self.wte(input_ids)
343
+
344
+ if cache_position is None:
345
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
346
+ cache_position = torch.arange(
347
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
348
+ )
349
+ if position_ids is None:
350
+ position_ids = cache_position.unsqueeze(0)
351
+
352
+ position_embeds = self.wpe(position_ids)
353
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
354
+
355
+ # Attention mask.
356
+ # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
357
+ if attention_mask is not None and attention_mask.ndim < 4:
358
+ attention_mask = attention_mask.view(batch_size, -1)
359
+
360
+ causal_mask = create_causal_mask(
361
+ config=self.config_causal,
362
+ input_embeds=inputs_embeds,
363
+ attention_mask=attention_mask,
364
+ cache_position=cache_position,
365
+ past_key_values=past_key_values,
366
+ position_ids=position_ids,
367
+ )
368
+
369
+ # If a 2D or 3D attention mask is provided for the cross-attention
370
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
371
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
372
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
373
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
374
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
375
+ if encoder_attention_mask is None:
376
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
377
+ if _use_sdpa:
378
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
379
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
380
+ )
381
+ elif self._attn_implementation != "flash_attention_2":
382
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
383
+ else:
384
+ encoder_attention_mask = None
385
+
386
+ # Prepare head mask if needed
387
+ # 1.0 in head_mask indicate we keep the head
388
+ # attention_probs has shape bsz x n_heads x N x N
389
+ # head_mask has shape n_layer x batch x n_heads x N x N
390
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
391
+
392
+ if token_type_ids is not None:
393
+ token_type_embeds = self.wte(token_type_ids)
394
+ hidden_states = hidden_states + token_type_embeds
395
+
396
+ hidden_states = self.drop(hidden_states)
397
+
398
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
399
+
400
+ all_self_attentions = () if output_attentions else None
401
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
+ all_hidden_states = () if output_hidden_states else None
403
+ for i, block in enumerate(self.h):
404
+ # Model parallel
405
+ if self.model_parallel:
406
+ torch.cuda.set_device(hidden_states.device)
407
+ if isinstance(head_mask, torch.Tensor):
408
+ head_mask = head_mask.to(hidden_states.device)
409
+ if output_hidden_states:
410
+ all_hidden_states = all_hidden_states + (hidden_states,)
411
+ if segmentation_mask is not None and causal_mask is not None:
412
+ # Make a safe copy of the causal mask and ensure its spatial
413
+ # dimensions match the sequence length that the attention
414
+ # functions expect. This prevents off-by-one shape errors
415
+ # when using eager attention (torch.where requires same sizes).
416
+ causal_mask_modified = causal_mask.clone()
417
+ if getattr(self.config, "prefix_allowed_length", None) is not None:
418
+ causal_mask_modified[:, :, :, :self.config.prefix_allowed_length].zero_()
419
+
420
+ # Use the input sequence length to crop the causal mask if needed
421
+ seq_len = input_shape[-1]
422
+ if causal_mask_modified.shape[2] != seq_len or causal_mask_modified.shape[3] != seq_len:
423
+ causal_mask_modified = causal_mask_modified[:, :, :seq_len, :seq_len]
424
+
425
+ # Clip segmentation mask to fit into causal_mask_modified before adding.
426
+ _, _, M, N = segmentation_mask.shape
427
+ M = min(M, causal_mask_modified.shape[2])
428
+ N = min(N, causal_mask_modified.shape[3])
429
+ causal_mask_modified[:, :, :M, :N] += segmentation_mask[:, i, :M, :N].unsqueeze(1)
430
+ if getattr(self.config, "plot_attention_mask", False) and i in getattr(self.config, "plot_attention_mask_layer", [0]):
431
+ if segmentation_mask is not None and causal_mask is not None:
432
+ print(f"Block {i}: segmentation mask added to causal mask.")
433
+ plt.imshow(causal_mask_modified[0,0].detach().cpu(), aspect='auto', cmap='hot', vmin=-1, vmax=1)
434
+ plt.colorbar()
435
+ plt.title(f"Causal Mask with Segmentation (Block {i})")
436
+ plt.show()
437
+ else:
438
+ print(f"Block {i}: no segmentation mask applied.")
439
+ plt.imshow(causal_mask[0,0].detach().cpu(), aspect='auto', cmap='hot', vmin=-1, vmax=1)
440
+ plt.colorbar()
441
+ plt.title(f"Causal Mask (Block {i})")
442
+ plt.show()
443
+
444
+
445
+ outputs = block(
446
+ hidden_states,
447
+ past_key_values if not (self.gradient_checkpointing and self.training) else None,
448
+ cache_position,
449
+ causal_mask_modified if segmentation_mask is not None and causal_mask is not None else causal_mask,
450
+ head_mask[i],
451
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
452
+ encoder_attention_mask=encoder_attention_mask,
453
+ use_cache=use_cache,
454
+ output_attentions=output_attentions,
455
+ **kwargs,
456
+ )
457
+
458
+ hidden_states = outputs[0]
459
+
460
+ if output_attentions:
461
+ all_self_attentions = all_self_attentions + (outputs[1],)
462
+ if self.config.add_cross_attention:
463
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
464
+
465
+ # Model Parallel: If it's the last layer for that device, put things on the next device
466
+ if self.model_parallel:
467
+ for k, v in self.device_map.items():
468
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
469
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
470
+
471
+ hidden_states = self.ln_f(hidden_states)
472
+
473
+ hidden_states = hidden_states.view(output_shape)
474
+ # Add last hidden state
475
+ if output_hidden_states:
476
+ all_hidden_states = all_hidden_states + (hidden_states,)
477
+
478
+ past_key_values = past_key_values if use_cache else None
479
+ if not return_dict:
480
+ return tuple(
481
+ v
482
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
483
+ if v is not None
484
+ )
485
+
486
+ return BaseModelOutputWithPastAndCrossAttentions(
487
+ last_hidden_state=hidden_states,
488
+ past_key_values=past_key_values,
489
+ hidden_states=all_hidden_states,
490
+ attentions=all_self_attentions,
491
+ cross_attentions=all_cross_attentions,
492
+ )
493
+
494
+ class GPT2LMHeadModelModified(GPT2LMHeadModel):
495
+ def __init__(self, config):
496
+ super().__init__(config)
497
+ # replace the base transformer with our modified transformer implementation
498
+ self.transformer = GPT2ModelModified(config)
499
+ self.post_init()
500
+
501
+ def forward(
502
+ self,
503
+ input_ids: Optional[torch.LongTensor] = None,
504
+ past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None,
505
+ cache_position: Optional[torch.LongTensor] = None,
506
+ attention_mask: Optional[torch.FloatTensor] = None,
507
+ token_type_ids: Optional[torch.LongTensor] = None,
508
+ position_ids: Optional[torch.LongTensor] = None,
509
+ head_mask: Optional[torch.FloatTensor] = None,
510
+ inputs_embeds: Optional[torch.FloatTensor] = None,
511
+ encoder_hidden_states: Optional[torch.Tensor] = None,
512
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
513
+ labels: Optional[torch.LongTensor] = None,
514
+ use_cache: Optional[bool] = None,
515
+ output_attentions: Optional[bool] = None,
516
+ output_hidden_states: Optional[bool] = None,
517
+ return_dict: Optional[bool] = None,
518
+ logits_to_keep: Union[int, torch.Tensor] = 0,
519
+ segmentation_mask: Optional[torch.FloatTensor] = None,
520
+ prefix_allowed_length: Optional[int] = None,
521
+ plot_attention_mask: Optional[bool] = False,
522
+ plot_attention_mask_layer: Optional[list[int]] = [0],
523
+ plot_attention_map: Optional[bool] = False,
524
+ plot_attention_map_layer: Optional[list[int]] = [0],
525
+ plot_attention_map_generation: Optional[int] = 0,
526
+ **kwargs,
527
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
528
+ r"""
529
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
530
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
531
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
532
+ sequence tokens in the vocabulary.
533
+
534
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
535
+ `input_ids`.
536
+
537
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
538
+ [`PreTrainedTokenizer.__call__`] for details.
539
+
540
+ [What are input IDs?](../glossary#input-ids)
541
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
542
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
543
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
544
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
545
+ """
546
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547
+
548
+ if prefix_allowed_length is not None:
549
+ self.config.prefix_allowed_length = prefix_allowed_length
550
+
551
+ if plot_attention_mask is not None:
552
+ self.config.plot_attention_mask = plot_attention_mask
553
+ if plot_attention_mask_layer is not None:
554
+ self.config.plot_attention_mask_layer = plot_attention_mask_layer
555
+
556
+ if plot_attention_map is not None:
557
+ if plot_attention_map_layer is not None:
558
+ self.config.plot_attention_map_layer = plot_attention_map_layer
559
+ if plot_attention_map_generation is not None:
560
+ self.config.plot_attention_map_generation = plot_attention_map_generation
561
+ self.config.plot_attention_map = plot_attention_map
562
+
563
+ transformer_outputs = self.transformer(
564
+ input_ids,
565
+ past_key_values=past_key_values,
566
+ attention_mask=attention_mask,
567
+ cache_position=cache_position,
568
+ token_type_ids=token_type_ids,
569
+ position_ids=position_ids,
570
+ head_mask=head_mask,
571
+ inputs_embeds=inputs_embeds,
572
+ encoder_hidden_states=encoder_hidden_states,
573
+ encoder_attention_mask=encoder_attention_mask,
574
+ use_cache=use_cache,
575
+ output_attentions=output_attentions,
576
+ output_hidden_states=output_hidden_states,
577
+ return_dict=return_dict,
578
+ segmentation_mask=segmentation_mask, #Added this parameter
579
+ **kwargs,
580
+ )
581
+ hidden_states = transformer_outputs[0]
582
+
583
+ # Set device for model parallelism
584
+ if self.model_parallel:
585
+ torch.cuda.set_device(self.transformer.first_device)
586
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
587
+
588
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
589
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
590
+
591
+ loss = None
592
+ if labels is not None:
593
+ # Flatten the tokens
594
+ loss = self.loss_function(
595
+ logits,
596
+ labels,
597
+ vocab_size=self.config.vocab_size,
598
+ **kwargs,
599
+ )
600
+
601
+ if not return_dict:
602
+ output = (logits,) + transformer_outputs[1:]
603
+ return ((loss,) + output) if loss is not None else output
604
+
605
+ return CausalLMOutputWithCrossAttentions(
606
+ loss=loss,
607
+ logits=logits,
608
+ past_key_values=transformer_outputs.past_key_values,
609
+ hidden_states=transformer_outputs.hidden_states,
610
+ attentions=transformer_outputs.attentions,
611
+ cross_attentions=transformer_outputs.cross_attentions,
612
+ )
613
+
614
+ @torch.no_grad()
615
+ def expand_gpt2_positional_embeddings(
616
+ model: torch.nn.Module,
617
+ new_max_positions: int,
618
+ mode: str = "linear", # "linear" | "copy_last" | "zeros"
619
+ align_corners: bool = True, # for linear interpolation
620
+ ):
621
+ """
622
+ Expand GPT-2's learned positional embeddings (wpe) to `new_max_positions`.
623
+
624
+ Works with GPT2LMHeadModel or GPT2Model (HF). Updates model.config.n_positions (and n_ctx if present).
625
+ Does NOT mutate token embeddings; only position table + config.
626
+
627
+ Args:
628
+ model: HF GPT2LMHeadModel or GPT2Model (already loaded).
629
+ new_max_positions: int, desired max sequence length (e.g., 1536 or 2048).
630
+ mode: how to initialize new rows if expanding:
631
+ - "linear": 1D linear interpolation along position dim (recommended)
632
+ - "copy_last": copy the last learned vector into all new rows
633
+ - "zeros": initialize new rows to zero
634
+ align_corners: passed to F.interpolate for "linear" mode.
635
+
636
+ Returns:
637
+ model (same instance) with expanded wpe and updated config.
638
+ """
639
+ # Locate the position embedding table.
640
+ # Support both:
641
+ # - GPT2LMHeadModel (has .transformer which is a GPT2Model with .wpe)
642
+ # - GPT2Model (exposes .wpe directly)
643
+ if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"):
644
+ model_for_wpe = model.transformer
645
+ elif hasattr(model, "wpe"):
646
+ model_for_wpe = model
647
+ else:
648
+ raise ValueError("Model does not look like a GPT-2 family model with a position embedding 'wpe')")
649
+
650
+ wpe = model_for_wpe.wpe
651
+
652
+ old_n, d = wpe.weight.shape
653
+ if new_max_positions <= 0:
654
+ raise ValueError("new_max_positions must be positive")
655
+ if new_max_positions == old_n:
656
+ # Still update config for consistency
657
+ if hasattr(model.config, "n_positions"):
658
+ model.config.n_positions = new_max_positions
659
+ if hasattr(model.config, "n_ctx"):
660
+ model.config.n_ctx = new_max_positions
661
+ return model
662
+
663
+ device = wpe.weight.device
664
+ dtype = wpe.weight.dtype
665
+
666
+ if new_max_positions < old_n:
667
+ # Shrink (rare): just slice
668
+ new_weight = wpe.weight[:new_max_positions].clone()
669
+ else:
670
+ # Expand
671
+ if mode == "linear":
672
+ # Interpolate along position dimension.
673
+ # Treat embedding dim as channels: (1, d, old_n) -> (1, d, new_n) -> (new_n, d)
674
+ w = wpe.weight.transpose(0, 1).unsqueeze(0) # (1, d, old_n)
675
+ w_new = F.interpolate(w, size=new_max_positions, mode="linear", align_corners=align_corners)
676
+ new_weight = w_new.squeeze(0).transpose(0, 1).contiguous() # (new_n, d)
677
+ elif mode == "copy_last":
678
+ new_weight = torch.empty((new_max_positions, d), device=device, dtype=dtype)
679
+ new_weight[:old_n].copy_(wpe.weight)
680
+ new_weight[old_n:].copy_(wpe.weight[old_n - 1].expand(new_max_positions - old_n, d))
681
+ elif mode == "zeros":
682
+ new_weight = torch.zeros((new_max_positions, d), device=device, dtype=dtype)
683
+ new_weight[:old_n].copy_(wpe.weight)
684
+ else:
685
+ raise ValueError(f"Unknown mode '{mode}'")
686
+
687
+ # Replace embedding module on whichever object held the original table
688
+ new_wpe = torch.nn.Embedding(new_max_positions, d, device=device, dtype=dtype)
689
+ new_wpe.weight.copy_(new_weight)
690
+
691
+ # Keep requires_grad True (default). If you want to freeze, set .requires_grad_(False).
692
+ if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"):
693
+ model.transformer.wpe = new_wpe
694
+ else:
695
+ model.wpe = new_wpe
696
+
697
+ # Update config fields used by HF
698
+ if hasattr(model.config, "n_positions"):
699
+ model.config.n_positions = new_max_positions
700
+ if hasattr(model.config, "n_ctx"):
701
+ model.config.n_ctx = new_max_positions
702
+
703
+ return model
704
+
705
+ def create_decoder(attention = "sdpa"):
706
+ config = GPT2Config.from_pretrained("gpt2")
707
+ config._attn_implementation = attention
708
+ new_max_positions = 2048
709
+ decoder = GPT2LMHeadModelModified.from_pretrained("gpt2", config=config)
710
+ decoder.config._attn_implementation = attention
711
+ decoder = expand_gpt2_positional_embeddings(decoder, new_max_positions=new_max_positions, mode="linear")
712
+ return decoder
utils/processing.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+ import fsspec
3
+ import io
4
+ from PIL import Image
5
+
6
+ def image_transform(img_size=512):
7
+ return T.Compose([
8
+ T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC),
9
+ T.ToTensor(),
10
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
+ ])
12
+
13
+ def open_binary(path: str):
14
+ """
15
+ Open any (local or gs://) file for binary reading.
16
+ Returns a file-like object (context manager).
17
+ """
18
+ return fsspec.open(path, mode="rb").open()
19
+
20
+ def pil_from_path(path: str) -> Image.Image:
21
+ """
22
+ Load an image from local or GCS; returns a PIL image in RGB.
23
+ """
24
+ with open_binary(path) as f:
25
+ img_bytes = f.read()
26
+ im = Image.open(io.BytesIO(img_bytes)).convert("RGB")
27
+ return im