davda54 commited on
Commit
60e56bf
·
verified ·
1 Parent(s): ecb7a88

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +148 -175
modeling_gptbert.py CHANGED
@@ -121,60 +121,6 @@ class GeGLU(nn.Module):
121
  return x * gelu_new(gate)
122
 
123
 
124
- class Encoder(nn.Module):
125
- def __init__(self, config: GptBertConfig):
126
- super().__init__()
127
- self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
128
- self.short_long_ratio = config.short_long_ratio
129
-
130
- def set_window_length(self, config: GptBertConfig):
131
- for i, layer in enumerate(self.layers):
132
- if (i + 1) % self.local_global_ratio == 0:
133
- layer.set_window_length(config.global_window_length)
134
- else:
135
- layer.set_window_length(config.local_window_length)
136
-
137
- def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
138
- hidden_layers = [hidden_layer] if output_hidden_states else None
139
- v1 = None
140
- embeddings = hidden_layer
141
-
142
- for layer in self.layers:
143
- if checkpoint_activations:
144
- hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layers, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
145
- else:
146
- hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
147
-
148
- if output_hidden_states:
149
- hidden_layers.append(hidden_layer)
150
-
151
- return hidden_layer, hidden_layers
152
-
153
-
154
- class Layer(nn.Module):
155
- def __init__(self, config: GptBertConfig, layer_idx: int):
156
- super().__init__()
157
-
158
- self.attention = SelfAttention(config, layer_idx)
159
- self.mlp = FeedForward(config)
160
- self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
161
-
162
- def set_window_length(self, window_length: int):
163
- self.attention.set_window_length(window_length)
164
-
165
- def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
166
- attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
167
- qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
168
- mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
169
-
170
- attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
171
- mlp_layer = mlp_layer + attention_output
172
- hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
173
- output = hidden_layer + attention_output + self.mlp(mlp_layer)
174
-
175
- return output, v1
176
-
177
-
178
  class Embedding(nn.Module):
179
  def __init__(self, config: GptBertConfig):
180
  super().__init__()
@@ -246,6 +192,110 @@ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbeddi
246
  return attn
247
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  class SelfAttention(nn.Module):
250
  def __init__(self, config: GptBertConfig, layer_idx: int):
251
  super().__init__()
@@ -280,7 +330,7 @@ class SelfAttention(nn.Module):
280
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
281
 
282
  # Initialize rotary embeddings based on whether FlashAttention is available
283
- if self.config._attn_implementation == "flash_attention_2":
284
  self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
285
  else:
286
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
@@ -331,7 +381,7 @@ class SelfAttention(nn.Module):
331
 
332
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
333
  # Get original shape info
334
- if self.config._attn_implementation == "flash_attention_2":
335
  # Unpadded case
336
  indices, cu_seqlens, max_seqlen = padding_info
337
  total_seqlen = hidden_layer.size(0)
@@ -346,7 +396,7 @@ class SelfAttention(nn.Module):
346
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
347
  value = self.v_proj(hidden_layer)
348
 
349
- if self.config._attn_implementation == "flash_attention_2":
350
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
351
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
352
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
@@ -437,108 +487,58 @@ class FeedForward(nn.Module):
437
  return x
438
 
439
 
440
- # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
441
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
442
- @staticmethod
443
- def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
444
- # (total_nnz, 3, nheads, headdim)
445
- qkv = qkv.contiguous()
446
- total_nnz, _three, _nheads, headdim = qkv.shape
447
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
448
- # we get the same tensor
449
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
450
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
451
- apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
452
-
453
- ctx.save_for_backward(cos, sin, cu_seqlens)
454
- ctx.max_seqlen = max_seqlen
455
- return qkv
456
-
457
- @staticmethod
458
- def backward(ctx, do):
459
- cos, sin, cu_seqlens = ctx.saved_tensors
460
- do = do.contiguous()
461
- total_nnz, _three, _nheads, headdim = do.shape
462
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
463
- # we get the same tensor
464
- dqk = do[:, :2].view(total_nnz, -1, headdim)
465
- apply_rotary(
466
- dqk,
467
- cos,
468
- sin,
469
- seqlen_offsets=0,
470
- cu_seqlens=cu_seqlens,
471
- max_seqlen=ctx.max_seqlen,
472
- interleaved=False,
473
- inplace=True,
474
- conjugate=True,
475
- )
476
-
477
- return do, None, None, None, None, None, None
478
-
479
-
480
- # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
481
- def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
482
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
483
-
484
 
485
- # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
486
- class UnpaddedRotaryEmbedding(RotaryEmbedding):
487
- def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
488
- super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
489
- self.max_seqlen = max_seqlen
490
 
491
- if max_seqlen is not None and device is not None and dtype is not None:
492
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=None)
493
 
