davda54 commited on
Commit
3bd5fb5
·
verified ·
1 Parent(s): e071182

fixes and optimizations

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +172 -292
modeling_gptbert.py CHANGED
@@ -3,14 +3,13 @@ from __future__ import annotations
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
- from torch import _softmax_backward_data as _softmax_backward_data
7
 
8
  from functools import partial, lru_cache
9
 
10
  from .configuration_gptbert import GptBertConfig
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.activations import gelu_new
13
- from transformers.utils import is_flash_attn_2_available
14
  from transformers.modeling_outputs import (
15
  MaskedLMOutput,
16
  MultipleChoiceModelOutput,
@@ -23,6 +22,9 @@ from transformers.modeling_outputs import (
23
  import math
24
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
 
 
 
 
26
  # Workaround for transformers < 4.36.0 check_imports issue
27
  # See: https://github.com/huggingface/transformers/issues/28459
28
  try:
@@ -31,13 +33,15 @@ try:
31
  from flash_attn.layers.rotary import RotaryEmbedding
32
  from flash_attn.ops.triton.rotary import apply_rotary
33
  else:
34
- flash_attn_varlen_qkvpacked_func = None
35
- RotaryEmbedding = object
36
- apply_rotary = None
 
37
  except ImportError:
38
- flash_attn_varlen_qkvpacked_func = None
39
- RotaryEmbedding = object
40
- apply_rotary = None
 
41
 
42
 
43
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
@@ -89,34 +93,6 @@ class CastedLinearIn(nn.Linear):
89
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
90
 
91
 
92
- class CastedLinearOut(nn.Linear):
93
- def __init__(self, in_features, out_features, bias):
94
- super().__init__(in_features, out_features, bias=bias)
95
- self.scale = nn.Parameter(torch.ones(out_features))
96
-
97
- def forward(self, x):
98
- return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
99
-
100
-
101
- class MultiCastedLinearOrtho(nn.Module):
102
- def __init__(self, in_features, out_features, bias):
103
- super().__init__()
104
- self.in_features = in_features
105
- self.out_features = out_features
106
-
107
- self.weights = nn.ParameterList()
108
- for out_feature in out_features:
109
- self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
110
-
111
- if bias:
112
- self.bias = nn.Parameter(torch.zeros(sum(out_features)))
113
- else:
114
- self.bias = self.register_parameter("bias", None)
115
-
116
- def forward(self, x):
117
- return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
118
-
119
-
120
  class MultiCastedLinearOrthoIn(nn.Module):
121
  def __init__(self, in_features, out_features, bias):
122
  super().__init__()
@@ -138,77 +114,40 @@ class MultiCastedLinearOrthoIn(nn.Module):
138
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
139
 
140
 
141
- class MultiCastedLinearOrthoOut(nn.Module):
142
- def __init__(self, in_features, out_features, bias):
143
- super().__init__()
144
-
145
- self.in_features = in_features
146
- self.out_features = out_features
147
-
148
- self.weights = nn.ParameterList()
149
- for out_feature in out_features:
150
- self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
151
-
152
- if bias:
153
- self.bias = nn.Parameter(torch.zeros(sum(out_features)))
154
- else:
155
- self.bias = self.register_parameter("bias", None)
156
-
157
- self.scale = nn.Parameter(torch.ones(sum(out_features)))
158
-
159
- def forward(self, x):
160
- return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
161
-
162
-
163
  class GeGLU(nn.Module):
164
  def forward(self, x):
165
  x, gate = x.chunk(2, dim=-1)
166
  return x * gelu_new(gate)
167
 
168
 
169
- class MaskedSoftmax(torch.autograd.Function):
170
- @staticmethod
171
- def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int):
172
- ctx.dim = dim
173
- x.masked_fill_(mask, float('-inf'))
174
- x = torch.softmax(x, ctx.dim)
175
- x.masked_fill_(mask, 0.0)
176
- ctx.save_for_backward(x)
177
- return x
178
-
179
- @staticmethod
180
- def backward(ctx, grad_output: torch.Tensor):
181
- output, = ctx.saved_tensors
182
- inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
183
- return inputGrad, None, None
184
-
185
-
186
  class Encoder(nn.Module):
187
  def __init__(self, config: GptBertConfig):
188
  super().__init__()
189
-
190
  self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
191
  self.short_long_ratio = config.short_long_ratio
192
 
193
  def set_window_length(self, config: GptBertConfig):
194
  for i, layer in enumerate(self.layers):
195
- if (i+1) % self.short_long_ratio == 0:
196
- layer.set_window_length(config.window_length)
197
  else:
198
- layer.set_window_length(256)
199
 
200
- def forward(self, hidden_layer: torch.Tensor, padding_info):
201
- hidden_states = []
202
- attention_probs = []
203
  v1 = None
204
  embeddings = hidden_layer
205
 
206
  for layer in self.layers:
207
- hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, padding_info)
208
- hidden_states.append(hidden_layer)
209
- attention_probs.append(attention_p)
 
210
 
211
- return hidden_states, attention_probs
 
 
 
212
 
213
 
214
  class Layer(nn.Module):
@@ -227,12 +166,12 @@ class Layer(nn.Module):
227
  qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
228
  mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
229
 
230
- attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, padding_info)
231
  mlp_layer = mlp_layer + attention_output
232
  hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
233
  output = hidden_layer + attention_output + self.mlp(mlp_layer)
234
 
235
- return output, v1, attention_p
236
 
237
 
238
  class Embedding(nn.Module):
@@ -257,21 +196,22 @@ class Embedding(nn.Module):
257
  return self.dropout(word_embedding)
