variable_cache.py compatibility for v4.57.2 / python3.12
Browse files- variable_cache.py +17 -11
variable_cache.py
CHANGED
|
@@ -31,6 +31,9 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
| 31 |
The default implementation for the layer caches is StaticCache.
|
| 32 |
The cache of each layer is allocated to the same gpu as the layer itself.
|
| 33 |
"""
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def __init__(
|
| 36 |
self,
|
|
@@ -50,7 +53,7 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
| 50 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
| 51 |
self.dtype = dtype
|
| 52 |
|
| 53 |
-
self.
|
| 54 |
self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
|
| 55 |
|
| 56 |
def update(
|
|
@@ -60,11 +63,11 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
| 60 |
layer_idx: int,
|
| 61 |
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 62 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 63 |
-
if self.
|
| 64 |
self.layer_devices[layer_idx] = key_states.device
|
| 65 |
self._init_layer_cache(layer_idx)
|
| 66 |
|
| 67 |
-
layer_cache = self.
|
| 68 |
assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
|
| 69 |
|
| 70 |
k_out, v_out = layer_cache.update(key_states=key_states,
|
|
@@ -93,37 +96,37 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
| 93 |
if attention_config.window_length is not None:
|
| 94 |
if not attention_config.is_sink:
|
| 95 |
config.sliding_window = attention_config.window_length
|
| 96 |
-
self.
|
| 97 |
max_batch_size=self.max_batch_size,
|
| 98 |
max_cache_len=self.max_cache_len,
|
| 99 |
device=device,
|
| 100 |
dtype=self.dtype)
|
| 101 |
return
|
| 102 |
elif not attention_config.unshifted_sink:
|
| 103 |
-
self.
|
| 104 |
num_sink_tokens=attention_config.num_sink_tokens)
|
| 105 |
return
|
| 106 |
|
| 107 |
-
self.
|
| 108 |
max_batch_size=self.max_batch_size,
|
| 109 |
max_cache_len=self.max_cache_len,
|
| 110 |
device=device,
|
| 111 |
dtype=self.dtype)
|
| 112 |
|
| 113 |
def _get_first_real_cache(self) -> Cache:
|
| 114 |
-
for layer_cache in self.
|
| 115 |
if layer_cache is not None:
|
| 116 |
return layer_cache
|
| 117 |
raise ValueError(f"No real cache found, all layer caches are None.")
|
| 118 |
|
| 119 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 120 |
-
if layer_idx == 0 and self.
|
| 121 |
try:
|
| 122 |
layer_cache = self._get_first_real_cache()
|
| 123 |
except ValueError:
|
| 124 |
return 0
|
| 125 |
else:
|
| 126 |
-
layer_cache = self.
|
| 127 |
return layer_cache.get_seq_length()
|
| 128 |
|
| 129 |
def get_max_length(self) -> Optional[int]:
|
|
@@ -131,9 +134,12 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
| 131 |
return self.max_cache_len
|
| 132 |
|
| 133 |
def reset(self):
|
| 134 |
-
for layer_idx in range(len(self.
|
| 135 |
-
layer_cache = self.
|
| 136 |
if hasattr(layer_cache, "reset"):
|
| 137 |
layer_cache.reset()
|
| 138 |
else:
|
| 139 |
self._init_layer_cache(layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
The default implementation for the layer caches is StaticCache.
|
| 32 |
The cache of each layer is allocated to the same gpu as the layer itself.
|
| 33 |
"""
|
| 34 |
+
|
| 35 |
+
max_batch_size = None
|
| 36 |
+
max_cache_len = None
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
self,
|
|
|
|
| 53 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
| 54 |
self.dtype = dtype
|
| 55 |
|
| 56 |
+
self.layers: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
|
| 57 |
self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
|
| 58 |
|
| 59 |
def update(
|
|
|
|
| 63 |
layer_idx: int,
|
| 64 |
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 65 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 66 |
+
if self.layers[layer_idx] is None:
|
| 67 |
self.layer_devices[layer_idx] = key_states.device
|
| 68 |
self._init_layer_cache(layer_idx)
|
| 69 |
|
| 70 |
+
layer_cache = self.layers[layer_idx]
|
| 71 |
assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
|
| 72 |
|
| 73 |
k_out, v_out = layer_cache.update(key_states=key_states,
|
|
|
|
| 96 |
if attention_config.window_length is not None:
|
| 97 |
if not attention_config.is_sink:
|
| 98 |
config.sliding_window = attention_config.window_length
|
| 99 |
+
self.layers[layer_idx] = SlidingWindowCache(config=config,
|
| 100 |
max_batch_size=self.max_batch_size,
|
| 101 |
max_cache_len=self.max_cache_len,
|
| 102 |
device=device,
|
| 103 |
dtype=self.dtype)
|
| 104 |
return
|
| 105 |
elif not attention_config.unshifted_sink:
|
| 106 |
+
self.layers[layer_idx] = SinkCache(window_length=attention_config.window_length,
|
| 107 |
num_sink_tokens=attention_config.num_sink_tokens)
|
| 108 |
return
|
| 109 |
|
| 110 |
+
self.layers[layer_idx] = StaticCache(config=config,
|
| 111 |
max_batch_size=self.max_batch_size,
|
| 112 |
max_cache_len=self.max_cache_len,
|
| 113 |
device=device,
|
| 114 |
dtype=self.dtype)
|
| 115 |
|
| 116 |
def _get_first_real_cache(self) -> Cache:
|
| 117 |
+
for layer_cache in self.layers:
|
| 118 |
if layer_cache is not None:
|
| 119 |
return layer_cache
|
| 120 |
raise ValueError(f"No real cache found, all layer caches are None.")
|
| 121 |
|
| 122 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 123 |
+
if layer_idx == 0 and self.layers[0] is None:
|
| 124 |
try:
|
| 125 |
layer_cache = self._get_first_real_cache()
|
| 126 |
except ValueError:
|
| 127 |
return 0
|
| 128 |
else:
|
| 129 |
+
layer_cache = self.layers[layer_idx]
|
| 130 |
return layer_cache.get_seq_length()
|
| 131 |
|
| 132 |
def get_max_length(self) -> Optional[int]:
|
|
|
|
| 134 |
return self.max_cache_len
|
| 135 |
|
| 136 |
def reset(self):
|
| 137 |
+
for layer_idx in range(len(self.layers)):
|
| 138 |
+
layer_cache = self.layers[layer_idx]
|
| 139 |
if hasattr(layer_cache, "reset"):
|
| 140 |
layer_cache.reset()
|
| 141 |
else:
|
| 142 |
self._init_layer_cache(layer_idx)
|
| 143 |
+
|
| 144 |
+
def is_compileable(self) -> bool:
|
| 145 |
+
return False
|