Upload 2 files
Browse files- configuration.py +5 -4
- modeling.py +3 -2
configuration.py
CHANGED
|
@@ -101,6 +101,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
|
|
| 101 |
sim_threshold=0.25,
|
| 102 |
tokenizer_all_special_ids=[0, 50278],
|
| 103 |
remove_special_ids=False,
|
|
|
|
| 104 |
**kwargs,
|
| 105 |
):
|
| 106 |
super().__init__(**kwargs)
|
|
@@ -121,6 +122,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
|
|
| 121 |
self.sim_threshold = sim_threshold
|
| 122 |
self.tokenizer_all_special_ids = tokenizer_all_special_ids
|
| 123 |
self.remove_special_ids = remove_special_ids
|
|
|
|
| 124 |
|
| 125 |
if attn_type not in ["multihead_attention", "multiquery_attention"]:
|
| 126 |
raise ValueError(
|
|
@@ -245,7 +247,6 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
| 245 |
n_layers: int = 32,
|
| 246 |
expansion_ratio: int = 4,
|
| 247 |
max_seq_len_inference: int = 2048,
|
| 248 |
-
max_seq_len_train: int = 2048,
|
| 249 |
vocab_size: int = 50432,
|
| 250 |
resid_pdrop: float = 0.0,
|
| 251 |
layer_norm_epsilon: float = 1e-5,
|
|
@@ -261,11 +262,12 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
| 261 |
use_cache: bool = False,
|
| 262 |
initializer_range=0.02,
|
| 263 |
use_external_mind: bool = True,
|
| 264 |
-
use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
|
| 265 |
**kwargs,
|
| 266 |
):
|
| 267 |
if attn_config is None:
|
| 268 |
-
self.attn_config = ExtendedMptAttentionConfig(
|
|
|
|
|
|
|
| 269 |
elif not isinstance(attn_config, ExtendedMptAttentionConfig):
|
| 270 |
self.attn_config = ExtendedMptAttentionConfig(**attn_config)
|
| 271 |
else:
|
|
@@ -275,7 +277,6 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
| 275 |
self.n_layers = n_layers
|
| 276 |
self.expansion_ratio = expansion_ratio
|
| 277 |
self.max_seq_len = max_seq_len_inference
|
| 278 |
-
self.max_seq_len_train = max_seq_len_train
|
| 279 |
self.vocab_size = vocab_size
|
| 280 |
self.resid_pdrop = resid_pdrop
|
| 281 |
self.emb_pdrop = emb_pdrop
|
|
|
|
| 101 |
sim_threshold=0.25,
|
| 102 |
tokenizer_all_special_ids=[0, 50278],
|
| 103 |
remove_special_ids=False,
|
| 104 |
+
use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
|
| 105 |
**kwargs,
|
| 106 |
):
|
| 107 |
super().__init__(**kwargs)
|
|
|
|
| 122 |
self.sim_threshold = sim_threshold
|
| 123 |
self.tokenizer_all_special_ids = tokenizer_all_special_ids
|
| 124 |
self.remove_special_ids = remove_special_ids
|
| 125 |
+
self.use_external_mind_by_layer = use_external_mind_by_layer
|
| 126 |
|
| 127 |
if attn_type not in ["multihead_attention", "multiquery_attention"]:
|
| 128 |
raise ValueError(
|
|
|
|
| 247 |
n_layers: int = 32,
|
| 248 |
expansion_ratio: int = 4,
|
| 249 |
max_seq_len_inference: int = 2048,
|
|
|
|
| 250 |
vocab_size: int = 50432,
|
| 251 |
resid_pdrop: float = 0.0,
|
| 252 |
layer_norm_epsilon: float = 1e-5,
|
|
|
|
| 262 |
use_cache: bool = False,
|
| 263 |
initializer_range=0.02,
|
| 264 |
use_external_mind: bool = True,
|
|
|
|
| 265 |
**kwargs,
|
| 266 |
):
|
| 267 |
if attn_config is None:
|
| 268 |
+
self.attn_config = ExtendedMptAttentionConfig(
|
| 269 |
+
use_external_mind_by_layer=[True for _ in range(n_layers)]
|
| 270 |
+
)
|
| 271 |
elif not isinstance(attn_config, ExtendedMptAttentionConfig):
|
| 272 |
self.attn_config = ExtendedMptAttentionConfig(**attn_config)
|
| 273 |
else:
|
|
|
|
| 277 |
self.n_layers = n_layers
|
| 278 |
self.expansion_ratio = expansion_ratio
|
| 279 |
self.max_seq_len = max_seq_len_inference
|
|
|
|
| 280 |
self.vocab_size = vocab_size
|
| 281 |
self.resid_pdrop = resid_pdrop
|
| 282 |
self.emb_pdrop = emb_pdrop
|
modeling.py
CHANGED
|
@@ -920,7 +920,7 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
| 920 |
|
| 921 |
_tied_weights_keys = ["lm_head.weight"]
|
| 922 |
|
| 923 |
-
def __init__(self, config: ExtendedMptConfig, external_memories=None):
|
| 924 |
super().__init__(config)
|
| 925 |
self.transformer: ExtendedMptModel = ExtendedMptModel(config)
|
| 926 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
@@ -1016,8 +1016,9 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
| 1016 |
if (
|
| 1017 |
self.memory_ids is not None and self.memories is None
|
| 1018 |
):
|
|
|
|
| 1019 |
self.memories = self.generate_cache(
|
| 1020 |
-
self.memory_ids, cache_type=self.memory_type
|
| 1021 |
)
|
| 1022 |
# EM: Remove special tokens from memory cache
|
| 1023 |
if self.remove_special_ids:
|
|
|
|
| 920 |
|
| 921 |
_tied_weights_keys = ["lm_head.weight"]
|
| 922 |
|
| 923 |
+
def __init__(self, config: ExtendedMptConfig, external_memories:list=None):
|
| 924 |
super().__init__(config)
|
| 925 |
self.transformer: ExtendedMptModel = ExtendedMptModel(config)
|
| 926 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
| 1016 |
if (
|
| 1017 |
self.memory_ids is not None and self.memories is None
|
| 1018 |
):
|
| 1019 |
+
self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
|
| 1020 |
self.memories = self.generate_cache(
|
| 1021 |
+
self.memory_ids, cache_type=self.memory_type,
|
| 1022 |
)
|
| 1023 |
# EM: Remove special tokens from memory cache
|
| 1024 |
if self.remove_special_ids:
|