Commit
·
1ed39ad
1
Parent(s):
2990a09
Optimize the storage of KV cache
Browse files- README.md +5 -0
- modeling_chatglm.py +21 -8
README.md
CHANGED
|
@@ -15,8 +15,13 @@ tags:
|
|
| 15 |
<p align="center">
|
| 16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
| 17 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
## 介绍
|
|
|
|
| 20 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
| 21 |
|
| 22 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
|
|
|
| 15 |
<p align="center">
|
| 16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
| 17 |
</p>
|
| 18 |
+
## 更新/Update
|
| 19 |
+
|
| 20 |
+
- 我们优化了KV Cache的存储方式,减少了显存碎片的产生。基于优化后的代码,模型可以在约**20G显存**的情况下处理32K长度的上下文(FP/BF16格式)。
|
| 21 |
+
- We have optimized the storage method of the KV Cache, reducing the generation of memory fragmentation. Based on the optimized code, the model can process a context length of 32K under approximately **20G** of memory (FP/BF16 format).
|
| 22 |
|
| 23 |
## 介绍
|
| 24 |
+
|
| 25 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
| 26 |
|
| 27 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
modeling_chatglm.py
CHANGED
|
@@ -413,7 +413,10 @@ class SelfAttention(torch.nn.Module):
|
|
| 413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
| 414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
| 415 |
if use_cache:
|
| 416 |
-
kv_cache
|
|
|
|
|
|
|
|
|
|
| 417 |
else:
|
| 418 |
kv_cache = None
|
| 419 |
|
|
@@ -612,12 +615,8 @@ class GLMTransformer(torch.nn.Module):
|
|
| 612 |
if not kv_caches:
|
| 613 |
kv_caches = [None for _ in range(self.num_layers)]
|
| 614 |
presents = () if use_cache else None
|
| 615 |
-
if self.
|
| 616 |
-
|
| 617 |
-
logger.warning_once(
|
| 618 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 619 |
-
)
|
| 620 |
-
use_cache = False
|
| 621 |
|
| 622 |
all_self_attentions = None
|
| 623 |
all_hidden_states = () if output_hidden_states else None
|
|
@@ -645,7 +644,15 @@ class GLMTransformer(torch.nn.Module):
|
|
| 645 |
)
|
| 646 |
hidden_states, kv_cache = layer_ret
|
| 647 |
if use_cache:
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
if output_hidden_states:
|
| 651 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
@@ -830,6 +837,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 830 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
| 831 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
| 832 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 833 |
|
| 834 |
if not return_dict:
|
| 835 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
|
| 413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
| 414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
| 415 |
if use_cache:
|
| 416 |
+
if kv_cache is None:
|
| 417 |
+
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
| 418 |
+
else:
|
| 419 |
+
kv_cache = (key_layer, value_layer)
|
| 420 |
else:
|
| 421 |
kv_cache = None
|
| 422 |
|
|
|
|
| 615 |
if not kv_caches:
|
| 616 |
kv_caches = [None for _ in range(self.num_layers)]
|
| 617 |
presents = () if use_cache else None
|
| 618 |
+
if self.training:
|
| 619 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
all_self_attentions = None
|
| 622 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
| 644 |
)
|
| 645 |
hidden_states, kv_cache = layer_ret
|
| 646 |
if use_cache:
|
| 647 |
+
# token by token decoding, use tuple format
|
| 648 |
+
if kv_caches[0] is not None:
|
| 649 |
+
presents = presents + (kv_cache,)
|
| 650 |
+
# prefilling in decoding, use tensor format to save cuda memory
|
| 651 |
+
else:
|
| 652 |
+
if len(presents) == 0:
|
| 653 |
+
presents = kv_cache
|
| 654 |
+
else:
|
| 655 |
+
presents = torch.cat((presents, kv_cache), dim=0)
|
| 656 |
|
| 657 |
if output_hidden_states:
|
| 658 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
| 837 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
| 838 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
| 839 |
)
|
| 840 |
+
if presents is not None and type(presents) is torch.Tensor:
|
| 841 |
+
presents = presents.split(1, dim=0)
|
| 842 |
+
presents = list(presents)
|
| 843 |
+
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
| 844 |
+
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
| 845 |
+
presents = tuple(presents)
|
| 846 |
|
| 847 |
if not return_dict:
|
| 848 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|