Upload 2 files
Browse files- modeling.py +2 -1
modeling.py
CHANGED
|
@@ -356,7 +356,7 @@ class ExtendedMptAttention(nn.Module):
|
|
| 356 |
)
|
| 357 |
attn_output = self.out_proj(context_states)
|
| 358 |
|
| 359 |
-
if not output_retrieved_memory_idx:
|
| 360 |
reshaped_idx = None
|
| 361 |
|
| 362 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
|
@@ -977,6 +977,7 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
| 977 |
"attention_mask": attention_mask,
|
| 978 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
| 979 |
"topk": kwargs.get("topk"),
|
|
|
|
| 980 |
}
|
| 981 |
)
|
| 982 |
return model_inputs
|
|
|
|
| 356 |
)
|
| 357 |
attn_output = self.out_proj(context_states)
|
| 358 |
|
| 359 |
+
if not output_retrieved_memory_idx or (long_range_past_key_value is None and faiss_indexes is None):
|
| 360 |
reshaped_idx = None
|
| 361 |
|
| 362 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
|
|
|
| 977 |
"attention_mask": attention_mask,
|
| 978 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
| 979 |
"topk": kwargs.get("topk"),
|
| 980 |
+
"output_retrieved_memory_idx": kwargs.get("output_retrieved_memory_idx"),
|
| 981 |
}
|
| 982 |
)
|
| 983 |
return model_inputs
|