dav1dliu commited on
Commit
70f2069
·
verified ·
1 Parent(s): 15cdedf

Add kernel to accelerate

Browse files
Files changed (1) hide show
  1. modeling_sdar.py +19 -55
modeling_sdar.py CHANGED
@@ -46,13 +46,7 @@ from transformers.processing_utils import Unpack
46
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
  from .configuration_sdar import SDARConfig
48
 
49
- from fla.modules.activations import swiglu_linear
50
- from fla.modules import (
51
- FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
52
- FusedLinearUnreducedCrossEntropyLoss,
53
- FusedLinearDiffusionCrossEntropyLoss)
54
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
55
- from torch.distributed.tensor import DTensor
56
 
57
  import torch.nn.functional as F
58
  try:
@@ -61,12 +55,11 @@ try:
61
  except:
62
  pass
63
 
64
-
65
- def dtensor2local(dtensor):
66
- if isinstance(dtensor, DTensor):
67
- return dtensor.to_local()
68
- else:
69
- return dtensor
70
 
71
 
72
  if is_torch_flex_attn_available():
@@ -77,10 +70,6 @@ if is_torch_flex_attn_available():
77
  logger = logging.get_logger(__name__)
78
 
79
 
80
- @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
81
- def fused_flex_attention(query, key, value, attention_mask=None, **kwargs):
82
- return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
83
-
84
 
85
  @use_kernel_forward_from_hub("RMSNorm")
86
  class SDARRMSNorm(nn.Module):
@@ -93,16 +82,16 @@ class SDARRMSNorm(nn.Module):
93
  self.variance_epsilon = eps
94
 
95
  def forward(self, hidden_states):
96
- weight = dtensor2local(self.weight)
97
- '''
98
- return flash_rms_norm(hidden_states, weight=weight, bias=None, eps=self.variance_epsilon)
99
  '''
100
  input_dtype = hidden_states.dtype
101
  hidden_states = hidden_states.to(torch.float32)
102
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
103
  hidden_states = hidden_states * \
104
  torch.rsqrt(variance + self.variance_epsilon)
105
- return weight * hidden_states.to(input_dtype)
 
106
 
107
 
108
  def extra_repr(self):
@@ -124,12 +113,11 @@ class SDARMLP(nn.Module):
124
  self.act_fn = ACT2FN[config.hidden_act]
125
 
126
  def forward(self, x):
127
- down_proj_weight = dtensor2local(self.down_proj.weight)
128
- down_proj_bias = dtensor2local(self.down_proj.bias)
129
- # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
130
- down_proj = swiglu_linear(self.gate_proj(x), self.up_proj(x),
131
- down_proj_weight, down_proj_bias)
132
- return down_proj
133
 
134
 
135
  def rotate_half(x):
@@ -856,35 +844,11 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
856
 
857
  loss = None
858
  if labels is not None:
859
- if self.config.fuse_cross_entropy:
860
- if fuse_linear_and_cross_entropy:
861
- # Note: We use reduction='sum'
862
- # For 'mean' reduction, gradients are normalized by number of *non-ignored* elements
863
- # mean_loss = sum_loss / num_non_ignored_tokens, instead of all tokens (labels != -100)
864
- loss_fct = FusedLinearDiffusionCrossEntropyLoss(
865
- reduction='sum')
866
- else:
867
- loss_fct = FusedCrossEntropyLoss(
868
- reduction='sum', inplace_backward=True)
869
- else:
870
- loss_fct = nn.CrossEntropyLoss() # nn.CE
871
-
872
- if fuse_linear_and_cross_entropy:
873
- p_mask = kwargs.get('p_mask', None)
874
- # loss: tuple of (sum_loss, unreduced_loss)
875
- lm_head_weight = dtensor2local(self.lm_head.weight)
876
- lm_head_bias = dtensor2local(self.lm_head.bias)
877
- loss = loss_fct(
878
- x=hidden_states, # `view(-1, V)` inside the kernel
879
- target=labels,
880
- weight=lm_head_weight,
881
- bias=lm_head_bias,
882
- p_mask=p_mask,
883
- )
884
- else:
885
- raise RuntimeError("Do not support yet!")
886
- loss = loss_fct(
887
- logits.view(-1, self.config.vocab_size), labels.view(-1))
888
 
889
  return CausalLMOutputWithPast(
890
  loss=loss,
 
46
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
  from .configuration_sdar import SDARConfig
48
 
 
 
 
 
 
49
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
 
50
 
51
  import torch.nn.functional as F
52
  try:
 
55
  except:
56
  pass
57
 
58
+ try:
59
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
60
+ liger_kernel_is_available = True
61
+ except ImportError:
62
+ liger_kernel_is_available = False
 
63
 
64
 
65
  if is_torch_flex_attn_available():
 
70
  logger = logging.get_logger(__name__)
71
 
72
 
 
 
 
 
73
 
74
  @use_kernel_forward_from_hub("RMSNorm")
75
  class SDARRMSNorm(nn.Module):
 
82
  self.variance_epsilon = eps
83
 
84
  def forward(self, hidden_states):
85
+ return flash_rms_norm(
86
+ hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon)
 
87
  '''
88
  input_dtype = hidden_states.dtype
89
  hidden_states = hidden_states.to(torch.float32)
90
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
  hidden_states = hidden_states * \
92
  torch.rsqrt(variance + self.variance_epsilon)
93
+ return self.weight * hidden_states.to(input_dtype)
94
+ '''
95
 
96
 
97
  def extra_repr(self):
 
113
  self.act_fn = ACT2FN[config.hidden_act]
114
 
115
  def forward(self, x):
116
+ if liger_kernel_is_available:
117
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
118
+ else:
119
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
120
+ return down_proj
 
121
 
122
 
123
  def rotate_half(x):
 
844
 
845
  loss = None
846
  if labels is not None:
847
+ # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
848
+ # We don't use it when inferencing
849
+ loss_fct = nn.CrossEntropyLoss() # nn.CE
850
+ loss = loss_fct(
851
+ logits.view(-1, self.config.vocab_size), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
 
853
  return CausalLMOutputWithPast(
854
  loss=loss,