maxmeuer commited on
Commit
6a30f3b
·
verified ·
1 Parent(s): 61d4a54

Update code for latest transformer

Browse files
Files changed (5) hide show
  1. README.md +6 -8
  2. config.json +4 -4
  3. config.yaml +4 -4
  4. model.py +3 -13
  5. transformer_backbone.py +22 -267
README.md CHANGED
@@ -69,10 +69,10 @@ We do not collect PII (personally identifiable information) for any of these cha
69
 
70
  ## Inference
71
 
72
- We provide an inference module compatible with HuggingFace Transformers for running model inference. We recommend pinning the transformers library to version 4.46.3. Before executing the inference example below, make sure the [hat-splitter package](https://pypi.org/project/hat-splitter/) is installed in your environment.
73
 
74
  ```shell
75
- pip install 'hat-splitter>=0.1.9' 'transformers==4.46.3' torch
76
  pip install flash_attn
77
  ```
78
 
@@ -82,13 +82,12 @@ Download model weights and run inference using the following example:
82
  import torch
83
  from transformers import AutoModelForCausalLM
84
  INPUT ="When was Rome founded?"
85
- MODEL_ID = "Aleph-Alpha/TFree-HAT-Pretrained-7B-Base"
86
  model = AutoModelForCausalLM.from_pretrained(
87
  trust_remote_code=True,
88
- pretrained_model_name_or_path=MODEL_ID,
89
- attn_implementation="flash_attention_2",
90
  ).to("cuda", torch.bfloat16)
91
- input_ids, cumulative_word_lengths = model._prepare_input(INPUT)
92
  model_output = model.generate(
93
  input_ids,
94
  cumulative_seq_lengths_per_word=cumulative_word_lengths,
@@ -99,8 +98,7 @@ print("Prompt: ", INPUT)
99
  print("Completion: ", model_output.completion_text)
100
  ```
101
 
102
- Please note that the realized inference speed strongly depends on the maturity of the inference implementation beyond the intrinsic text compression of any model. Besides this huggingface transformers-based inference solution, we are also releasing a [vLLM-based inference solution](https://github.com/Aleph-Alpha/vllm) for our models that is optimized for batched inference. Please note that this vLLM inference for HAT is still under active development.
103
-
104
 
105
  # Evaluation
106
 
 
69
 
70
  ## Inference
71
 
72
+ We provide an inference module compatible with HuggingFace Transformers for running model inference. Before executing the inference example below, make sure the [hat-splitter package](https://pypi.org/project/hat-splitter/) is installed in your environment.
73
 
74
  ```shell
75
+ pip install 'hat-splitter>=0.1.9' transformers torch
76
  pip install flash_attn
77
  ```
78
 
 
82
  import torch
83
  from transformers import AutoModelForCausalLM
84
  INPUT ="When was Rome founded?"
85
+ MODEL_ID = "Aleph-Alpha/llama-tfree-hat-pretrained-7b-base"
86
  model = AutoModelForCausalLM.from_pretrained(
87
  trust_remote_code=True,
88
+ pretrained_model_name_or_path=MODEL_ID
 
89
  ).to("cuda", torch.bfloat16)
90
+ input_ids, cumulative_word_lengths = model._prepare_input(INPUT, add_llama_template=True)
91
  model_output = model.generate(
92
  input_ids,
93
  cumulative_seq_lengths_per_word=cumulative_word_lengths,
 
98
  print("Completion: ", model_output.completion_text)
99
  ```
100
 
101
+ Please note that the realized inference speed strongly depends on the maturity of the inference implementation beyond the intrinsic text compression of any model. Besides this huggingface transformers-based inference solution, we are also releasing a [vLLM-based inference solution](https://github.com/Aleph-Alpha/vllm) for our models that is optimized for batched inference. We are still waiting for a review on our PR into VLLM, but we strongly encourage to use the optimized inference provided through VLLM.
 
102
 
103
  # Evaluation
104
 
config.json CHANGED
@@ -12,7 +12,7 @@
12
  "is_neox_style": true,
13
  "key_query_norm": true,
14
  "key_query_norm_per_head": true,
15
- "max_position_embeddings": 3500,
16
  "mlp_bias": false,
17
  "num_attention_heads": 32,
18
  "num_hidden_layers": 32,
@@ -44,7 +44,7 @@
44
  "is_neox_style": true,
45
  "key_query_norm": true,
46
  "key_query_norm_per_head": true,
47
- "max_position_embeddings": 28000,
48
  "mlp_bias": false,
49
  "num_attention_heads": 8,
50
  "num_hidden_layers": 4,
@@ -75,7 +75,7 @@
75
  "is_neox_style": true,
76
  "key_query_norm": true,
77
  "key_query_norm_per_head": true,
78
- "max_position_embeddings": 28000,
79
  "mlp_bias": false,
80
  "num_attention_heads": 8,
81
  "num_hidden_layers": 6,
@@ -90,7 +90,7 @@
90
  "use_cache": true,
91
  "vocab_size": 256
92
  },
93
- "max_position_embeddings": 28000,
94
  "max_word_size": 100,
95
  "model_type": "hierarchical_autoregressive_transformer",
96
  "sliding_window": 768,
 
12
  "is_neox_style": true,
13
  "key_query_norm": true,
14
  "key_query_norm_per_head": true,
15
+ "max_position_embeddings": 32900,
16
  "mlp_bias": false,
17
  "num_attention_heads": 32,
18
  "num_hidden_layers": 32,
 
44
  "is_neox_style": true,
45
  "key_query_norm": true,
46
  "key_query_norm_per_head": true,
47
+ "max_position_embeddings": 262144,
48
  "mlp_bias": false,
49
  "num_attention_heads": 8,
50
  "num_hidden_layers": 4,
 
75
  "is_neox_style": true,
76
  "key_query_norm": true,
77
  "key_query_norm_per_head": true,
78
+ "max_position_embeddings": 262144,
79
  "mlp_bias": false,
80
  "num_attention_heads": 8,
81
  "num_hidden_layers": 6,
 
90
  "use_cache": true,
91
  "vocab_size": 256
92
  },
93
+ "max_position_embeddings": 262144,
94
  "max_word_size": 100,
95
  "model_type": "hierarchical_autoregressive_transformer",
96
  "sliding_window": 768,
config.yaml CHANGED
@@ -6,7 +6,7 @@ encoder_config:
6
  num_key_value_heads: 8
7
  rms_norm_eps: 1.0e-05
8
  intermediate_size: 2816
9
- max_position_embeddings: 28000
10
  rope_scaling:
11
  rope_type: default
12
  rope_theta: 100000
@@ -34,7 +34,7 @@ backbone_config:
34
  num_key_value_heads: 8
35
  rms_norm_eps: 1.0e-05
36
  intermediate_size: 14336
37
- max_position_embeddings: 3500
38
  rope_scaling:
39
  rope_type: default
40
  rope_theta: 500000
@@ -53,7 +53,7 @@ decoder_config:
53
  num_key_value_heads: 8
54
  rms_norm_eps: 1.0e-05
55
  intermediate_size: 2816
56
- max_position_embeddings: 28000
57
  rope_scaling:
58
  rope_type: default
59
  rope_theta: 100000
@@ -82,7 +82,7 @@ auto_map:
82
  special_token_dict: {}
83
  max_word_size: 100
84
  sliding_window: 768
85
- max_position_embeddings: 28000
86
  torch_dtype: bfloat16
87
  architectures:
88
  - HATDecoderForCausalLM
 
6
  num_key_value_heads: 8
7
  rms_norm_eps: 1.0e-05
8
  intermediate_size: 2816
9
+ max_position_embeddings: 262144
10
  rope_scaling:
11
  rope_type: default
12
  rope_theta: 100000
 
34
  num_key_value_heads: 8
35
  rms_norm_eps: 1.0e-05
36
  intermediate_size: 14336
37
+ max_position_embeddings: 32900
38
  rope_scaling:
39
  rope_type: default
40
  rope_theta: 500000
 
53
  num_key_value_heads: 8
54
  rms_norm_eps: 1.0e-05
55
  intermediate_size: 2816
56
+ max_position_embeddings: 262144
57
  rope_scaling:
58
  rope_type: default
59
  rope_theta: 100000
 
82
  special_token_dict: {}
83
  max_word_size: 100
84
  sliding_window: 768
85
+ max_position_embeddings: 262144
86
  torch_dtype: bfloat16
87
  architectures:
88
  - HATDecoderForCausalLM
model.py CHANGED
@@ -26,12 +26,6 @@ from .transformer_backbone import (
26
  LlamaRotaryEmbedding,
27
  )
28
 
29
- try:
30
- transformers_version = version("transformers")
31
- if transformers_version != "4.46.3":
32
- print(f"Warning: Expecected transformers version 4.46.3, but found {transformers_version}. Outputs might be different.")
33
- except PackageNotFoundError:
34
- print("transformers is not installed")
35
 
36
 
37
  def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
@@ -41,13 +35,12 @@ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
41
  LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
42
 
43
 
44
- class HATCache(Cache):
45
  encoder_cache: DynamicCache
46
  backbone_cache: DynamicCache
47
  decoder_cache: DynamicCache
48
 
49
- def __init__(self, *args, **kwargs):
50
- super().__init__(*args, **kwargs)
51
  self.encoder_cache = DynamicCache()
52
  self.backbone_cache = DynamicCache()
53
  self.decoder_cache = DynamicCache()
@@ -382,10 +375,7 @@ class HATCrossAttention(nn.Module):
382
 
383
  self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
384
 
385
- rope_theta = config.rope_theta
386
- rope_type = config.rope_scaling["rope_type"]
387
-
388
- self.rotary_emb = LlamaRotaryEmbedding(dim=self.head_dim, base=rope_theta, rope_type=rope_type)
389
 
390
  def forward(
391
  self,
 
26
  LlamaRotaryEmbedding,
27
  )
28
 
 
 
 
 
 
 
29
 
30
 
31
  def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
 
35
  LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
36
 
37
 
38
+ class HATCache:
39
  encoder_cache: DynamicCache
40
  backbone_cache: DynamicCache
41
  decoder_cache: DynamicCache
42
 
43
+ def __init__(self):
 
44
  self.encoder_cache = DynamicCache()
45
  self.backbone_cache = DynamicCache()
46
  self.decoder_cache = DynamicCache()
 
375
 
376
  self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
377
 
378
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
 
 
 
379
 
380
  def forward(
381
  self,
transformer_backbone.py CHANGED
@@ -22,17 +22,16 @@
22
  # See the License for the specific language governing permissions and
23
  # limitations under the License.
24
  import math
25
- from typing import List, Optional, Tuple, Union
26
 
27
  import torch
28
  import torch.nn.functional as F
29
- import torch.utils.checkpoint
30
  from torch import nn
31
 
 
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
  from transformers.generation import GenerationMixin
35
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
  from transformers.modeling_outputs import (
38
  BaseModelOutputWithPast,
@@ -86,117 +85,41 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
86
 
87
 
88
  class LlamaRotaryEmbedding(nn.Module):
89
- def __init__(
90
- self,
91
- dim=None,
92
- max_position_embeddings=2048,
93
- base=10000,
94
- device=None,
95
- scaling_factor=1.0,
96
- rope_type="default",
97
- config: Optional[LlamaConfig] = None,
98
- ):
99
  super().__init__()
100
- # TODO (joao): remove the `if` below, only used for BC
101
- self.rope_kwargs = {}
102
- if config is None:
103
- logger.warning_once(
104
- "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
105
- "`config` argument. All other arguments will be removed in v4.46"
106
- )
107
- self.rope_kwargs = {
108
- "rope_type": rope_type,
109
- "factor": scaling_factor,
110
- "dim": dim,
111
- "base": base,
112
- "max_position_embeddings": max_position_embeddings,
113
- }
114
- self.rope_type = rope_type
115
- self.max_seq_len_cached = max_position_embeddings
116
- self.original_max_seq_len = max_position_embeddings
117
  else:
118
- # BC: "rope_type" was originally "type"
119
- if config.rope_scaling is not None:
120
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
121
- else:
122
- self.rope_type = "default"
123
- self.max_seq_len_cached = config.max_position_embeddings
124
- self.original_max_seq_len = config.max_position_embeddings
125
 
126
  self.config = config
127
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
128
 
129
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
130
  self.register_buffer("inv_freq", inv_freq, persistent=False)
131
  self.original_inv_freq = self.inv_freq
132
 
133
- def _dynamic_frequency_update(self, position_ids, device):
134
- """
135
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
136
- 1 - growing beyond the cached sequence length (allow scaling)
137
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
138
- """
139
- seq_len = torch.max(position_ids) + 1
140
- if seq_len > self.max_seq_len_cached: # growth
141
- inv_freq, self.attention_scaling = self.rope_init_fn(
142
- self.config, device, seq_len=seq_len, **self.rope_kwargs
143
- )
144
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
145
- self.max_seq_len_cached = seq_len
146
-
147
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
148
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
149
- self.max_seq_len_cached = self.original_max_seq_len
150
-
151
  @torch.no_grad()
 
152
  def forward(self, x, position_ids):
153
- if "dynamic" in self.rope_type:
154
- self._dynamic_frequency_update(position_ids, device=x.device)
155
-
156
- # Core RoPE block
157
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
158
  position_ids_expanded = position_ids[:, None, :].float()
159
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
160
- device_type = x.device.type
161
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
162
- with torch.autocast(device_type=device_type, enabled=False):
163
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
164
  emb = torch.cat((freqs, freqs), dim=-1)
165
- cos = emb.cos()
166
- sin = emb.sin()
167
-
168
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
169
- cos = cos * self.attention_scaling
170
- sin = sin * self.attention_scaling
171
 
172
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
173
 
174
 
175
- class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
176
- """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
177
-
178
- def __init__(self, *args, **kwargs):
179
- logger.warning_once(
180
- "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
181
- "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
182
- )
183
- kwargs["rope_type"] = "linear"
184
- super().__init__(*args, **kwargs)
185
-
186
-
187
- class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
188
- """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
189
-
190
- def __init__(self, *args, **kwargs):
191
- logger.warning_once(
192
- "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
193
- "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
194
- "__init__)."
195
- )
196
- kwargs["rope_type"] = "dynamic"
197
- super().__init__(*args, **kwargs)
198
-
199
-
200
  def rotate_half(x):
201
  """Rotates half the hidden dims of the input."""
202
  x1 = x[..., : x.shape[-1] // 2]
@@ -565,110 +488,7 @@ class LlamaFlashAttention2(LlamaAttention):
565
  return attn_output, attn_weights, past_key_value
566
 
567
 
568
- class LlamaSdpaAttention(LlamaAttention):
569
- """
570
- Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
571
- `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
572
- SDPA API.
573
- """
574
-
575
- # Adapted from LlamaAttention.forward
576
- def forward(
577
- self,
578
- hidden_states: torch.Tensor,
579
- attention_mask: Optional[torch.Tensor] = None,
580
- position_ids: Optional[torch.LongTensor] = None,
581
- past_key_value: Optional[Cache] = None,
582
- output_attentions: bool = False,
583
- use_cache: bool = False,
584
- cache_position: Optional[torch.LongTensor] = None,
585
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
586
- **kwargs,
587
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
588
- if output_attentions:
589
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
590
- logger.warning_once(
591
- "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
592
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
593
- )
594
- return super().forward(
595
- hidden_states=hidden_states,
596
- attention_mask=attention_mask,
597
- position_ids=position_ids,
598
- past_key_value=past_key_value,
599
- output_attentions=output_attentions,
600
- use_cache=use_cache,
601
- cache_position=cache_position,
602
- position_embeddings=position_embeddings,
603
- )
604
-
605
- bsz, q_len, _ = hidden_states.size()
606
-
607
- query_states = self.q_proj(hidden_states)
608
- key_states = self.k_proj(hidden_states)
609
- value_states = self.v_proj(hidden_states)
610
-
611
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
612
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
613
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
614
-
615
- if position_embeddings is None:
616
- logger.warning_once(
617
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
618
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
619
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
620
- "removed and `position_embeddings` will be mandatory."
621
- )
622
- cos, sin = self.rotary_emb(value_states, position_ids)
623
- else:
624
- cos, sin = position_embeddings
625
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
626
-
627
- if past_key_value is not None:
628
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
629
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
630
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
631
-
632
- key_states = repeat_kv(key_states, self.num_key_value_groups)
633
- value_states = repeat_kv(value_states, self.num_key_value_groups)
634
-
635
- causal_mask = attention_mask
636
- if attention_mask is not None:
637
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
638
-
639
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
640
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
641
- if query_states.device.type == "cuda" and causal_mask is not None:
642
- query_states = query_states.contiguous()
643
- key_states = key_states.contiguous()
644
- value_states = value_states.contiguous()
645
-
646
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
647
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
648
- is_causal = True if causal_mask is None and q_len > 1 else False
649
 
650
- attn_output = torch.nn.functional.scaled_dot_product_attention(
651
- query_states,
652
- key_states,
653
- value_states,
654
- attn_mask=causal_mask,
655
- dropout_p=self.attention_dropout if self.training else 0.0,
656
- is_causal=is_causal,
657
- )
658
-
659
- attn_output = attn_output.transpose(1, 2).contiguous()
660
- attn_output = attn_output.view(bsz, q_len, -1)
661
-
662
- attn_output = self.o_proj(attn_output)
663
-
664
- return attn_output, None, past_key_value
665
-
666
-
667
- LLAMA_ATTENTION_CLASSES = {
668
- "eager": LlamaAttention,
669
- "flash_attention_2": LlamaFlashAttention2,
670
- "sdpa": LlamaSdpaAttention,
671
- }
672
 
673
 
674
  class LlamaDecoderLayer(nn.Module):
@@ -676,7 +496,7 @@ class LlamaDecoderLayer(nn.Module):
676
  super().__init__()
677
  self.hidden_size = config.hidden_size
678
 
679
- self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
680
 
681
  self.mlp = LlamaMLP(config)
682
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -960,9 +780,8 @@ class LlamaModel(LlamaPreTrainedModel):
960
  if position_ids is None:
961
  position_ids = cache_position.unsqueeze(0)
962
 
963
- causal_mask = self._update_causal_mask(
964
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
965
- )
966
  hidden_states = inputs_embeds
967
 
968
  # create position embeddings to be shared across the decoder layers
@@ -1028,70 +847,6 @@ class LlamaModel(LlamaPreTrainedModel):
1028
  attentions=all_self_attns,
1029
  )
1030
 
1031
- def _update_causal_mask(
1032
- self,
1033
- attention_mask: torch.Tensor,
1034
- input_tensor: torch.Tensor,
1035
- cache_position: torch.Tensor,
1036
- past_key_values: Cache,
1037
- output_attentions: bool,
1038
- ):
1039
- if self.config._attn_implementation == "flash_attention_2":
1040
- if attention_mask is not None and 0.0 in attention_mask:
1041
- return attention_mask
1042
- return None
1043
-
1044
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1045
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1046
- # to infer the attention mask.
1047
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1048
- using_static_cache = isinstance(past_key_values, StaticCache)
1049
-
1050
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1051
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1052
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1053
- attention_mask,
1054
- inputs_embeds=input_tensor,
1055
- past_key_values_length=past_seen_tokens,
1056
- is_training=self.training,
1057
- ):
1058
- return None
1059
-
1060
- dtype, device = input_tensor.dtype, input_tensor.device
1061
- sequence_length = input_tensor.shape[1]
1062
- if using_static_cache:
1063
- target_length = past_key_values.get_max_cache_shape()
1064
- else:
1065
- target_length = (
1066
- attention_mask.shape[-1]
1067
- if isinstance(attention_mask, torch.Tensor)
1068
- else past_seen_tokens + sequence_length + 1
1069
- )
1070
-
1071
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1072
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1073
- attention_mask,
1074
- sequence_length=sequence_length,
1075
- target_length=target_length,
1076
- dtype=dtype,
1077
- device=device,
1078
- cache_position=cache_position,
1079
- batch_size=input_tensor.shape[0],
1080
- )
1081
-
1082
- if (
1083
- self.config._attn_implementation == "sdpa"
1084
- and attention_mask is not None
1085
- and attention_mask.device.type == "cuda"
1086
- and not output_attentions
1087
- ):
1088
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1089
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1090
- # Details: https://github.com/pytorch/pytorch/issues/110213
1091
- min_dtype = torch.finfo(dtype).min
1092
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1093
-
1094
- return causal_mask
1095
 
1096
  @staticmethod
1097
  def _prepare_4d_causal_attention_mask_with_cache_position(
 
22
  # See the License for the specific language governing permissions and
23
  # limitations under the License.
24
  import math
25
+ from typing import Callable, List, Optional, Tuple, Union
26
 
27
  import torch
28
  import torch.nn.functional as F
 
29
  from torch import nn
30
 
31
+ from transformers import PretrainedConfig, dynamic_rope_update
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
  from transformers.generation import GenerationMixin
 
35
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
36
  from transformers.modeling_outputs import (
37
  BaseModelOutputWithPast,
 
85
 
86
 
87
  class LlamaRotaryEmbedding(nn.Module):
88
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
89
+
90
+ def __init__(self, config: LlamaConfig, device=None):
 
 
 
 
 
 
 
91
  super().__init__()
92
+ # BC: "rope_type" was originally "type"
93
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
94
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  else:
96
+ self.rope_type = "default"
97
+ self.max_seq_len_cached = config.max_position_embeddings
98
+ self.original_max_seq_len = config.max_position_embeddings
 
 
 
 
99
 
100
  self.config = config
101
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
102
 
103
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
104
  self.register_buffer("inv_freq", inv_freq, persistent=False)
105
  self.original_inv_freq = self.inv_freq
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  @torch.no_grad()
108
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
109
  def forward(self, x, position_ids):
110
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
 
111
  position_ids_expanded = position_ids[:, None, :].float()
112
+
113
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
114
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
 
115
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
+ cos = emb.cos() * self.attention_scaling
118
+ sin = emb.sin() * self.attention_scaling
 
 
 
 
119
 
120
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def rotate_half(x):
124
  """Rotates half the hidden dims of the input."""
125
  x1 = x[..., : x.shape[-1] // 2]
 
488
  return attn_output, attn_weights, past_key_value
489
 
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
 
494
  class LlamaDecoderLayer(nn.Module):
 
496
  super().__init__()
497
  self.hidden_size = config.hidden_size
498
 
499
+ self.self_attn = LlamaFlashAttention2(config=config, layer_idx=layer_idx)
500
 
501
  self.mlp = LlamaMLP(config)
502
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
780
  if position_ids is None:
781
  position_ids = cache_position.unsqueeze(0)
782
 
783
+ # Not needed for Flash Attention 2
784
+ causal_mask = None
 
785
  hidden_states = inputs_embeds
786
 
787
  # create position embeddings to be shared across the decoder layers
 
847
  attentions=all_self_attns,
848
  )
849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
851
  @staticmethod
852
  def _prepare_4d_causal_attention_mask_with_cache_position(