258
 
259
 
260
- class MaskClassifier(nn.Module):
261
- def __init__(self, config: GptBertConfig, embedding_weights: nn.Parameter):
262
  super().__init__()
263
 
264
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
265
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
266
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
267
- self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
268
 
269
  def forward(self, x: torch.Tensor):
270
- x = self.pre_norm(x)
271
  x = self.projection(x)
272
  x = gelu_new(x)
273
- x = self.post_norm(x)
274
- return self.emb2vocab(x)
 
275
 
276
 
277
  def flash_attention_forward(
@@ -354,14 +294,8 @@ class SelfAttention(nn.Module):
354
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
355
 
356
  # Initialize rotary embeddings based on whether FlashAttention is available
357
- if is_flash_attn_2_available():
358
- self.rope_embedding = UnpaddedRotaryEmbedding(
359
- dim=config.d_qk,
360
- base=theta,
361
- max_seqlen=config.max_sequence_length,
362
- device=None,
363
- dtype=None
364
- )
365
  else:
366
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
367
 
@@ -380,10 +314,10 @@ class SelfAttention(nn.Module):
380
  """Create and cache window attention mask."""
381
  if self.is_causal:
382
  mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
383
- mask = ~mask.tril().triu(diagonal=-self.window_length)
384
  else:
385
  mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
386
- mask = ~mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
387
  return mask.view(1, 1, query_length, key_length)
388
 
389
  def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -394,26 +328,24 @@ class SelfAttention(nn.Module):
394
  # Use cached window mask
395
  with torch.no_grad():
396
  window_mask = self._get_window_mask(query_length, key_length, query.device)
397
-
398
  if padding_mask is not None:
399
- attention_mask = padding_mask | window_mask
400
  else:
401
  attention_mask = window_mask
402
 
403
- attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
404
- attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
405
-
406
- attention_probabilities = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
407
- attention_probabilities = self.dropout(attention_probabilities)
408
-
409
- value = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
410
- value = value.view(batch_size, self.num_attention_heads, query_length, self.d_v)
411
-
412
- return value, attention_probabilities.detach()
413
 
414
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
415
  # Get original shape info
416
- if is_flash_attn_2_available() and isinstance(padding_info, tuple):
417
  # Unpadded case
418
  indices, cu_seqlens, max_seqlen = padding_info
419
  total_seqlen = hidden_layer.size(0)
@@ -421,16 +353,14 @@ class SelfAttention(nn.Module):
421
  else:
422
  # Padded case
423
  batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
424
- hidden_layer = hidden_layer.transpose(0, 1) # [seq_len, batch_size, hidden_size]
425
- qk_layer = qk_layer.transpose(0, 1)
426
 
427
- hidden_layer = self.pre_v_norm(hidden_layer)
428
- qk_layer = self.pre_qk_norm(qk_layer)
429
 
430
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
431
  value = self.v_proj(hidden_layer)
432
 
433
- if is_flash_attn_2_available() and isinstance(padding_info, tuple):
434
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
435
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
436
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
@@ -445,15 +375,7 @@ class SelfAttention(nn.Module):
445
  value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
446
 
447
  # Prepare qkv for FlashAttention
448
- if self.num_kv_heads == self.num_attention_heads:
449
- # Standard MHA
450
- qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
451
- else:
452
- # GQA case - need to repeat k,v heads
453
- num_rep = self.num_attention_heads // self.num_kv_heads
454
- key = key.repeat_interleave(num_rep, dim=1)
455
- value = value.repeat_interleave(num_rep, dim=1)
456
- qkv = torch.stack([query, key, value], dim=1)
457
 
458
  # Determine window size for local attention
459
  if self.window_length is not None and self.window_length > 0:
@@ -478,16 +400,15 @@ class SelfAttention(nn.Module):
478
 
479
  # Reshape output back
480
  output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
481
- attention_probabilities = None
482
 
483
  else:
484
  # Standard attention path
485
  query_length = hidden_layer.size(0)
486
  key_length = hidden_layer.size(0)
487
 
488
- query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3)
489
- key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3)
490
- value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_v).permute(1, 2, 0, 3)
491
 
492
  query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
493
  key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
@@ -500,27 +421,15 @@ class SelfAttention(nn.Module):
500
  query = self.rope_embedding(query)
501
  key = self.rope_embedding(key)
502
 
503
- # Handle GQA for standard attention
504
- if self.num_kv_heads != self.num_attention_heads:
505
- num_rep = self.num_attention_heads // self.num_kv_heads
506
- key = key.repeat_interleave(num_rep, dim=1)
507
- value = value.repeat_interleave(num_rep, dim=1)
508
 
509
- output, attention_probabilities = self.attention_operation(query, key, value, padding_info)
510
- output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
511
-
512
- output = self.inter_norm(output)
513
  output = self.out_proj(output)
 
514
 
515
- # Handle output padding if necessary
516
- if is_flash_attn_2_available() and isinstance(padding_info, tuple):
517
- # Already in correct format for unpadded
518
- pass
519
- else:
520
- # Transpose back to [batch_size, seq_len, hidden_size]
521
- output = output.transpose(0, 1)
522
 
523
- return self.dropout(output), v1, attention_probabilities
524
 
525
  class FeedForward(nn.Module):
526
  def __init__(self, config: GptBertConfig):
@@ -533,12 +442,13 @@ class FeedForward(nn.Module):
533
  self.dropout = nn.Dropout(config.feed_forward_dropout_p)
534
 
535
  def forward(self, x: torch.Tensor):
536
- x = self.pre_norm(x)
537
  x = self.up_proj(x)
