Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +12 -0
modeling_gpt_refact.py
CHANGED
|
@@ -503,6 +503,18 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
| 503 |
|
| 504 |
# Initialize weights and apply final processing
|
| 505 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 508 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
| 503 |
|
| 504 |
# Initialize weights and apply final processing
|
| 505 |
self.post_init()
|
| 506 |
+
|
| 507 |
+
# gradient checkpointing support for lower versions of transformers
|
| 508 |
+
import transformers
|
| 509 |
+
from packaging import version
|
| 510 |
+
|
| 511 |
+
def _set_gradient_checkpointing(module, enable=False):
|
| 512 |
+
if isinstance(module, GPTRefactModel):
|
| 513 |
+
module.gradient_checkpointing = enable
|
| 514 |
+
|
| 515 |
+
v = version.parse(transformers.__version__)
|
| 516 |
+
if v.major <= 4 and v.minor < 35:
|
| 517 |
+
self._set_gradient_checkpointing = _set_gradient_checkpointing
|
| 518 |
|
| 519 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 520 |
if inputs_embeds is not None and past_key_values is None:
|