NePe commited on
Commit
36f62ed
·
verified ·
1 Parent(s): 420ba7d

variable_cache.py compatibility for v4.57.2 / python3.12

Browse files
Files changed (1) hide show
  1. variable_cache.py +17 -11
variable_cache.py CHANGED
@@ -31,6 +31,9 @@ class VariableCache(Cache_4_44_2, Cache):
31
  The default implementation for the layer caches is StaticCache.
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
 
 
 
34
 
35
  def __init__(
36
  self,
@@ -50,7 +53,7 @@ class VariableCache(Cache_4_44_2, Cache):
50
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
  self.dtype = dtype
52
 
53
- self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
54
  self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
55
 
56
  def update(
@@ -60,11 +63,11 @@ class VariableCache(Cache_4_44_2, Cache):
60
  layer_idx: int,
61
  cache_kwargs: Optional[Dict[str, Any]] = None,
62
  ) -> Tuple[torch.Tensor, torch.Tensor]:
63
- if self.layer_caches[layer_idx] is None:
64
  self.layer_devices[layer_idx] = key_states.device
65
  self._init_layer_cache(layer_idx)
66
 
67
- layer_cache = self.layer_caches[layer_idx]
68
  assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
69
 
70
  k_out, v_out = layer_cache.update(key_states=key_states,
@@ -93,37 +96,37 @@ class VariableCache(Cache_4_44_2, Cache):
93
  if attention_config.window_length is not None:
94
  if not attention_config.is_sink:
95
  config.sliding_window = attention_config.window_length
96
- self.layer_caches[layer_idx] = SlidingWindowCache(config=config,
97
  max_batch_size=self.max_batch_size,
98
  max_cache_len=self.max_cache_len,
99
  device=device,
100
  dtype=self.dtype)
101
  return
102
  elif not attention_config.unshifted_sink:
103
- self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length,
104
  num_sink_tokens=attention_config.num_sink_tokens)
105
  return
106
 
107
- self.layer_caches[layer_idx] = StaticCache(config=config,
108
  max_batch_size=self.max_batch_size,
109
  max_cache_len=self.max_cache_len,
110
  device=device,
111
  dtype=self.dtype)
112
 
113
  def _get_first_real_cache(self) -> Cache:
114
- for layer_cache in self.layer_caches:
115
  if layer_cache is not None:
116
  return layer_cache
117
  raise ValueError(f"No real cache found, all layer caches are None.")
118
 
119
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120
- if layer_idx == 0 and self.layer_caches[0] is None:
121
  try:
122
  layer_cache = self._get_first_real_cache()
123
  except ValueError:
124
  return 0
125
  else:
126
- layer_cache = self.layer_caches[layer_idx]
127
  return layer_cache.get_seq_length()
128
 
129
  def get_max_length(self) -> Optional[int]:
@@ -131,9 +134,12 @@ class VariableCache(Cache_4_44_2, Cache):
131
  return self.max_cache_len
132
 
133
  def reset(self):
134
- for layer_idx in range(len(self.layer_caches)):
135
- layer_cache = self.layer_caches[layer_idx]
136
  if hasattr(layer_cache, "reset"):
137
  layer_cache.reset()
138
  else:
139
  self._init_layer_cache(layer_idx)
 
 
 
 
31
  The default implementation for the layer caches is StaticCache.
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
+
35
+ max_batch_size = None
36
+ max_cache_len = None
37
 
38
  def __init__(
39
  self,
 
53
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
54
  self.dtype = dtype
55
 
56
+ self.layers: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
57
  self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
58
 
59
  def update(
 
63
  layer_idx: int,
64
  cache_kwargs: Optional[Dict[str, Any]] = None,
65
  ) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ if self.layers[layer_idx] is None:
67
  self.layer_devices[layer_idx] = key_states.device
68
  self._init_layer_cache(layer_idx)
69
 
70
+ layer_cache = self.layers[layer_idx]
71
  assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
72
 
73
  k_out, v_out = layer_cache.update(key_states=key_states,
 
96
  if attention_config.window_length is not None:
97
  if not attention_config.is_sink:
98
  config.sliding_window = attention_config.window_length
99
+ self.layers[layer_idx] = SlidingWindowCache(config=config,
100
  max_batch_size=self.max_batch_size,
101
  max_cache_len=self.max_cache_len,
102
  device=device,
103
  dtype=self.dtype)
104
  return
105
  elif not attention_config.unshifted_sink:
106
+ self.layers[layer_idx] = SinkCache(window_length=attention_config.window_length,
107
  num_sink_tokens=attention_config.num_sink_tokens)
108
  return
109
 
110
+ self.layers[layer_idx] = StaticCache(config=config,
111
  max_batch_size=self.max_batch_size,
112
  max_cache_len=self.max_cache_len,
113
  device=device,
114
  dtype=self.dtype)
115
 
116
  def _get_first_real_cache(self) -> Cache:
117
+ for layer_cache in self.layers:
118
  if layer_cache is not None:
119
  return layer_cache
120
  raise ValueError(f"No real cache found, all layer caches are None.")
121
 
122
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
123
+ if layer_idx == 0 and self.layers[0] is None:
124
  try:
125
  layer_cache = self._get_first_real_cache()
126
  except ValueError:
127
  return 0
128
  else:
129
+ layer_cache = self.layers[layer_idx]
130
  return layer_cache.get_seq_length()
131
 
132
  def get_max_length(self) -> Optional[int]:
 
134
  return self.max_cache_len
135
 
136
  def reset(self):
137
+ for layer_idx in range(len(self.layers)):
138
+ layer_cache = self.layers[layer_idx]
139
  if hasattr(layer_cache, "reset"):
140
  layer_cache.reset()
141
  else:
142
  self._init_layer_cache(layer_idx)
143
+
144
+ def is_compileable(self) -> bool:
145
+ return False