538
  x = self.activation(x)
539
  x = self.inter_norm(x.float()).type_as(x)
540
  x = self.down_proj(x)
541
- return self.dropout(x)
 
542
 
543
 
544
  class ApplyRotaryEmbUnpad(torch.autograd.Function):
@@ -596,23 +506,17 @@ class ApplyRotaryEmbUnpad(torch.autograd.Function):
596
  return do, None, None, None, None, None, None
597
 
598
 
599
- def apply_rotary_unpadded(
600
- qkv,
601
- cos,
602
- sin,
603
- cu_seqlens: Optional[torch.Tensor] = None,
604
- max_seqlen: Optional[int] = None,
605
- ):
606
  return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
607
 
608
 
609
  class UnpaddedRotaryEmbedding(RotaryEmbedding):
610
- def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
611
- super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
612
  self.max_seqlen = max_seqlen
613
 
614
  if max_seqlen is not None and device is not None and dtype is not None:
615
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
616
 
617
  def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
618
  if max_seqlen is not None:
@@ -696,7 +600,7 @@ class GptBertPreTrainedModel(PreTrainedModel):
696
  def _init_weights(self, module):
697
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
698
 
699
- if isinstance(module, nn.Linear):
700
  nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
701
  if module.bias is not None:
702
  module.bias.data.zero_()
@@ -706,6 +610,33 @@ class GptBertPreTrainedModel(PreTrainedModel):
706
  module.bias.data.zero_()
707
  module.weight.data.fill_(1.0)
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
 
710
  class GptBertModel(GptBertPreTrainedModel):
711
  def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
@@ -715,8 +646,10 @@ class GptBertModel(GptBertPreTrainedModel):
715
 
716
  self.embedding = Embedding(config)
717
  self.encoder = Encoder(config)
718
- self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
719
  self.set_window_length(config)
 
 
720
 
721
  def set_window_length(self, config) -> None:
722
  self.encoder.set_window_length(config)
@@ -732,7 +665,7 @@ class GptBertModel(GptBertPreTrainedModel):
732
  input_ids: Optional[torch.Tensor] = None,
733
  attention_mask: Optional[torch.Tensor] = None,
734
  output_hidden_states: Optional[bool] = None
735
- ) -> List[torch.Tensor]:
736
  if input_ids is not None:
737
  input_shape = input_ids.size()
738
  else:
@@ -741,39 +674,50 @@ class GptBertModel(GptBertPreTrainedModel):
741
  batch_size, seq_length = input_shape
742
  device = input_ids.device
743
 
744
- if is_flash_attn_2_available():
745
- if attention_mask is None:
746
- attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
747
- elif attention_mask is not None and len(attention_mask.size()) != 2:
748
- raise ValueError("Only attention mask with two dimensions is supported now.")
749
- input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
 
 
 
 
750
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
751
  else:
752
- if attention_mask is None:
753
- attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
754
- if attention_mask is not None:
755
- attention_mask = ~attention_mask.bool()
756
- if len(attention_mask.size()) == 2:
757
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
758
- elif len(attention_mask.size()) == 3:
759
- attention_mask = attention_mask.unsqueeze(1)
760
- if self.config.is_decoder:
761
- attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(1)
762
  padding_info = attention_mask
763
 
764
  static_embeddings = self.embedding(input_ids)
