Update code for latest transformer
Browse files- README.md +6 -8
- config.json +4 -4
- config.yaml +4 -4
- model.py +3 -13
- transformer_backbone.py +22 -267
README.md
CHANGED
|
@@ -69,10 +69,10 @@ We do not collect PII (personally identifiable information) for any of these cha
|
|
| 69 |
|
| 70 |
## Inference
|
| 71 |
|
| 72 |
-
We provide an inference module compatible with HuggingFace Transformers for running model inference.
|
| 73 |
|
| 74 |
```shell
|
| 75 |
-
pip install 'hat-splitter>=0.1.9'
|
| 76 |
pip install flash_attn
|
| 77 |
```
|
| 78 |
|
|
@@ -82,13 +82,12 @@ Download model weights and run inference using the following example:
|
|
| 82 |
import torch
|
| 83 |
from transformers import AutoModelForCausalLM
|
| 84 |
INPUT ="When was Rome founded?"
|
| 85 |
-
MODEL_ID = "Aleph-Alpha/
|
| 86 |
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
trust_remote_code=True,
|
| 88 |
-
pretrained_model_name_or_path=MODEL_ID
|
| 89 |
-
attn_implementation="flash_attention_2",
|
| 90 |
).to("cuda", torch.bfloat16)
|
| 91 |
-
input_ids, cumulative_word_lengths = model._prepare_input(INPUT)
|
| 92 |
model_output = model.generate(
|
| 93 |
input_ids,
|
| 94 |
cumulative_seq_lengths_per_word=cumulative_word_lengths,
|
|
@@ -99,8 +98,7 @@ print("Prompt: ", INPUT)
|
|
| 99 |
print("Completion: ", model_output.completion_text)
|
| 100 |
```
|
| 101 |
|
| 102 |
-
Please note that the realized inference speed strongly depends on the maturity of the inference implementation beyond the intrinsic text compression of any model. Besides this huggingface transformers-based inference solution, we are also releasing a [vLLM-based inference solution](https://github.com/Aleph-Alpha/vllm) for our models that is optimized for batched inference.
|
| 103 |
-
|
| 104 |
|
| 105 |
# Evaluation
|
| 106 |
|
|
|
|
| 69 |
|
| 70 |
## Inference
|
| 71 |
|
| 72 |
+
We provide an inference module compatible with HuggingFace Transformers for running model inference. Before executing the inference example below, make sure the [hat-splitter package](https://pypi.org/project/hat-splitter/) is installed in your environment.
|
| 73 |
|
| 74 |
```shell
|
| 75 |
+
pip install 'hat-splitter>=0.1.9' transformers torch
|
| 76 |
pip install flash_attn
|
| 77 |
```
|
| 78 |
|
|
|
|
| 82 |
import torch
|
| 83 |
from transformers import AutoModelForCausalLM
|
| 84 |
INPUT ="When was Rome founded?"
|
| 85 |
+
MODEL_ID = "Aleph-Alpha/llama-tfree-hat-pretrained-7b-base"
|
| 86 |
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
trust_remote_code=True,
|
| 88 |
+
pretrained_model_name_or_path=MODEL_ID
|
|
|
|
| 89 |
).to("cuda", torch.bfloat16)
|
| 90 |
+
input_ids, cumulative_word_lengths = model._prepare_input(INPUT, add_llama_template=True)
|
| 91 |
model_output = model.generate(
|
| 92 |
input_ids,
|
| 93 |
cumulative_seq_lengths_per_word=cumulative_word_lengths,
|
|
|
|
| 98 |
print("Completion: ", model_output.completion_text)
|
| 99 |
```
|
| 100 |
|
| 101 |
+
Please note that the realized inference speed strongly depends on the maturity of the inference implementation beyond the intrinsic text compression of any model. Besides this huggingface transformers-based inference solution, we are also releasing a [vLLM-based inference solution](https://github.com/Aleph-Alpha/vllm) for our models that is optimized for batched inference. We are still waiting for a review on our PR into VLLM, but we strongly encourage to use the optimized inference provided through VLLM.
|
|
|
|
| 102 |
|
| 103 |
# Evaluation
|
| 104 |
|
config.json
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
"is_neox_style": true,
|
| 13 |
"key_query_norm": true,
|
| 14 |
"key_query_norm_per_head": true,
|
| 15 |
-
"max_position_embeddings":
|
| 16 |
"mlp_bias": false,
|
| 17 |
"num_attention_heads": 32,
|
| 18 |
"num_hidden_layers": 32,
|
|
@@ -44,7 +44,7 @@
|
|
| 44 |
"is_neox_style": true,
|
| 45 |
"key_query_norm": true,
|
| 46 |
"key_query_norm_per_head": true,
|
| 47 |
-
"max_position_embeddings":
|
| 48 |
"mlp_bias": false,
|
| 49 |
"num_attention_heads": 8,
|
| 50 |
"num_hidden_layers": 4,
|
|
@@ -75,7 +75,7 @@
|
|
| 75 |
"is_neox_style": true,
|
| 76 |
"key_query_norm": true,
|
| 77 |
"key_query_norm_per_head": true,
|
| 78 |
-
"max_position_embeddings":
|
| 79 |
"mlp_bias": false,
|
| 80 |
"num_attention_heads": 8,
|
| 81 |
"num_hidden_layers": 6,
|
|
@@ -90,7 +90,7 @@
|
|
| 90 |
"use_cache": true,
|
| 91 |
"vocab_size": 256
|
| 92 |
},
|
| 93 |
-
"max_position_embeddings":
|
| 94 |
"max_word_size": 100,
|
| 95 |
"model_type": "hierarchical_autoregressive_transformer",
|
| 96 |
"sliding_window": 768,
|
|
|
|
| 12 |
"is_neox_style": true,
|
| 13 |
"key_query_norm": true,
|
| 14 |
"key_query_norm_per_head": true,
|
| 15 |
+
"max_position_embeddings": 32900,
|
| 16 |
"mlp_bias": false,
|
| 17 |
"num_attention_heads": 32,
|
| 18 |
"num_hidden_layers": 32,
|
|
|
|
| 44 |
"is_neox_style": true,
|
| 45 |
"key_query_norm": true,
|
| 46 |
"key_query_norm_per_head": true,
|
| 47 |
+
"max_position_embeddings": 262144,
|
| 48 |
"mlp_bias": false,
|
| 49 |
"num_attention_heads": 8,
|
| 50 |
"num_hidden_layers": 4,
|
|
|
|
| 75 |
"is_neox_style": true,
|
| 76 |
"key_query_norm": true,
|
| 77 |
"key_query_norm_per_head": true,
|
| 78 |
+
"max_position_embeddings": 262144,
|
| 79 |
"mlp_bias": false,
|
| 80 |
"num_attention_heads": 8,
|
| 81 |
"num_hidden_layers": 6,
|
|
|
|
| 90 |
"use_cache": true,
|
| 91 |
"vocab_size": 256
|
| 92 |
},
|
| 93 |
+
"max_position_embeddings": 262144,
|
| 94 |
"max_word_size": 100,
|
| 95 |
"model_type": "hierarchical_autoregressive_transformer",
|
| 96 |
"sliding_window": 768,
|
config.yaml
CHANGED
|
@@ -6,7 +6,7 @@ encoder_config:
|
|
| 6 |
num_key_value_heads: 8
|
| 7 |
rms_norm_eps: 1.0e-05
|
| 8 |
intermediate_size: 2816
|
| 9 |
-
max_position_embeddings:
|
| 10 |
rope_scaling:
|
| 11 |
rope_type: default
|
| 12 |
rope_theta: 100000
|
|
@@ -34,7 +34,7 @@ backbone_config:
|
|
| 34 |
num_key_value_heads: 8
|
| 35 |
rms_norm_eps: 1.0e-05
|
| 36 |
intermediate_size: 14336
|
| 37 |
-
max_position_embeddings:
|
| 38 |
rope_scaling:
|
| 39 |
rope_type: default
|
| 40 |
rope_theta: 500000
|
|
@@ -53,7 +53,7 @@ decoder_config:
|
|
| 53 |
num_key_value_heads: 8
|
| 54 |
rms_norm_eps: 1.0e-05
|
| 55 |
intermediate_size: 2816
|
| 56 |
-
max_position_embeddings:
|
| 57 |
rope_scaling:
|
| 58 |
rope_type: default
|
| 59 |
rope_theta: 100000
|
|
@@ -82,7 +82,7 @@ auto_map:
|
|
| 82 |
special_token_dict: {}
|
| 83 |
max_word_size: 100
|
| 84 |
sliding_window: 768
|
| 85 |
-
max_position_embeddings:
|
| 86 |
torch_dtype: bfloat16
|
| 87 |
architectures:
|
| 88 |
- HATDecoderForCausalLM
|
|
|
|
| 6 |
num_key_value_heads: 8
|
| 7 |
rms_norm_eps: 1.0e-05
|
| 8 |
intermediate_size: 2816
|
| 9 |
+
max_position_embeddings: 262144
|
| 10 |
rope_scaling:
|
| 11 |
rope_type: default
|
| 12 |
rope_theta: 100000
|
|
|
|
| 34 |
num_key_value_heads: 8
|
| 35 |
rms_norm_eps: 1.0e-05
|
| 36 |
intermediate_size: 14336
|
| 37 |
+
max_position_embeddings: 32900
|
| 38 |
rope_scaling:
|
| 39 |
rope_type: default
|
| 40 |
rope_theta: 500000
|
|
|
|
| 53 |
num_key_value_heads: 8
|
| 54 |
rms_norm_eps: 1.0e-05
|
| 55 |
intermediate_size: 2816
|
| 56 |
+
max_position_embeddings: 262144
|
| 57 |
rope_scaling:
|
| 58 |
rope_type: default
|
| 59 |
rope_theta: 100000
|
|
|
|
| 82 |
special_token_dict: {}
|
| 83 |
max_word_size: 100
|
| 84 |
sliding_window: 768
|
| 85 |
+
max_position_embeddings: 262144
|
| 86 |
torch_dtype: bfloat16
|
| 87 |
architectures:
|
| 88 |
- HATDecoderForCausalLM
|
model.py
CHANGED
|
@@ -26,12 +26,6 @@ from .transformer_backbone import (
|
|
| 26 |
LlamaRotaryEmbedding,
|
| 27 |
)
|
| 28 |
|
| 29 |
-
try:
|
| 30 |
-
transformers_version = version("transformers")
|
| 31 |
-
if transformers_version != "4.46.3":
|
| 32 |
-
print(f"Warning: Expecected transformers version 4.46.3, but found {transformers_version}. Outputs might be different.")
|
| 33 |
-
except PackageNotFoundError:
|
| 34 |
-
print("transformers is not installed")
|
| 35 |
|
| 36 |
|
| 37 |
def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
|
|
@@ -41,13 +35,12 @@ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
|
|
| 41 |
LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 42 |
|
| 43 |
|
| 44 |
-
class HATCache
|
| 45 |
encoder_cache: DynamicCache
|
| 46 |
backbone_cache: DynamicCache
|
| 47 |
decoder_cache: DynamicCache
|
| 48 |
|
| 49 |
-
def __init__(self
|
| 50 |
-
super().__init__(*args, **kwargs)
|
| 51 |
self.encoder_cache = DynamicCache()
|
| 52 |
self.backbone_cache = DynamicCache()
|
| 53 |
self.decoder_cache = DynamicCache()
|
|
@@ -382,10 +375,7 @@ class HATCrossAttention(nn.Module):
|
|
| 382 |
|
| 383 |
self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
|
| 384 |
|
| 385 |
-
|
| 386 |
-
rope_type = config.rope_scaling["rope_type"]
|
| 387 |
-
|
| 388 |
-
self.rotary_emb = LlamaRotaryEmbedding(dim=self.head_dim, base=rope_theta, rope_type=rope_type)
|
| 389 |
|
| 390 |
def forward(
|
| 391 |
self,
|
|
|
|
| 26 |
LlamaRotaryEmbedding,
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 35 |
LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 36 |
|
| 37 |
|
| 38 |
+
class HATCache:
|
| 39 |
encoder_cache: DynamicCache
|
| 40 |
backbone_cache: DynamicCache
|
| 41 |
decoder_cache: DynamicCache
|
| 42 |
|
| 43 |
+
def __init__(self):
|
|
|
|
| 44 |
self.encoder_cache = DynamicCache()
|
| 45 |
self.backbone_cache = DynamicCache()
|
| 46 |
self.decoder_cache = DynamicCache()
|
|
|
|
| 375 |
|
| 376 |
self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
|
| 377 |
|
| 378 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
def forward(
|
| 381 |
self,
|
transformer_backbone.py
CHANGED
|
@@ -22,17 +22,16 @@
|
|
| 22 |
# See the License for the specific language governing permissions and
|
| 23 |
# limitations under the License.
|
| 24 |
import math
|
| 25 |
-
from typing import List, Optional, Tuple, Union
|
| 26 |
|
| 27 |
import torch
|
| 28 |
import torch.nn.functional as F
|
| 29 |
-
import torch.utils.checkpoint
|
| 30 |
from torch import nn
|
| 31 |
|
|
|
|
| 32 |
from transformers.activations import ACT2FN
|
| 33 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 34 |
from transformers.generation import GenerationMixin
|
| 35 |
-
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 36 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 37 |
from transformers.modeling_outputs import (
|
| 38 |
BaseModelOutputWithPast,
|
|
@@ -86,117 +85,41 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
|
| 86 |
|
| 87 |
|
| 88 |
class LlamaRotaryEmbedding(nn.Module):
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
max_position_embeddings=2048,
|
| 93 |
-
base=10000,
|
| 94 |
-
device=None,
|
| 95 |
-
scaling_factor=1.0,
|
| 96 |
-
rope_type="default",
|
| 97 |
-
config: Optional[LlamaConfig] = None,
|
| 98 |
-
):
|
| 99 |
super().__init__()
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
logger.warning_once(
|
| 104 |
-
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 105 |
-
"`config` argument. All other arguments will be removed in v4.46"
|
| 106 |
-
)
|
| 107 |
-
self.rope_kwargs = {
|
| 108 |
-
"rope_type": rope_type,
|
| 109 |
-
"factor": scaling_factor,
|
| 110 |
-
"dim": dim,
|
| 111 |
-
"base": base,
|
| 112 |
-
"max_position_embeddings": max_position_embeddings,
|
| 113 |
-
}
|
| 114 |
-
self.rope_type = rope_type
|
| 115 |
-
self.max_seq_len_cached = max_position_embeddings
|
| 116 |
-
self.original_max_seq_len = max_position_embeddings
|
| 117 |
else:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
else:
|
| 122 |
-
self.rope_type = "default"
|
| 123 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 124 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 125 |
|
| 126 |
self.config = config
|
| 127 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 128 |
|
| 129 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device
|
| 130 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 131 |
self.original_inv_freq = self.inv_freq
|
| 132 |
|
| 133 |
-
def _dynamic_frequency_update(self, position_ids, device):
|
| 134 |
-
"""
|
| 135 |
-
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 136 |
-
1 - growing beyond the cached sequence length (allow scaling)
|
| 137 |
-
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 138 |
-
"""
|
| 139 |
-
seq_len = torch.max(position_ids) + 1
|
| 140 |
-
if seq_len > self.max_seq_len_cached: # growth
|
| 141 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 142 |
-
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 143 |
-
)
|
| 144 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 145 |
-
self.max_seq_len_cached = seq_len
|
| 146 |
-
|
| 147 |
-
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 148 |
-
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 149 |
-
self.max_seq_len_cached = self.original_max_seq_len
|
| 150 |
-
|
| 151 |
@torch.no_grad()
|
|
|
|
| 152 |
def forward(self, x, position_ids):
|
| 153 |
-
|
| 154 |
-
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 155 |
-
|
| 156 |
-
# Core RoPE block
|
| 157 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 158 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 159 |
-
|
| 160 |
-
device_type = x.device.type
|
| 161 |
-
device_type
|
| 162 |
-
with torch.autocast(device_type=device_type, enabled=False):
|
| 163 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 164 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 165 |
-
cos = emb.cos()
|
| 166 |
-
sin = emb.sin()
|
| 167 |
-
|
| 168 |
-
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 169 |
-
cos = cos * self.attention_scaling
|
| 170 |
-
sin = sin * self.attention_scaling
|
| 171 |
|
| 172 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 173 |
|
| 174 |
|
| 175 |
-
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 176 |
-
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 177 |
-
|
| 178 |
-
def __init__(self, *args, **kwargs):
|
| 179 |
-
logger.warning_once(
|
| 180 |
-
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
| 181 |
-
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
| 182 |
-
)
|
| 183 |
-
kwargs["rope_type"] = "linear"
|
| 184 |
-
super().__init__(*args, **kwargs)
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 188 |
-
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 189 |
-
|
| 190 |
-
def __init__(self, *args, **kwargs):
|
| 191 |
-
logger.warning_once(
|
| 192 |
-
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
| 193 |
-
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
| 194 |
-
"__init__)."
|
| 195 |
-
)
|
| 196 |
-
kwargs["rope_type"] = "dynamic"
|
| 197 |
-
super().__init__(*args, **kwargs)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
def rotate_half(x):
|
| 201 |
"""Rotates half the hidden dims of the input."""
|
| 202 |
x1 = x[..., : x.shape[-1] // 2]
|
|
@@ -565,110 +488,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 565 |
return attn_output, attn_weights, past_key_value
|
| 566 |
|
| 567 |
|
| 568 |
-
class LlamaSdpaAttention(LlamaAttention):
|
| 569 |
-
"""
|
| 570 |
-
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 571 |
-
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 572 |
-
SDPA API.
|
| 573 |
-
"""
|
| 574 |
-
|
| 575 |
-
# Adapted from LlamaAttention.forward
|
| 576 |
-
def forward(
|
| 577 |
-
self,
|
| 578 |
-
hidden_states: torch.Tensor,
|
| 579 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 580 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 581 |
-
past_key_value: Optional[Cache] = None,
|
| 582 |
-
output_attentions: bool = False,
|
| 583 |
-
use_cache: bool = False,
|
| 584 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 585 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 586 |
-
**kwargs,
|
| 587 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 588 |
-
if output_attentions:
|
| 589 |
-
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 590 |
-
logger.warning_once(
|
| 591 |
-
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 592 |
-
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 593 |
-
)
|
| 594 |
-
return super().forward(
|
| 595 |
-
hidden_states=hidden_states,
|
| 596 |
-
attention_mask=attention_mask,
|
| 597 |
-
position_ids=position_ids,
|
| 598 |
-
past_key_value=past_key_value,
|
| 599 |
-
output_attentions=output_attentions,
|
| 600 |
-
use_cache=use_cache,
|
| 601 |
-
cache_position=cache_position,
|
| 602 |
-
position_embeddings=position_embeddings,
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
bsz, q_len, _ = hidden_states.size()
|
| 606 |
-
|
| 607 |
-
query_states = self.q_proj(hidden_states)
|
| 608 |
-
key_states = self.k_proj(hidden_states)
|
| 609 |
-
value_states = self.v_proj(hidden_states)
|
| 610 |
-
|
| 611 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 612 |
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 613 |
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 614 |
-
|
| 615 |
-
if position_embeddings is None:
|
| 616 |
-
logger.warning_once(
|
| 617 |
-
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 618 |
-
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 619 |
-
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 620 |
-
"removed and `position_embeddings` will be mandatory."
|
| 621 |
-
)
|
| 622 |
-
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 623 |
-
else:
|
| 624 |
-
cos, sin = position_embeddings
|
| 625 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 626 |
-
|
| 627 |
-
if past_key_value is not None:
|
| 628 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 629 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 630 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 631 |
-
|
| 632 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 633 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 634 |
-
|
| 635 |
-
causal_mask = attention_mask
|
| 636 |
-
if attention_mask is not None:
|
| 637 |
-
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 638 |
-
|
| 639 |
-
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 640 |
-
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 641 |
-
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 642 |
-
query_states = query_states.contiguous()
|
| 643 |
-
key_states = key_states.contiguous()
|
| 644 |
-
value_states = value_states.contiguous()
|
| 645 |
-
|
| 646 |
-
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 647 |
-
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 648 |
-
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 649 |
|
| 650 |
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 651 |
-
query_states,
|
| 652 |
-
key_states,
|
| 653 |
-
value_states,
|
| 654 |
-
attn_mask=causal_mask,
|
| 655 |
-
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 656 |
-
is_causal=is_causal,
|
| 657 |
-
)
|
| 658 |
-
|
| 659 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 660 |
-
attn_output = attn_output.view(bsz, q_len, -1)
|
| 661 |
-
|
| 662 |
-
attn_output = self.o_proj(attn_output)
|
| 663 |
-
|
| 664 |
-
return attn_output, None, past_key_value
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
LLAMA_ATTENTION_CLASSES = {
|
| 668 |
-
"eager": LlamaAttention,
|
| 669 |
-
"flash_attention_2": LlamaFlashAttention2,
|
| 670 |
-
"sdpa": LlamaSdpaAttention,
|
| 671 |
-
}
|
| 672 |
|
| 673 |
|
| 674 |
class LlamaDecoderLayer(nn.Module):
|
|
@@ -676,7 +496,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
| 676 |
super().__init__()
|
| 677 |
self.hidden_size = config.hidden_size
|
| 678 |
|
| 679 |
-
self.self_attn =
|
| 680 |
|
| 681 |
self.mlp = LlamaMLP(config)
|
| 682 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -960,9 +780,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 960 |
if position_ids is None:
|
| 961 |
position_ids = cache_position.unsqueeze(0)
|
| 962 |
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
)
|
| 966 |
hidden_states = inputs_embeds
|
| 967 |
|
| 968 |
# create position embeddings to be shared across the decoder layers
|
|
@@ -1028,70 +847,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 1028 |
attentions=all_self_attns,
|
| 1029 |
)
|
| 1030 |
|
| 1031 |
-
def _update_causal_mask(
|
| 1032 |
-
self,
|
| 1033 |
-
attention_mask: torch.Tensor,
|
| 1034 |
-
input_tensor: torch.Tensor,
|
| 1035 |
-
cache_position: torch.Tensor,
|
| 1036 |
-
past_key_values: Cache,
|
| 1037 |
-
output_attentions: bool,
|
| 1038 |
-
):
|
| 1039 |
-
if self.config._attn_implementation == "flash_attention_2":
|
| 1040 |
-
if attention_mask is not None and 0.0 in attention_mask:
|
| 1041 |
-
return attention_mask
|
| 1042 |
-
return None
|
| 1043 |
-
|
| 1044 |
-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1045 |
-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1046 |
-
# to infer the attention mask.
|
| 1047 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1048 |
-
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1049 |
-
|
| 1050 |
-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1051 |
-
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1052 |
-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1053 |
-
attention_mask,
|
| 1054 |
-
inputs_embeds=input_tensor,
|
| 1055 |
-
past_key_values_length=past_seen_tokens,
|
| 1056 |
-
is_training=self.training,
|
| 1057 |
-
):
|
| 1058 |
-
return None
|
| 1059 |
-
|
| 1060 |
-
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1061 |
-
sequence_length = input_tensor.shape[1]
|
| 1062 |
-
if using_static_cache:
|
| 1063 |
-
target_length = past_key_values.get_max_cache_shape()
|
| 1064 |
-
else:
|
| 1065 |
-
target_length = (
|
| 1066 |
-
attention_mask.shape[-1]
|
| 1067 |
-
if isinstance(attention_mask, torch.Tensor)
|
| 1068 |
-
else past_seen_tokens + sequence_length + 1
|
| 1069 |
-
)
|
| 1070 |
-
|
| 1071 |
-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1072 |
-
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1073 |
-
attention_mask,
|
| 1074 |
-
sequence_length=sequence_length,
|
| 1075 |
-
target_length=target_length,
|
| 1076 |
-
dtype=dtype,
|
| 1077 |
-
device=device,
|
| 1078 |
-
cache_position=cache_position,
|
| 1079 |
-
batch_size=input_tensor.shape[0],
|
| 1080 |
-
)
|
| 1081 |
-
|
| 1082 |
-
if (
|
| 1083 |
-
self.config._attn_implementation == "sdpa"
|
| 1084 |
-
and attention_mask is not None
|
| 1085 |
-
and attention_mask.device.type == "cuda"
|
| 1086 |
-
and not output_attentions
|
| 1087 |
-
):
|
| 1088 |
-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1089 |
-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1090 |
-
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1091 |
-
min_dtype = torch.finfo(dtype).min
|
| 1092 |
-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1093 |
-
|
| 1094 |
-
return causal_mask
|
| 1095 |
|
| 1096 |
@staticmethod
|
| 1097 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
|
| 22 |
# See the License for the specific language governing permissions and
|
| 23 |
# limitations under the License.
|
| 24 |
import math
|
| 25 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 26 |
|
| 27 |
import torch
|
| 28 |
import torch.nn.functional as F
|
|
|
|
| 29 |
from torch import nn
|
| 30 |
|
| 31 |
+
from transformers import PretrainedConfig, dynamic_rope_update
|
| 32 |
from transformers.activations import ACT2FN
|
| 33 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 34 |
from transformers.generation import GenerationMixin
|
|
|
|
| 35 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 36 |
from transformers.modeling_outputs import (
|
| 37 |
BaseModelOutputWithPast,
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
class LlamaRotaryEmbedding(nn.Module):
|
| 88 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 89 |
+
|
| 90 |
+
def __init__(self, config: LlamaConfig, device=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
super().__init__()
|
| 92 |
+
# BC: "rope_type" was originally "type"
|
| 93 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 94 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
+
self.rope_type = "default"
|
| 97 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 98 |
+
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
self.config = config
|
| 101 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 102 |
|
| 103 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 104 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 105 |
self.original_inv_freq = self.inv_freq
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
@torch.no_grad()
|
| 108 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 109 |
def forward(self, x, position_ids):
|
| 110 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 112 |
+
|
| 113 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 114 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
|
|
| 115 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 117 |
+
cos = emb.cos() * self.attention_scaling
|
| 118 |
+
sin = emb.sin() * self.attention_scaling
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def rotate_half(x):
|
| 124 |
"""Rotates half the hidden dims of the input."""
|
| 125 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
|
| 488 |
return attn_output, attn_weights, past_key_value
|
| 489 |
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
|
| 494 |
class LlamaDecoderLayer(nn.Module):
|
|
|
|
| 496 |
super().__init__()
|
| 497 |
self.hidden_size = config.hidden_size
|
| 498 |
|
| 499 |
+
self.self_attn = LlamaFlashAttention2(config=config, layer_idx=layer_idx)
|
| 500 |
|
| 501 |
self.mlp = LlamaMLP(config)
|
| 502 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 780 |
if position_ids is None:
|
| 781 |
position_ids = cache_position.unsqueeze(0)
|
| 782 |
|
| 783 |
+
# Not needed for Flash Attention 2
|
| 784 |
+
causal_mask = None
|
|
|
|
| 785 |
hidden_states = inputs_embeds
|
| 786 |
|
| 787 |
# create position embeddings to be shared across the decoder layers
|
|
|
|
| 847 |
attentions=all_self_attns,
|
| 848 |
)
|
| 849 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
|
| 851 |
@staticmethod
|
| 852 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|