Gradient Checkpointing for HF Trainer
#23
by
acon96
- opened
- modeling_phi.py +13 -7
modeling_phi.py
CHANGED
|
@@ -605,9 +605,9 @@ class MHA(nn.Module):
|
|
| 605 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 606 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
| 607 |
|
| 608 |
-
if self.checkpointing:
|
| 609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 610 |
-
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 611 |
)
|
| 612 |
else:
|
| 613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
@@ -615,8 +615,8 @@ class MHA(nn.Module):
|
|
| 615 |
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
| 616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 617 |
|
| 618 |
-
if self.checkpointing:
|
| 619 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
| 620 |
|
| 621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
| 622 |
|
|
@@ -664,7 +664,7 @@ class MHA(nn.Module):
|
|
| 664 |
|
| 665 |
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 666 |
|
| 667 |
-
if self.checkpointing:
|
| 668 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 669 |
self.inner_cross_attn,
|
| 670 |
q,
|
|
@@ -674,6 +674,7 @@ class MHA(nn.Module):
|
|
| 674 |
max_seqlen=max_seqlen_q,
|
| 675 |
cu_seqlens_k=cu_seqlens_k,
|
| 676 |
max_seqlen_k=max_seqlen_k,
|
|
|
|
| 677 |
)
|
| 678 |
else:
|
| 679 |
attn_output = self.inner_cross_attn(
|
|
@@ -692,13 +693,14 @@ class MHA(nn.Module):
|
|
| 692 |
else attn_output
|
| 693 |
)
|
| 694 |
|
| 695 |
-
if self.checkpointing:
|
| 696 |
return torch.utils.checkpoint.checkpoint(
|
| 697 |
self.inner_cross_attn,
|
| 698 |
q,
|
| 699 |
kv,
|
| 700 |
key_padding_mask=key_padding_mask,
|
| 701 |
causal=causal,
|
|
|
|
| 702 |
)
|
| 703 |
|
| 704 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
@@ -835,7 +837,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 835 |
|
| 836 |
config_class = PhiConfig
|
| 837 |
base_model_prefix = "transformer"
|
| 838 |
-
supports_gradient_checkpointing =
|
| 839 |
_no_split_modules = ["ParallelBlock"]
|
| 840 |
|
| 841 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
@@ -855,6 +857,10 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 855 |
module.bias.data.zero_()
|
| 856 |
module.weight.data.fill_(1.0)
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
def prepare_inputs_for_generation(
|
| 859 |
self,
|
| 860 |
input_ids: torch.LongTensor,
|
|
|
|
| 605 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 606 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
| 607 |
|
| 608 |
+
if self.checkpointing and self.training:
|
| 609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 610 |
+
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, use_reentrant=False
|
| 611 |
)
|
| 612 |
else:
|
| 613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
|
|
| 615 |
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
| 616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 617 |
|
| 618 |
+
if self.checkpointing and self.training:
|
| 619 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask, use_reentrant=False)
|
| 620 |
|
| 621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
| 622 |
|
|
|
|
| 664 |
|
| 665 |
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 666 |
|
| 667 |
+
if self.checkpointing and self.training:
|
| 668 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 669 |
self.inner_cross_attn,
|
| 670 |
q,
|
|
|
|
| 674 |
max_seqlen=max_seqlen_q,
|
| 675 |
cu_seqlens_k=cu_seqlens_k,
|
| 676 |
max_seqlen_k=max_seqlen_k,
|
| 677 |
+
use_reentrant=False
|
| 678 |
)
|
| 679 |
else:
|
| 680 |
attn_output = self.inner_cross_attn(
|
|
|
|
| 693 |
else attn_output
|
| 694 |
)
|
| 695 |
|
| 696 |
+
if self.checkpointing and self.training:
|
| 697 |
return torch.utils.checkpoint.checkpoint(
|
| 698 |
self.inner_cross_attn,
|
| 699 |
q,
|
| 700 |
kv,
|
| 701 |
key_padding_mask=key_padding_mask,
|
| 702 |
causal=causal,
|
| 703 |
+
use_reentrant=False
|
| 704 |
)
|
| 705 |
|
| 706 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
|
|
| 837 |
|
| 838 |
config_class = PhiConfig
|
| 839 |
base_model_prefix = "transformer"
|
| 840 |
+
supports_gradient_checkpointing = True
|
| 841 |
_no_split_modules = ["ParallelBlock"]
|
| 842 |
|
| 843 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
|
|
| 857 |
module.bias.data.zero_()
|
| 858 |
module.weight.data.fill_(1.0)
|
| 859 |
|
| 860 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 861 |
+
if isinstance(module, MHA):
|
| 862 |
+
module.checkpointing = value
|
| 863 |
+
|
| 864 |
def prepare_inputs_for_generation(
|
| 865 |
self,
|
| 866 |
input_ids: torch.LongTensor,
|