Fixes for HF >4.53.3 cache refactoring
Browse files- modeling_qwen.py +47 -10
modeling_qwen.py
CHANGED
|
@@ -274,7 +274,9 @@ class Qwen2Attention(nn.Module):
|
|
| 274 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 275 |
"with a layer index."
|
| 276 |
)
|
| 277 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 278 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 279 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 280 |
|
|
@@ -378,7 +380,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
| 378 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 379 |
"with a layer index."
|
| 380 |
)
|
| 381 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 382 |
|
| 383 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 384 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
@@ -676,7 +680,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
| 676 |
|
| 677 |
kv_seq_len = key_states.shape[-2]
|
| 678 |
if past_key_value is not None:
|
| 679 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 680 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 681 |
|
| 682 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -993,12 +999,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 993 |
use_cache = False
|
| 994 |
|
| 995 |
past_key_values_length = 0
|
|
|
|
| 996 |
|
| 997 |
if use_cache:
|
| 998 |
-
|
| 999 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
if position_ids is None:
|
| 1004 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
@@ -1104,7 +1126,11 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1104 |
|
| 1105 |
next_cache = None
|
| 1106 |
if use_cache:
|
| 1107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1108 |
|
| 1109 |
if not return_dict:
|
| 1110 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
@@ -1243,10 +1269,21 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1243 |
# Omit tokens covered by past_key_values
|
| 1244 |
if past_key_values is not None:
|
| 1245 |
if isinstance(past_key_values, Cache):
|
|
|
|
| 1246 |
cache_length = past_key_values.get_seq_length()
|
| 1247 |
-
past_length =
|
| 1248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1249 |
else:
|
|
|
|
|
|
|
| 1250 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1251 |
max_cache_length = None
|
| 1252 |
|
|
@@ -1287,7 +1324,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1287 |
model_inputs.update(
|
| 1288 |
{
|
| 1289 |
"position_ids": position_ids,
|
| 1290 |
-
"past_key_values": past_key_values,
|
| 1291 |
"use_cache": kwargs.get("use_cache"),
|
| 1292 |
"attention_mask": attention_mask,
|
| 1293 |
}
|
|
|
|
| 274 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 275 |
"with a layer index."
|
| 276 |
)
|
| 277 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 278 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 279 |
+
kv_seq_len += past_len
|
| 280 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 281 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 282 |
|
|
|
|
| 380 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 381 |
"with a layer index."
|
| 382 |
)
|
| 383 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 384 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 385 |
+
kv_seq_len += past_len
|
| 386 |
|
| 387 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 388 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
|
|
| 680 |
|
| 681 |
kv_seq_len = key_states.shape[-2]
|
| 682 |
if past_key_value is not None:
|
| 683 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 684 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 685 |
+
kv_seq_len += past_len
|
| 686 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 687 |
|
| 688 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 999 |
use_cache = False
|
| 1000 |
|
| 1001 |
past_key_values_length = 0
|
| 1002 |
+
use_legacy_cache = False
|
| 1003 |
|
| 1004 |
if use_cache:
|
| 1005 |
+
# OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
|
| 1006 |
+
# directly used legacy methods on it (would crash if None or new API).
|
| 1007 |
+
# use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1008 |
+
# if use_legacy_cache:
|
| 1009 |
+
# # past_key_values_length = past_key_values.get_seq_length()
|
| 1010 |
+
# past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 1011 |
+
|
| 1012 |
+
# NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
|
| 1013 |
+
# compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
|
| 1014 |
+
if past_key_values is not None and not isinstance(past_key_values, Cache):
|
| 1015 |
+
use_legacy_cache = True # remember input format for return
|
| 1016 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1017 |
+
|
| 1018 |
+
if isinstance(past_key_values, Cache):
|
| 1019 |
+
# Layer-agnostic total length; cache_position is handled deeper if needed
|
| 1020 |
+
past_key_values_length = past_key_values.get_seq_length()
|
| 1021 |
+
else:
|
| 1022 |
+
# No cache given on first forward, keep length at 0
|
| 1023 |
+
past_key_values_length = 0
|
| 1024 |
|
| 1025 |
if position_ids is None:
|
| 1026 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
| 1126 |
|
| 1127 |
next_cache = None
|
| 1128 |
if use_cache:
|
| 1129 |
+
# If the caller passed legacy, return legacy. Otherwise return the Cache object.
|
| 1130 |
+
next_cache = (
|
| 1131 |
+
next_decoder_cache.to_legacy_cache() if
|
| 1132 |
+
(use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
|
| 1133 |
+
)
|
| 1134 |
|
| 1135 |
if not return_dict:
|
| 1136 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
| 1269 |
# Omit tokens covered by past_key_values
|
| 1270 |
if past_key_values is not None:
|
| 1271 |
if isinstance(past_key_values, Cache):
|
| 1272 |
+
# NEW API (HF >= 4.55): use Cache methods
|
| 1273 |
cache_length = past_key_values.get_seq_length()
|
| 1274 |
+
past_length = cache_length # `seen_tokens` removed; use total seq length instead
|
| 1275 |
+
try:
|
| 1276 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1277 |
+
except Exception:
|
| 1278 |
+
max_cache_length = None
|
| 1279 |
+
|
| 1280 |
+
# OLD API (deprecated/removed):
|
| 1281 |
+
# cache_length = past_key_values.get_seq_length()
|
| 1282 |
+
# past_length = past_key_values.seen_tokens
|
| 1283 |
+
# max_cache_length = past_key_values.get_max_length()
|
| 1284 |
else:
|
| 1285 |
+
# Legacy tuple format: keep computing lengths directly from tensors
|
| 1286 |
+
# (We keep it compatible without forcing a conversion here)
|
| 1287 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1288 |
max_cache_length = None
|
| 1289 |
|
|
|
|
| 1324 |
model_inputs.update(
|
| 1325 |
{
|
| 1326 |
"position_ids": position_ids,
|
| 1327 |
+
"past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
|
| 1328 |
"use_cache": kwargs.get("use_cache"),
|
| 1329 |
"attention_mask": attention_mask,
|
| 1330 |
}
|