765
- contextualized_embeddings, attention_probs = self.encoder(static_embeddings, padding_info)
766
- last_layer = contextualized_embeddings[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
767
 
768
  # Pad output if using FlashAttention
769
- if is_flash_attn_2_available():
770
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
771
  if output_hidden_states:
772
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
773
  else:
774
  contextualized_embeddings = None
775
 
776
- return last_layer, contextualized_embeddings, attention_probs
777
 
778
  def forward(
779
  self,
@@ -786,26 +730,22 @@ class GptBertModel(GptBertPreTrainedModel):
786
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
787
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
788
 
789
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(
790
- input_ids, attention_mask, output_hidden_states
791
- )
792
 
793
  if not return_dict:
794
  return (
795
  sequence_output,
796
- *([contextualized_embeddings] if output_hidden_states else []),
797
- *([attention_probs] if output_attentions else [])
798
  )
799
 
800
  return BaseModelOutput(
801
  last_hidden_state=sequence_output,
802
- hidden_states=contextualized_embeddings if output_hidden_states else None,
803
- attentions=attention_probs if output_attentions else None
804
  )
805
 
806
 
807
  class GptBertForMaskedLM(GptBertModel):
808
- _keys_to_ignore_on_load_unexpected = ["head"]
809
 
810
  def __init__(self, config: GptBertConfig, **kwargs):
811
  super().__init__(config, add_mlm_layer=True, **kwargs)
@@ -820,17 +760,14 @@ class GptBertForMaskedLM(GptBertModel):
820
  self,
821
  input_ids: Optional[torch.Tensor] = None,
822
  attention_mask: Optional[torch.Tensor] = None,
823
- token_type_ids: Optional[torch.Tensor] = None,
824
- position_ids: Optional[torch.Tensor] = None,
825
  output_hidden_states: Optional[bool] = None,
826
- output_attentions: Optional[bool] = None,
827
  return_dict: Optional[bool] = None,
828
  labels: Optional[torch.LongTensor] = None,
829
  **kwargs
830
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
831
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
832
 
833
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
834
  subword_prediction = self.classifier(sequence_output)
835
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
836
 
@@ -847,60 +784,19 @@ class GptBertForMaskedLM(GptBertModel):
847
  if not return_dict:
848
  output = (
849
  subword_prediction,
850
- *([contextualized_embeddings] if output_hidden_states else []),
851
- *([attention_probs] if output_attentions else [])
852
  )
853
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
854
 
855
  return MaskedLMOutput(
856
  loss=masked_lm_loss,
857
  logits=subword_prediction,
858
- hidden_states=contextualized_embeddings if output_hidden_states else None,
859
- attentions=attention_probs if output_attentions else None
860
  )
861
 
862
 
863
- class Classifier(nn.Module):
864
- def __init__(self, config: GptBertConfig, num_labels: int):
865
- super().__init__()
866
-
867
- drop_out = getattr(config, "cls_dropout", None)
868
- drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
869
-
870
- self.projection: CastedLinear
871
- self.emb2vocab: CastedLinear
872
- self.pre_norm: nn.LayerNorm
873
- self.post_norm: nn.LayerNorm
874
-
875
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
876
- self.projection = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
877
- # self.projection = CastedLinear(config.hidden_size, config.hidden_size, bias=False)
878
- self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
879
- self.emb2vocab = nn.Linear(config.hidden_size, num_labels, bias=True)
880
- # self.emb2vocab = CastedLinear(config.hidden_size, num_labels, bias=True)
881
- self.dropout = nn.Dropout(drop_out)
882
-
883
- self.initialize(config.hidden_size, config.intermediate_size, num_labels)
884
-
885
- @torch.no_grad()
886
- def initialize(self, hidden_size: int, intermediate_size: int, vocab_size: int) -> None:
887
- proj_std: float = math.sqrt(2.0 / (hidden_size + intermediate_size))
888
-
889
- nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
890
- nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
891
- self.emb2vocab.bias.zero_()
892
-
893
- def forward(self, x: torch.Tensor):
894
- x = self.pre_norm(x)
895
- x = self.dropout(x)
896
- x = self.projection(x)
897
- x = gelu_new(x)
898
- x = self.post_norm(x)
899
- return self.emb2vocab(x)
900
-
901
-
902
  class GptBertForCausalLM(GptBertModel):
903
- _keys_to_ignore_on_load_unexpected = ["head"]
904
 
905
  def __init__(self, config: GptBertConfig, **kwargs):
906
  config.is_decoder = True
@@ -947,29 +843,27 @@ class GptBertForCausalLM(GptBertModel):
947
  assert past_key_values is None, "past_key_values is not supported for now"
948
  assert not use_cache, "use_cache is not supported for now"
949
 
950
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
951
  subword_prediction = self.classifier(sequence_output)
952
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
953
 
954
- masked_lm_loss = None
955
  if labels is not None:
956
  labels_flatten = labels[:, 1:].flatten()
957
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
958
- masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
959
 
960
  if not return_dict:
961
  output = (
962
  subword_prediction,
963
- *([contextualized_embeddings] if output_hidden_states else []),
964
- *([attention_probs] if output_attentions else [])
965
  )
966
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
967
 
968
  return CausalLMOutput(
969
- loss=masked_lm_loss,
970
  logits=subword_prediction,
971
- hidden_states=contextualized_embeddings if output_hidden_states else None,
972
- attentions=attention_probs if output_attentions else None
973
  )
974
 
975
  def prepare_inputs_for_generation(
@@ -1025,21 +919,19 @@ class GptBertForCausalLM(GptBertModel):
1025
 
1026
 
1027
  class GptBertForSequenceClassification(GptBertModel):
1028
- _keys_to_ignore_on_load_unexpected = ["classifier"]
1029
 
1030
  def __init__(self, config: GptBertConfig, **kwargs):
1031
  super().__init__(config, add_mlm_layer=False, **kwargs)
1032
 
1033
  self.num_labels = config.num_labels
1034
- self.head = Classifier(config, self.num_labels)
 
1035
 
1036
  def forward(
1037
  self,
1038
  input_ids: Optional[torch.Tensor] = None,
1039
  attention_mask: Optional[torch.Tensor] = None,
1040
- token_type_ids: Optional[torch.Tensor] = None,
1041
- position_ids: Optional[torch.Tensor] = None,
1042
- output_attentions: Optional[bool] = None,
1043
  output_hidden_states: Optional[bool] = None,
1044
  return_dict: Optional[bool] = None,
1045
  labels: Optional[torch.LongTensor] = None,
@@ -1047,8 +939,8 @@ class GptBertForSequenceClassification(GptBertModel):
1047
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1048
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
 
1050
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1051
- logits = self.head(sequence_output[:, 0, :])
1052
 
1053
  loss = None
1054
  if labels is not None:
@@ -1076,35 +968,31 @@ class GptBertForSequenceClassification(GptBertModel):
1076
  if not return_dict:
1077
  output = (
1078
  logits,
1079
- *([contextualized_embeddings] if output_hidden_states else []),
1080
- *([attention_probs] if output_attentions else [])
1081
  )
1082
  return ((loss,) + output) if loss is not None else output
1083
 
1084
  return SequenceClassifierOutput(
1085
  loss=loss,
1086
  logits=logits,
1087
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1088
- attentions=attention_probs if output_attentions else None
1089
  )
1090
 
1091
 
1092
  class GptBertForTokenClassification(GptBertModel):
1093
- _keys_to_ignore_on_load_unexpected = ["classifier"]
1094
 
1095
  def __init__(self, config: GptBertConfig, **kwargs):
1096
  super().__init__(config, add_mlm_layer=False, **kwargs)
1097
 
1098
  self.num_labels = config.num_labels
1099
- self.head = Classifier(config, self.num_labels)
 
1100
 
1101
  def forward(
1102
  self,
1103
  input_ids: Optional[torch.Tensor] = None,
1104
  attention_mask: Optional[torch.Tensor] = None,
1105
- token_type_ids: Optional[torch.Tensor] = None,
1106
- position_ids: Optional[torch.Tensor] = None,
1107
- output_attentions: Optional[bool] = None,
1108
  output_hidden_states: Optional[bool] = None,
1109
  return_dict: Optional[bool] = None,
1110
  labels: Optional[torch.LongTensor] = None,
@@ -1112,8 +1000,8 @@ class GptBertForTokenClassification(GptBertModel):
1112
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1113
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114
 
1115
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1116
- logits = self.head(sequence_output)
1117
 
1118
  loss = None
1119
  if labels is not None:
@@ -1137,21 +1025,19 @@ class GptBertForTokenClassification(GptBertModel):
1137
 
1138
 
1139
  class GptBertForQuestionAnswering(GptBertModel):
1140
- _keys_to_ignore_on_load_unexpected = ["classifier"]
1141
 
1142
  def __init__(self, config: GptBertConfig, **kwargs):
1143
  super().__init__(config, add_mlm_layer=False, **kwargs)
1144
 
1145
  self.num_labels = config.num_labels
1146
- self.head = Classifier(config, self.num_labels)
 
1147
 
1148
  def forward(
1149
  self,
1150
  input_ids: Optional[torch.Tensor] = None,
1151
  attention_mask: Optional[torch.Tensor] = None,
1152
- token_type_ids: Optional[torch.Tensor] = None,
1153
- position_ids: Optional[torch.Tensor] = None,
1154
- output_attentions: Optional[bool] = None,
1155
  output_hidden_states: Optional[bool] = None,
1156
  return_dict: Optional[bool] = None,
1157
  start_positions: Optional[torch.Tensor] = None,
@@ -1160,8 +1046,8 @@ class GptBertForQuestionAnswering(GptBertModel):
1160
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1161
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1162
 
1163
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1164
- logits = self.head(sequence_output)
1165
 
1166
  start_logits, end_logits = logits.split(1, dim=-1)
1167
  start_logits = start_logits.squeeze(-1).contiguous()
@@ -1189,8 +1075,7 @@ class GptBertForQuestionAnswering(GptBertModel):
1189
  output = (
1190
  start_logits,
1191
  end_logits,
1192
- *([contextualized_embeddings] if output_hidden_states else []),
1193
- *([attention_probs] if output_attentions else [])
1194
  )
1195
  return ((total_loss,) + output) if total_loss is not None else output
1196
 
@@ -1198,28 +1083,25 @@ class GptBertForQuestionAnswering(GptBertModel):
1198
  loss=total_loss,
1199
  start_logits=start_logits,
1200
  end_logits=end_logits,
1201
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1202
- attentions=attention_probs if output_attentions else None
1203
  )
1204
 
1205
 
1206
  class GptBertForMultipleChoice(GptBertModel):
1207
- _keys_to_ignore_on_load_unexpected = ["classifier"]
1208
 
1209
  def __init__(self, config: GptBertConfig, **kwargs):
1210
  super().__init__(config, add_mlm_layer=False, **kwargs)
1211
 
1212
  self.num_labels = getattr(config, "num_labels", 2)
1213
- self.head = Classifier(config, self.num_labels)
 
1214
 
1215
  def forward(
1216
  self,
1217
  input_ids: Optional[torch.Tensor] = None,
1218
  attention_mask: Optional[torch.Tensor] = None,
1219
- token_type_ids: Optional[torch.Tensor] = None,
1220
- position_ids: Optional[torch.Tensor] = None,
1221
  labels: Optional[torch.Tensor] = None,
1222
- output_attentions: Optional[bool] = None,
1223
  output_hidden_states: Optional[bool] = None,
1224
  return_dict: Optional[bool] = None,
1225
  **kwargs
@@ -1230,8 +1112,8 @@ class GptBertForMultipleChoice(GptBertModel):
1230
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1231
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1232
 
1233
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask)
1234
- logits = self.head(sequence_output)
1235
  reshaped_logits = logits.view(-1, num_choices)
1236
 
1237
  loss = None
@@ -1242,14 +1124,12 @@ class GptBertForMultipleChoice(GptBertModel):
1242
  if not return_dict:
1243
  output = (
1244
  reshaped_logits,
1245
- *([contextualized_embeddings] if output_hidden_states else []),
1246
- *([attention_probs] if output_attentions else [])
1247
  )
1248
  return ((loss,) + output) if loss is not None else output
1249
 
1250
  return MultipleChoiceModelOutput(
1251
  loss=loss,
1252
  logits=reshaped_logits,
1253
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1254
- attentions=attention_probs if output_attentions else None
1255
  )
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
 
7
  from functools import partial, lru_cache
8
 
9
  from .configuration_gptbert import GptBertConfig
10
  from transformers.modeling_utils import PreTrainedModel
11
  from transformers.activations import gelu_new
12
+ from transformers.utils import is_flash_attn_2_available, logging
13
  from transformers.modeling_outputs import (
14
  MaskedLMOutput,
15
  MultipleChoiceModelOutput,
 
22
  import math
23
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
24
 
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
  # Workaround for transformers < 4.36.0 check_imports issue
29
  # See: https://github.com/huggingface/transformers/issues/28459
30
  try:
 
33
  from flash_attn.layers.rotary import RotaryEmbedding
34
  from flash_attn.ops.triton.rotary import apply_rotary
35
  else:
36
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
37
+ logger.warning_once(
38
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
39
+ )
40
  except ImportError:
41
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
42
+ logger.warning_once(
43
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
44
+ )
45
 
46
 
47
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
 
93
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  class MultiCastedLinearOrthoIn(nn.Module):
97
  def __init__(self, in_features, out_features, bias):
98
  super().__init__()
 
114
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  class GeGLU(nn.Module):
118
  def forward(self, x):
119
  x, gate = x.chunk(2, dim=-1)
120
  return x * gelu_new(gate)
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  class Encoder(nn.Module):
124
  def __init__(self, config: GptBertConfig):
125
  super().__init__()
 
126
  self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
127
  self.short_long_ratio = config.short_long_ratio
128
 
129
  def set_window_length(self, config: GptBertConfig):
130
  for i, layer in enumerate(self.layers):
131
+ if (i + 1) % self.local_global_ratio == 0:
132
+ layer.set_window_length(config.global_window_length)
133
  else:
134
+ layer.set_window_length(config.local_window_length)
135
 
136
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
137
+ hidden_layers = [hidden_layer] if output_hidden_states else None
 
138
  v1 = None
139
  embeddings = hidden_layer
140
 
141
  for layer in self.layers:
142
+ if checkpoint_activations:
143
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layers, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
144
+ else:
145
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
146
 
147
+ if output_hidden_states:
148
+ hidden_layers.append(hidden_layer)
149
+
150
+ return hidden_layer, hidden_layers
151
 
152
 
153
  class Layer(nn.Module):
 
166
  qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
167
  mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
168
 
169
+ attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
170
  mlp_layer = mlp_layer + attention_output
171
  hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
172
  output = hidden_layer + attention_output + self.mlp(mlp_layer)
173
 
174
+ return output, v1
175
 
176
 
177
  class Embedding(nn.Module):
 
196
  return self.dropout(word_embedding)
197
 
198
 
199
+ class Classifier(nn.Module):
200
+ def __init__(self, config: GptBertConfig, n_labels: int):
201
  super().__init__()
202
 
203
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
204
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
205
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
206
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
207
 
208
  def forward(self, x: torch.Tensor):
209
+ x = self.pre_norm(x.float()).type_as(x)
210
  x = self.projection(x)
211
  x = gelu_new(x)
212
+ x = self.post_norm(x.float()).type_as(x)
213
+ x = self.emb2vocab(x)
214
+ return x
215
 
216
 
217
  def flash_attention_forward(
 
294
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
295
 
296
  # Initialize rotary embeddings based on whether FlashAttention is available
297
+ if self.config._attn_implementation == "flash_attention_2":
298
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=config.d_qk, base=theta, max_seqlen=config.max_sequence_length)
 
 
 
 
 
 
299
  else:
300
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
301
 
 
314
  """Create and cache window attention mask."""
315
  if self.is_causal:
316
  mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
317
+ mask = mask.tril().triu(diagonal=-self.window_length)
318
  else:
319
  mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
320
+ mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
321
  return mask.view(1, 1, query_length, key_length)
322
 
323
  def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
 
328
  # Use cached window mask
329
  with torch.no_grad():
330
  window_mask = self._get_window_mask(query_length, key_length, query.device)
 
331
  if padding_mask is not None:
332
+ attention_mask = padding_mask & window_mask
333
  else:
334
  attention_mask = window_mask
335
 
336
+ output = F.scaled_dot_product_attention(
337
+ query=query,
338
+ key=key,
339
+ value=value,
340
+ attn_mask=attention_mask,
341
+ dropout_p=self.attention_dropout if self.training else 0.0,
342
+ is_causal=self.is_causal
343
+ )
344
+ return output
 
345
 
346
  def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
347
  # Get original shape info
348
+ if self.config._attn_implementation == "flash_attention_2":
349
  # Unpadded case
350
  indices, cu_seqlens, max_seqlen = padding_info
351
  total_seqlen = hidden_layer.size(0)
 
353
  else:
354
  # Padded case
355
  batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
 
 
356
 
357
+ hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
358
+ qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
359
 
360
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
361
  value = self.v_proj(hidden_layer)
362
 
363
+ if self.config._attn_implementation == "flash_attention_2":
364
  # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
365
  query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
366
  key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
 
375
  value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
376
 
377
  # Prepare qkv for FlashAttention
378
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
 
 
 
 
 
 
 
 
379
 
380
  # Determine window size for local attention
381
  if self.window_length is not None and self.window_length > 0:
 
400
 
401
  # Reshape output back
402
  output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
 
403
 
404
  else:
405
  # Standard attention path
406
  query_length = hidden_layer.size(0)
407
  key_length = hidden_layer.size(0)
408
 
409
+ query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
410
+ key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
411
+ value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
412
 
413
  query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
414
  key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
 
421
  query = self.rope_embedding(query)
422
  key = self.rope_embedding(key)
423
 
424
+ output = self.attention_operation(query, key, value, padding_info)
425
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
 
 
 
426
 
427
+ output = self.inter_norm(output.float()).type_as(output)
 
 
 
428
  output = self.out_proj(output)
429
+ output = self.dropout(output)
430
 
431
+ return output, v1
 
 
 
 
 
 
432
 
 
433
 
434
  class FeedForward(nn.Module):
435
  def __init__(self, config: GptBertConfig):
 
442
  self.dropout = nn.Dropout(config.feed_forward_dropout_p)
443
 
444
  def forward(self, x: torch.Tensor):
445
+ x = self.pre_norm(x.float()).type_as(x)
446
  x = self.up_proj(x)
447
  x = self.activation(x)
448
  x = self.inter_norm(x.float()).type_as(x)
449
  x = self.down_proj(x)
450
+ x = self.dropout(x)
451
+ return x
452
 
453
 
454
  class ApplyRotaryEmbUnpad(torch.autograd.Function):
 
506
  return do, None, None, None, None, None, None
507
 
508
 
509
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
 
 
 
 
 
 
510
  return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
511
 
512
 
513
  class UnpaddedRotaryEmbedding(RotaryEmbedding):
514
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
515
+ super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
516
  self.max_seqlen = max_seqlen
517
 
518
  if max_seqlen is not None and device is not None and dtype is not None:
519
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=None)
520
 
521
  def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
522
  if max_seqlen is not None:
 
600
  def _init_weights(self, module):
601
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
602
 
603
+ if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
604
  nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
605
  if module.bias is not None:
606
  module.bias.data.zero_()
 
610
  module.bias.data.zero_()
611
  module.weight.data.fill_(1.0)
612
 
613
+ @classmethod
614
+ def _autoset_attn_implementation(
615
+ cls,
616
+ config,
617
+ torch_dtype: Optional[torch.dtype] = None,
618
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
619
+ check_device_map: bool = True,
620
+ ):
621
+ if config._attn_implementation_internal is None:
622
+ config._attn_implementation_internal = "flash_attention_2"
623
+ try:
624
+ return cls._check_and_enable_flash_attn_2(
625
+ config,
626
+ torch_dtype=torch.float16,
627
+ device_map=device_map,
628
+ hard_check_only=False,
629
+ check_device_map=check_device_map,
630
+ )
631
+ except (ValueError, ImportError):
632
+ config._attn_implementation_internal = None
633
+ return super()._autoset_attn_implementation(
634
+ config,
635
+ torch_dtype=torch_dtype,
636
+ device_map=device_map,
637
+ check_device_map=check_device_map,
638
+ )
639
+
640
 
641
  class GptBertModel(GptBertPreTrainedModel):
642
  def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
 
646
 
647
  self.embedding = Embedding(config)
648
  self.encoder = Encoder(config)
649
+ self.classifier = Classifier(config, config.vocab_size) if add_mlm_layer else None
650
  self.set_window_length(config)
651
+ self.gradient_checkpointing = False
652
+ self.post_init()
653
 
654
  def set_window_length(self, config) -> None:
655
  self.encoder.set_window_length(config)
 
665
  input_ids: Optional[torch.Tensor] = None,
666
  attention_mask: Optional[torch.Tensor] = None,
667
  output_hidden_states: Optional[bool] = None
668
+ ):
669
  if input_ids is not None:
670
  input_shape = input_ids.size()
671
  else:
 
674
  batch_size, seq_length = input_shape
675
  device = input_ids.device
676
 
677
+ if attention_mask is None:
678
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
679
+ else:
680
+ attention_mask = attention_mask.bool()
681
+
682
+ if self.config._attn_implementation == "flash_attention_2":
683
+ if len(attention_mask.size()) != 2:
684
+ raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
685
+ with torch.no_grad():
686
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
687
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
688
  else:
689
+ if len(attention_mask.size()) == 2:
690
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
691
+ elif len(attention_mask.size()) == 3:
692
+ attention_mask = attention_mask.unsqueeze(1)
 
 
 
 
 
 
693
  padding_info = attention_mask
694
 
695
  static_embeddings = self.embedding(input_ids)
696
+
697
+ original_dtype = static_embeddings.dtype
698
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
699
+ static_embeddings = static_embeddings.bfloat16()
700
+
701
+ last_layer, contextualized_embeddings = self.encoder(
702
+ static_embeddings,
703
+ padding_info,
704
+ output_hidden_states=output_hidden_states,
705
+ checkpoint_activations=self.gradient_checkpointing and self.training
706
+ )
707
+
708
+ last_layer = last_layer.to(original_dtype)
709
+ if output_hidden_states:
710
+ contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
711
 
712
  # Pad output if using FlashAttention
713
+ if self.config._attn_implementation == "flash_attention_2":
714
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
715
  if output_hidden_states:
716
  contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
717
  else:
718
  contextualized_embeddings = None
719
 
720
+ return last_layer, contextualized_embeddings
721
 
722
  def forward(
723
  self,
 
730
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
731
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
732
 
733
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
 
734
 
735
  if not return_dict:
736
  return (
737
  sequence_output,
738
+ *([contextualized_embeddings] if output_hidden_states else [])
 
739
  )
740
 
741
  return BaseModelOutput(
742
  last_hidden_state=sequence_output,
743
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
744
  )
745
 
746
 
747
  class GptBertForMaskedLM(GptBertModel):
748
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
749
 
750
  def __init__(self, config: GptBertConfig, **kwargs):
751
  super().__init__(config, add_mlm_layer=True, **kwargs)
 
760
  self,
761
  input_ids: Optional[torch.Tensor] = None,
762
  attention_mask: Optional[torch.Tensor] = None,
 
 
763
  output_hidden_states: Optional[bool] = None,
 
764
  return_dict: Optional[bool] = None,
765
  labels: Optional[torch.LongTensor] = None,
766
  **kwargs
767
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
768
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
769
 
770
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
771
  subword_prediction = self.classifier(sequence_output)
772
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
773
 
 
784
  if not return_dict:
785
  output = (
786
  subword_prediction,
787
+ *([contextualized_embeddings] if output_hidden_states else [])
 
788
  )
789
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
790
 
791
  return MaskedLMOutput(
792
  loss=masked_lm_loss,
793
  logits=subword_prediction,
794
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
795
  )
796
 
797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  class GptBertForCausalLM(GptBertModel):
799
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
800
 
801
  def __init__(self, config: GptBertConfig, **kwargs):
802
  config.is_decoder = True
 
843
  assert past_key_values is None, "past_key_values is not supported for now"
844
  assert not use_cache, "use_cache is not supported for now"
845
 
846
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
847
  subword_prediction = self.classifier(sequence_output)
848
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
849
 
850
+ causal_lm_loss = None
851
  if labels is not None:
852
  labels_flatten = labels[:, 1:].flatten()
853
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
854
+ causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
855
 
856
  if not return_dict:
857
  output = (
858
  subword_prediction,
859
+ *([contextualized_embeddings] if output_hidden_states else [])
 
860
  )
861
+ return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
862
 
863
  return CausalLMOutput(
864
+ loss=causal_lm_loss,
865
  logits=subword_prediction,
866
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
867
  )
868
 
869
  def prepare_inputs_for_generation(
 
919
 
920
 
921
  class GptBertForSequenceClassification(GptBertModel):
922
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab"]
923
 
924
  def __init__(self, config: GptBertConfig, **kwargs):
925
  super().__init__(config, add_mlm_layer=False, **kwargs)
926
 
927
  self.num_labels = config.num_labels
928
+ self.classifier = Classifier(config, self.num_labels)
929
+ self.post_init()
930
 
931
  def forward(
932
  self,
933
  input_ids: Optional[torch.Tensor] = None,
934
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
935
  output_hidden_states: Optional[bool] = None,
936
  return_dict: Optional[bool] = None,
937
  labels: Optional[torch.LongTensor] = None,
 
939
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
940
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
941
 
942
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
943
+ logits = self.classifier(sequence_output[:, 0, :])
944
 
945
  loss = None
946
  if labels is not None:
 
968
  if not return_dict:
969
  output = (
970
  logits,
971
+ *([contextualized_embeddings] if output_hidden_states else [])
 
972
  )
973
  return ((loss,) + output) if loss is not None else output
974
 
975
  return SequenceClassifierOutput(
976
  loss=loss,
977
  logits=logits,
978
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
979
  )
980
 
981
 
982
  class GptBertForTokenClassification(GptBertModel):
983
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab"]
984
 
985
  def __init__(self, config: GptBertConfig, **kwargs):
986
  super().__init__(config, add_mlm_layer=False, **kwargs)
987
 
988
  self.num_labels = config.num_labels
989
+ self.classifier = Classifier(config, self.num_labels)
990
+ self.post_init()
991
 
992
  def forward(
993
  self,
994
  input_ids: Optional[torch.Tensor] = None,
995
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
996
  output_hidden_states: Optional[bool] = None,
997
  return_dict: Optional[bool] = None,
998
  labels: Optional[torch.LongTensor] = None,
 
1000
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1001
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1002
 
1003
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
1004
+ logits = self.classifier(sequence_output)
1005
 
1006
  loss = None
1007
  if labels is not None:
 
1025
 
1026
 
1027
  class GptBertForQuestionAnswering(GptBertModel):
1028
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab"]
1029
 
1030
  def __init__(self, config: GptBertConfig, **kwargs):
1031
  super().__init__(config, add_mlm_layer=False, **kwargs)
1032
 
1033
  self.num_labels = config.num_labels
1034
+ self.classifier = Classifier(config, self.num_labels)
1035
+ self.post_init()
1036
 
1037
  def forward(
1038
  self,
1039
  input_ids: Optional[torch.Tensor] = None,
1040
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
1041
  output_hidden_states: Optional[bool] = None,
1042
  return_dict: Optional[bool] = None,
1043
  start_positions: Optional[torch.Tensor] = None,
 
1046
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1047
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1048
 
1049
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
1050
+ logits = self.classifier(sequence_output)
1051
 
1052
  start_logits, end_logits = logits.split(1, dim=-1)
1053
  start_logits = start_logits.squeeze(-1).contiguous()
 
1075
  output = (
1076
  start_logits,
1077
  end_logits,
1078
+ *([contextualized_embeddings] if output_hidden_states else [])
 
1079
  )
1080
  return ((total_loss,) + output) if total_loss is not None else output
1081
 
 
1083
  loss=total_loss,
1084
  start_logits=start_logits,
1085
  end_logits=end_logits,
1086
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
1087
  )
1088
 
1089
 
1090
  class GptBertForMultipleChoice(GptBertModel):
1091
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab"]
1092
 
1093
  def __init__(self, config: GptBertConfig, **kwargs):
1094
  super().__init__(config, add_mlm_layer=False, **kwargs)
1095
 
1096
  self.num_labels = getattr(config, "num_labels", 2)
1097
+ self.classifier = Classifier(config, self.num_labels)
1098
+ self.post_init()
1099
 
1100
  def forward(
1101
  self,
1102
  input_ids: Optional[torch.Tensor] = None,
1103
  attention_mask: Optional[torch.Tensor] = None,
 
 
1104
  labels: Optional[torch.Tensor] = None,
 
1105
  output_hidden_states: Optional[bool] = None,
1106
  return_dict: Optional[bool] = None,
1107
  **kwargs
 
1112
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1113
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1114
 
1115
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
1116
+ logits = self.classifier(sequence_output)
1117
  reshaped_logits = logits.view(-1, num_choices)
1118
 
1119
  loss = None
 
1124
  if not return_dict:
1125
  output = (
1126
  reshaped_logits,
1127
+ *([contextualized_embeddings] if output_hidden_states else [])
 
1128
  )
1129
  return ((loss,) + output) if loss is not None else output
1130
 
1131
  return MultipleChoiceModelOutput(
1132
  loss=loss,
1133
  logits=reshaped_logits,
1134
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
1135
  )