494
- def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
495
- if max_seqlen is not None:
496
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
 
497
 
498
- qkv = apply_rotary_unpadded(
499
- qkv,
500
- self._cos_cached,
501
- self._sin_cached,
502
- cu_seqlens=cu_seqlens,
503
- max_seqlen=max_seqlen,
504
- )
505
 
506
- return qkv
507
 
508
 
509
- class RotaryPositionalEmbeddings(nn.Module):
510
- def __init__(self, config, theta: int):
511
  super().__init__()
 
 
512
 
513
- head_size = config.query_key_head_size
514
- assert head_size % 2 == 0
515
- max_seq_len = config.max_sequence_length
516
-
517
- inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
518
- pos = torch.arange(max_seq_len, dtype=torch.float32)
519
- embedding = torch.einsum('n, d -> nd', pos, inv_freq)
520
- embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
521
- self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
522
- self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
523
-
524
- def forward(self, x: torch.Tensor):
525
- hidden_layer = x.float()
526
 
527
- seq_len = x.shape[2]
 
 
 
528
 
529
- cos_matrix = self.cos_matrix[:, None, :seq_len, :]
530
- sin_matrix = self.sin_matrix[:, None, :seq_len, :]
 
 
 
531
 
532
- x_rotate_half = torch.cat(
533
- [
534
- -hidden_layer[:, :, :, x.size(-1) // 2:],
535
- hidden_layer[:, :, :, :x.size(-1) // 2]
536
- ],
537
- dim=-1
538
- )
539
 
540
- out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
541
- return out.type_as(x)
542
 
543
 
544
  #
@@ -565,33 +565,6 @@ class GptBertPreTrainedModel(PreTrainedModel):
565
  module.bias.data.zero_()
566
  module.weight.data.fill_(1.0)
567
 
568
- @classmethod
569
- def _autoset_attn_implementation(
570
- cls,
571
- config,
572
- torch_dtype: Optional[torch.dtype] = None,
573
- device_map: Optional[Union[str, Dict[str, int]]] = None,
574
- check_device_map: bool = True,
575
- ):
576
- if config._attn_implementation_internal is None:
577
- config._attn_implementation_internal = "flash_attention_2"
578
- try:
579
- return cls._check_and_enable_flash_attn_2(
580
- config,
581
- torch_dtype=torch.float16,
582
- device_map=device_map,
583
- hard_check_only=False,
584
- check_device_map=check_device_map,
585
- )
586
- except (ValueError, ImportError):
587
- config._attn_implementation_internal = None
588
- return super()._autoset_attn_implementation(
589
- config,
590
- torch_dtype=torch_dtype,
591
- device_map=device_map,
592
- check_device_map=check_device_map,
593
- )
594
-
595
 
596
  class GptBertModel(GptBertPreTrainedModel):
597
  def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
@@ -634,7 +607,7 @@ class GptBertModel(GptBertPreTrainedModel):
634
  else:
635
  attention_mask = attention_mask.bool()
636
 
637
- if self.config._attn_implementation == "flash_attention_2":
638
  if len(attention_mask.size()) != 2:
639
  raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
640
  with torch.no_grad():
@@ -665,7 +638,7 @@ class GptBertModel(GptBertPreTrainedModel):
665
  contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
666
 
667
  # Pad output if using FlashAttention
668
- if self.config._attn_implementation == "flash_attention_2":
669
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
670
  if output_hidden_states:
671
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
 
121
  return x * gelu_new(gate)
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  class Embedding(nn.Module):
125
  def __init__(self, config: GptBertConfig):
126
  super().__init__()
 
192
  return attn
193
 
194
 
195
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
196
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
197
+ @staticmethod
198
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
199
+ # (total_nnz, 3, nheads, headdim)
200
+ qkv = qkv.contiguous()
201
+ total_nnz, _three, _nheads, headdim = qkv.shape
202
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
203
+ # we get the same tensor
204
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
205
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
206
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
207
+
208
+ ctx.save_for_backward(cos, sin, cu_seqlens)
209
+ ctx.max_seqlen = max_seqlen
210
+ return qkv
211
+
212
+ @staticmethod
213
+ def backward(ctx, do):
214
+ cos, sin, cu_seqlens = ctx.saved_tensors
215
+ do = do.contiguous()
216
+ total_nnz, _three, _nheads, headdim = do.shape
217
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
218
+ # we get the same tensor
219
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
220
+ apply_rotary(
221
+ dqk,
222
+ cos,
223
+ sin,
224
+ seqlen_offsets=0,
225
+ cu_seqlens=cu_seqlens,
226
+ max_seqlen=ctx.max_seqlen,
227
+ interleaved=False,
228
+ inplace=True,
229
+ conjugate=True,
230
+ )
231
+
232
+ return do, None, None, None, None, None, None
233
+
234
+
235
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
236
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
237
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
238
+
239
+
240
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
241
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
242
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
243
+ super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
244
+ self.max_seqlen = max_seqlen
245
+
246
+ if max_seqlen is not None and device is not None and dtype is not None:
247
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=None)
248
+
249
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
250
+ if max_seqlen is not None:
251
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
252
+
253
+ qkv = apply_rotary_unpadded(
254
+ qkv,
255
+ self._cos_cached,
256
+ self._sin_cached,
257
+ cu_seqlens=cu_seqlens,
258
+ max_seqlen=max_seqlen,
259
+ )
260
+
261
+ return qkv
262
+
263
+
264
+ class RotaryPositionalEmbeddings(nn.Module):
265
+ def __init__(self, config, theta: int):
266
+ super().__init__()
267
+
268
+ head_size = config.query_key_head_size
269
+ assert head_size % 2 == 0
270
+ max_seq_len = config.max_sequence_length
271
+
272
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
273
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
274
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
275
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
276
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
277
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
278
+
279
+ def forward(self, x: torch.Tensor):
280
+ hidden_layer = x.float()
281
+
282
+ seq_len = x.shape[2]
283
+
284
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
285
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
286
+
287
+ x_rotate_half = torch.cat(
288
+ [
289
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
290
+ hidden_layer[:, :, :, :x.size(-1) // 2]
291
+ ],
292
+ dim=-1
293
+ )
294
+
295
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
296
+ return out.type_as(x)
297
+
298
+
299
  class SelfAttention(nn.Module):
300
  def __init__(self, config: GptBertConfig, layer_idx: int):
301
  super().__init__()
 
330
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
331
 
332
  # Initialize rotary embeddings based on whether FlashAttention is available
333
+ if is_flash_attn_2_available():
334
  self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
335
  else:
336
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
 
381
 
382
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
383
  # Get original shape info
384
+ if is_flash_attn_2_available():
385
  # Unpadded case
386
  indices, cu_seqlens, max_seqlen = padding_info
387
  total_seqlen = hidden_layer.size(0)
 
396
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
397
  value = self.v_proj(hidden_layer)
398
 
399
+ if is_flash_attn_2_available():
400
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
401
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
402
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
 
487
  return x
488
 
489
 
490
+ class Layer(nn.Module):
491
+ def __init__(self, config: GptBertConfig, layer_idx: int):
492
+ super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
+ self.attention = SelfAttention(config, layer_idx)
495
+ self.mlp = FeedForward(config)
496
+ self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
 
 
497
 
498
+ def set_window_length(self, window_length: int):
499
+ self.attention.set_window_length(window_length)
500
 
501
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
502
+ attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
503
+ qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
504
+ mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
505
 
506
+ attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
507
+ mlp_layer = mlp_layer + attention_output
508
+ hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
509
+ output = hidden_layer + attention_output + self.mlp(mlp_layer)
 
 
 
510
 
511
+ return output, v1
512
 
513
 
514
+ class Encoder(nn.Module):
515
+ def __init__(self, config: GptBertConfig):
516
  super().__init__()
517
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
518
+ self.short_long_ratio = config.short_long_ratio
519
 
520
+ def set_window_length(self, config: GptBertConfig):
521
+ for i, layer in enumerate(self.layers):
522
+ if (i + 1) % self.local_global_ratio == 0:
523
+ layer.set_window_length(config.global_window_length)
524
+ else:
525
+ layer.set_window_length(config.local_window_length)
 
 
 
 
 
 
 
526
 
527
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
528
+ hidden_layers = [hidden_layer] if output_hidden_states else None
529
+ v1 = None
530
+ embeddings = hidden_layer
531
 
532
+ for layer in self.layers:
533
+ if checkpoint_activations:
534
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layers, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
535
+ else:
536
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
537
 
538
+ if output_hidden_states:
539
+ hidden_layers.append(hidden_layer)
 
 
 
 
 
540
 
541
+ return hidden_layer, hidden_layers
 
542
 
543
 
544
  #
 
565
  module.bias.data.zero_()
566
  module.weight.data.fill_(1.0)
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  class GptBertModel(GptBertPreTrainedModel):
570
  def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
 
607
  else:
608
  attention_mask = attention_mask.bool()
609
 
610
+ if is_flash_attn_2_available():
611
  if len(attention_mask.size()) != 2:
612
  raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
613
  with torch.no_grad():
 
638
  contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
639
 
640
  # Pad output if using FlashAttention
641
+ if is_flash_attn_2_available():
642
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
643
  if output_hidden_states:
644
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]