sdadas commited on
Commit
18af555
·
verified ·
1 Parent(s): a6d84d0

Fixes for HF >4.53.3 cache refactoring

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +47 -10
modeling_qwen.py CHANGED
@@ -274,7 +274,9 @@ class Qwen2Attention(nn.Module):
274
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
275
  "with a layer index."
276
  )
277
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
278
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
279
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
280
 
@@ -378,7 +380,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
378
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
379
  "with a layer index."
380
  )
381
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
382
 
383
  # Because the input can be padded, the absolute sequence length depends on the max position id.
384
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -676,7 +680,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
676
 
677
  kv_seq_len = key_states.shape[-2]
678
  if past_key_value is not None:
679
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
680
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
681
 
682
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -993,12 +999,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
993
  use_cache = False
994
 
995
  past_key_values_length = 0
 
996
 
997
  if use_cache:
998
- use_legacy_cache = not isinstance(past_key_values, Cache)
999
- if use_legacy_cache:
 
 
 
 
 
 
 
 
 
1000
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1001
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
 
 
 
1002
 
1003
  if position_ids is None:
1004
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1104,7 +1126,11 @@ class Qwen2Model(Qwen2PreTrainedModel):
1104
 
1105
  next_cache = None
1106
  if use_cache:
1107
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
 
1108
 
1109
  if not return_dict:
1110
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -1243,10 +1269,21 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1243
  # Omit tokens covered by past_key_values
1244
  if past_key_values is not None:
1245
  if isinstance(past_key_values, Cache):
 
1246
  cache_length = past_key_values.get_seq_length()
1247
- past_length = past_key_values.seen_tokens
1248
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
 
 
 
1249
  else:
 
 
1250
  cache_length = past_length = past_key_values[0][0].shape[2]
1251
  max_cache_length = None
1252
 
@@ -1287,7 +1324,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1287
  model_inputs.update(
1288
  {
1289
  "position_ids": position_ids,
1290
- "past_key_values": past_key_values,
1291
  "use_cache": kwargs.get("use_cache"),
1292
  "attention_mask": attention_mask,
1293
  }
 
274
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
275
  "with a layer index."
276
  )
277
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
278
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
279
+ kv_seq_len += past_len
280
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
281
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
282
 
 
380
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
381
  "with a layer index."
382
  )
383
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
384
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
385
+ kv_seq_len += past_len
386
 
387
  # Because the input can be padded, the absolute sequence length depends on the max position id.
388
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
 
680
 
681
  kv_seq_len = key_states.shape[-2]
682
  if past_key_value is not None:
683
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
684
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
685
+ kv_seq_len += past_len
686
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
687
 
688
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
999
  use_cache = False
1000
 
1001
  past_key_values_length = 0
1002
+ use_legacy_cache = False
1003
 
1004
  if use_cache:
1005
+ # OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
1006
+ # directly used legacy methods on it (would crash if None or new API).
1007
+ # use_legacy_cache = not isinstance(past_key_values, Cache)
1008
+ # if use_legacy_cache:
1009
+ # # past_key_values_length = past_key_values.get_seq_length()
1010
+ # past_key_values_length = past_key_values.get_usable_length(seq_length)
1011
+
1012
+ # NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
1013
+ # compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
1014
+ if past_key_values is not None and not isinstance(past_key_values, Cache):
1015
+ use_legacy_cache = True # remember input format for return
1016
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1017
+
1018
+ if isinstance(past_key_values, Cache):
1019
+ # Layer-agnostic total length; cache_position is handled deeper if needed
1020
+ past_key_values_length = past_key_values.get_seq_length()
1021
+ else:
1022
+ # No cache given on first forward, keep length at 0
1023
+ past_key_values_length = 0
1024
 
1025
  if position_ids is None:
1026
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1126
 
1127
  next_cache = None
1128
  if use_cache:
1129
+ # If the caller passed legacy, return legacy. Otherwise return the Cache object.
1130
+ next_cache = (
1131
+ next_decoder_cache.to_legacy_cache() if
1132
+ (use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
1133
+ )
1134
 
1135
  if not return_dict:
1136
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
1269
  # Omit tokens covered by past_key_values
1270
  if past_key_values is not None:
1271
  if isinstance(past_key_values, Cache):
1272
+ # NEW API (HF >= 4.55): use Cache methods
1273
  cache_length = past_key_values.get_seq_length()
1274
+ past_length = cache_length # `seen_tokens` removed; use total seq length instead
1275
+ try:
1276
+ max_cache_length = past_key_values.get_max_cache_shape()
1277
+ except Exception:
1278
+ max_cache_length = None
1279
+
1280
+ # OLD API (deprecated/removed):
1281
+ # cache_length = past_key_values.get_seq_length()
1282
+ # past_length = past_key_values.seen_tokens
1283
+ # max_cache_length = past_key_values.get_max_length()
1284
  else:
1285
+ # Legacy tuple format: keep computing lengths directly from tensors
1286
+ # (We keep it compatible without forcing a conversion here)
1287
  cache_length = past_length = past_key_values[0][0].shape[2]
1288
  max_cache_length = None
1289
 
 
1324
  model_inputs.update(
1325
  {
1326
  "position_ids": position_ids,
1327
+ "past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
1328
  "use_cache": kwargs.get("use_cache"),
1329
  "attention_mask": attention_mask,
1330
  }