Commit
·
11599d0
1
Parent(s):
636ee83
Update modelling_RW.py
Browse filesBetter version, so can set config.use_cache = False in top level during model load, and gets to bottom level, for https://github.com/h2oai/h2ogpt/pull/297
- modelling_RW.py +48 -48
modelling_RW.py
CHANGED
|
@@ -52,10 +52,11 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 52 |
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
-
|
| 56 |
base=10000,
|
| 57 |
-
use_cache=False,
|
| 58 |
):
|
|
|
|
|
|
|
| 59 |
super().__init__()
|
| 60 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 61 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
@@ -64,7 +65,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 64 |
self.batch_size_cached = None
|
| 65 |
self.cos_cached: torch.Tensor | None = None
|
| 66 |
self.sin_cached: torch.Tensor | None = None
|
| 67 |
-
self.use_cache = use_cache
|
| 68 |
|
| 69 |
def cos_sin(
|
| 70 |
self,
|
|
@@ -107,7 +107,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 107 |
def forward(self, q, k):
|
| 108 |
batch, seq_len, head_dim = q.shape
|
| 109 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
def _make_causal_mask(
|
|
@@ -184,7 +187,7 @@ class Attention(nn.Module):
|
|
| 184 |
f" {self.num_heads})."
|
| 185 |
)
|
| 186 |
|
| 187 |
-
self.maybe_rotary = RotaryEmbedding(config
|
| 188 |
|
| 189 |
# Layer-wise attention scaling
|
| 190 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
@@ -192,44 +195,34 @@ class Attention(nn.Module):
|
|
| 192 |
|
| 193 |
self.query_key_value = Linear(
|
| 194 |
self.hidden_size,
|
| 195 |
-
|
| 196 |
bias=config.bias,
|
| 197 |
)
|
|
|
|
| 198 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 199 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 200 |
-
self.num_kv = config.
|
| 201 |
|
| 202 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 203 |
"""
|
| 204 |
-
Split the last dimension into (num_heads, head_dim), results share same memory
|
| 205 |
storage as `fused_qkv`
|
| 206 |
|
| 207 |
Args:
|
| 208 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 209 |
|
| 210 |
Returns:
|
| 211 |
-
query: [batch_size, seq_length, num_heads, head_dim]
|
| 212 |
-
key: [batch_size, seq_length, num_heads, head_dim]
|
| 213 |
value: [batch_size, seq_length, num_heads, head_dim]
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
q, k, v = [
|
| 224 |
-
rearrange(
|
| 225 |
-
x,
|
| 226 |
-
"batch seq_len group num_heads head_dim ->\
|
| 227 |
-
batch seq_len (group num_heads) head_dim",
|
| 228 |
-
head_dim=self.head_dim,
|
| 229 |
-
)
|
| 230 |
-
for x in [q, k, v]
|
| 231 |
-
]
|
| 232 |
-
return q, k, v
|
| 233 |
|
| 234 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 235 |
"""
|
|
@@ -275,11 +268,11 @@ class Attention(nn.Module):
|
|
| 275 |
|
| 276 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 277 |
key_layer = key_layer.transpose(1, 2).reshape(
|
| 278 |
-
batch_size * self.
|
| 279 |
q_length,
|
| 280 |
self.head_dim,
|
| 281 |
)
|
| 282 |
-
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.
|
| 283 |
|
| 284 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 285 |
|
|
@@ -300,12 +293,15 @@ class Attention(nn.Module):
|
|
| 300 |
|
| 301 |
if alibi is None:
|
| 302 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 303 |
-
key_layer_ = key_layer.reshape(batch_size, self.
|
| 304 |
-
value_layer_ = value_layer.reshape(batch_size, self.
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 311 |
x = x.permute(0, 2, 1, 3)
|
|
@@ -330,8 +326,7 @@ class Attention(nn.Module):
|
|
| 330 |
attention_scores = attention_scores.to(torch.float32)
|
| 331 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 332 |
attention_probs = F.softmax(
|
| 333 |
-
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
|
| 334 |
-
+ attention_mask_float,
|
| 335 |
dim=-1,
|
| 336 |
dtype=hidden_states.dtype,
|
| 337 |
)
|
|
@@ -380,12 +375,14 @@ class DecoderLayer(nn.Module):
|
|
| 380 |
super().__init__()
|
| 381 |
hidden_size = config.hidden_size
|
| 382 |
|
| 383 |
-
self.
|
| 384 |
-
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 385 |
-
|
| 386 |
self.num_heads = config.n_head
|
| 387 |
self.self_attention = Attention(config)
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
self.mlp = MLP(config)
|
| 390 |
|
| 391 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
|
@@ -404,14 +401,12 @@ class DecoderLayer(nn.Module):
|
|
| 404 |
output_attentions: bool = False,
|
| 405 |
):
|
| 406 |
|
| 407 |
-
|
| 408 |
-
ln_mlp = self.ln_mlp(hidden_states)
|
| 409 |
-
|
| 410 |
residual = hidden_states
|
| 411 |
|
| 412 |
# Self attention.
|
| 413 |
attn_outputs = self.self_attention(
|
| 414 |
-
|
| 415 |
layer_past=layer_past,
|
| 416 |
attention_mask=attention_mask,
|
| 417 |
alibi=alibi,
|
|
@@ -422,14 +417,19 @@ class DecoderLayer(nn.Module):
|
|
| 422 |
|
| 423 |
attention_output = attn_outputs[0]
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
outputs = attn_outputs[1:]
|
| 426 |
|
| 427 |
# MLP.
|
| 428 |
-
mlp_output = self.mlp(
|
| 429 |
|
| 430 |
-
|
| 431 |
-
mlp_output
|
| 432 |
-
|
|
|
|
| 433 |
|
| 434 |
if use_cache:
|
| 435 |
outputs = (output,) + outputs
|
|
|
|
| 52 |
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
+
config,
|
| 56 |
base=10000,
|
|
|
|
| 57 |
):
|
| 58 |
+
head_dim = config.head_dim
|
| 59 |
+
self.use_cache = config.use_cache
|
| 60 |
super().__init__()
|
| 61 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 62 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
| 65 |
self.batch_size_cached = None
|
| 66 |
self.cos_cached: torch.Tensor | None = None
|
| 67 |
self.sin_cached: torch.Tensor | None = None
|
|
|
|
| 68 |
|
| 69 |
def cos_sin(
|
| 70 |
self,
|
|
|
|
| 107 |
def forward(self, q, k):
|
| 108 |
batch, seq_len, head_dim = q.shape
|
| 109 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 110 |
+
try:
|
| 111 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise
|
| 114 |
|
| 115 |
|
| 116 |
def _make_causal_mask(
|
|
|
|
| 187 |
f" {self.num_heads})."
|
| 188 |
)
|
| 189 |
|
| 190 |
+
self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
|
| 191 |
|
| 192 |
# Layer-wise attention scaling
|
| 193 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
|
| 195 |
|
| 196 |
self.query_key_value = Linear(
|
| 197 |
self.hidden_size,
|
| 198 |
+
3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
|
| 199 |
bias=config.bias,
|
| 200 |
)
|
| 201 |
+
self.multi_query = config.multi_query
|
| 202 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 203 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 204 |
+
self.num_kv = config.n_head if not self.multi_query else 1
|
| 205 |
|
| 206 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 207 |
"""
|
| 208 |
+
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
| 209 |
storage as `fused_qkv`
|
| 210 |
|
| 211 |
Args:
|
| 212 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 213 |
|
| 214 |
Returns:
|
| 215 |
+
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
| 216 |
value: [batch_size, seq_length, num_heads, head_dim]
|
| 217 |
"""
|
| 218 |
+
if not self.multi_query:
|
| 219 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 220 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
| 221 |
+
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
| 222 |
+
else:
|
| 223 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 224 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
|
| 225 |
+
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 228 |
"""
|
|
|
|
| 268 |
|
| 269 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 270 |
key_layer = key_layer.transpose(1, 2).reshape(
|
| 271 |
+
batch_size * self.num_kv,
|
| 272 |
q_length,
|
| 273 |
self.head_dim,
|
| 274 |
)
|
| 275 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
| 276 |
|
| 277 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 278 |
|
|
|
|
| 293 |
|
| 294 |
if alibi is None:
|
| 295 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 296 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 297 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 298 |
|
| 299 |
+
try:
|
| 300 |
+
attn_output = F.scaled_dot_product_attention(
|
| 301 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
| 302 |
+
)
|
| 303 |
+
except Exception as e:
|
| 304 |
+
raise
|
| 305 |
|
| 306 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 307 |
x = x.permute(0, 2, 1, 3)
|
|
|
|
| 326 |
attention_scores = attention_scores.to(torch.float32)
|
| 327 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 328 |
attention_probs = F.softmax(
|
| 329 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
|
|
|
|
| 330 |
dim=-1,
|
| 331 |
dtype=hidden_states.dtype,
|
| 332 |
)
|
|
|
|
| 375 |
super().__init__()
|
| 376 |
hidden_size = config.hidden_size
|
| 377 |
|
| 378 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
| 379 |
self.num_heads = config.n_head
|
| 380 |
self.self_attention = Attention(config)
|
| 381 |
|
| 382 |
+
if not config.parallel_attn:
|
| 383 |
+
# unused if parallel attn
|
| 384 |
+
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 385 |
+
|
| 386 |
self.mlp = MLP(config)
|
| 387 |
|
| 388 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
|
|
|
| 401 |
output_attentions: bool = False,
|
| 402 |
):
|
| 403 |
|
| 404 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
| 405 |
residual = hidden_states
|
| 406 |
|
| 407 |
# Self attention.
|
| 408 |
attn_outputs = self.self_attention(
|
| 409 |
+
layernorm_output,
|
| 410 |
layer_past=layer_past,
|
| 411 |
attention_mask=attention_mask,
|
| 412 |
alibi=alibi,
|
|
|
|
| 417 |
|
| 418 |
attention_output = attn_outputs[0]
|
| 419 |
|
| 420 |
+
if not self.config.parallel_attn:
|
| 421 |
+
residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
|
| 422 |
+
layernorm_output = self.post_attention_layernorm(residual)
|
| 423 |
+
|
| 424 |
outputs = attn_outputs[1:]
|
| 425 |
|
| 426 |
# MLP.
|
| 427 |
+
mlp_output = self.mlp(layernorm_output)
|
| 428 |
|
| 429 |
+
if self.config.parallel_attn:
|
| 430 |
+
mlp_output += attention_output
|
| 431 |
+
|
| 432 |
+
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
| 433 |
|
| 434 |
if use_cache:
|
| 435 |
outputs = (output,) + outputs
|