Add kernel to accelerate
Browse files- modeling_sdar.py +19 -55
modeling_sdar.py
CHANGED
|
@@ -46,13 +46,7 @@ from transformers.processing_utils import Unpack
|
|
| 46 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 47 |
from .configuration_sdar import SDARConfig
|
| 48 |
|
| 49 |
-
from fla.modules.activations import swiglu_linear
|
| 50 |
-
from fla.modules import (
|
| 51 |
-
FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
|
| 52 |
-
FusedLinearUnreducedCrossEntropyLoss,
|
| 53 |
-
FusedLinearDiffusionCrossEntropyLoss)
|
| 54 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
| 55 |
-
from torch.distributed.tensor import DTensor
|
| 56 |
|
| 57 |
import torch.nn.functional as F
|
| 58 |
try:
|
|
@@ -61,12 +55,11 @@ try:
|
|
| 61 |
except:
|
| 62 |
pass
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
return dtensor
|
| 70 |
|
| 71 |
|
| 72 |
if is_torch_flex_attn_available():
|
|
@@ -77,10 +70,6 @@ if is_torch_flex_attn_available():
|
|
| 77 |
logger = logging.get_logger(__name__)
|
| 78 |
|
| 79 |
|
| 80 |
-
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 81 |
-
def fused_flex_attention(query, key, value, attention_mask=None, **kwargs):
|
| 82 |
-
return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
|
| 83 |
-
|
| 84 |
|
| 85 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 86 |
class SDARRMSNorm(nn.Module):
|
|
@@ -93,16 +82,16 @@ class SDARRMSNorm(nn.Module):
|
|
| 93 |
self.variance_epsilon = eps
|
| 94 |
|
| 95 |
def forward(self, hidden_states):
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
return flash_rms_norm(hidden_states, weight=weight, bias=None, eps=self.variance_epsilon)
|
| 99 |
'''
|
| 100 |
input_dtype = hidden_states.dtype
|
| 101 |
hidden_states = hidden_states.to(torch.float32)
|
| 102 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 103 |
hidden_states = hidden_states * \
|
| 104 |
torch.rsqrt(variance + self.variance_epsilon)
|
| 105 |
-
return weight * hidden_states.to(input_dtype)
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
def extra_repr(self):
|
|
@@ -124,12 +113,11 @@ class SDARMLP(nn.Module):
|
|
| 124 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 125 |
|
| 126 |
def forward(self, x):
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
return down_proj
|
| 133 |
|
| 134 |
|
| 135 |
def rotate_half(x):
|
|
@@ -856,35 +844,11 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 856 |
|
| 857 |
loss = None
|
| 858 |
if labels is not None:
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
loss_fct = FusedLinearDiffusionCrossEntropyLoss(
|
| 865 |
-
reduction='sum')
|
| 866 |
-
else:
|
| 867 |
-
loss_fct = FusedCrossEntropyLoss(
|
| 868 |
-
reduction='sum', inplace_backward=True)
|
| 869 |
-
else:
|
| 870 |
-
loss_fct = nn.CrossEntropyLoss() # nn.CE
|
| 871 |
-
|
| 872 |
-
if fuse_linear_and_cross_entropy:
|
| 873 |
-
p_mask = kwargs.get('p_mask', None)
|
| 874 |
-
# loss: tuple of (sum_loss, unreduced_loss)
|
| 875 |
-
lm_head_weight = dtensor2local(self.lm_head.weight)
|
| 876 |
-
lm_head_bias = dtensor2local(self.lm_head.bias)
|
| 877 |
-
loss = loss_fct(
|
| 878 |
-
x=hidden_states, # `view(-1, V)` inside the kernel
|
| 879 |
-
target=labels,
|
| 880 |
-
weight=lm_head_weight,
|
| 881 |
-
bias=lm_head_bias,
|
| 882 |
-
p_mask=p_mask,
|
| 883 |
-
)
|
| 884 |
-
else:
|
| 885 |
-
raise RuntimeError("Do not support yet!")
|
| 886 |
-
loss = loss_fct(
|
| 887 |
-
logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 888 |
|
| 889 |
return CausalLMOutputWithPast(
|
| 890 |
loss=loss,
|
|
|
|
| 46 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 47 |
from .configuration_sdar import SDARConfig
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
|
|
|
| 50 |
|
| 51 |
import torch.nn.functional as F
|
| 52 |
try:
|
|
|
|
| 55 |
except:
|
| 56 |
pass
|
| 57 |
|
| 58 |
+
try:
|
| 59 |
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
|
| 60 |
+
liger_kernel_is_available = True
|
| 61 |
+
except ImportError:
|
| 62 |
+
liger_kernel_is_available = False
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
if is_torch_flex_attn_available():
|
|
|
|
| 70 |
logger = logging.get_logger(__name__)
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 75 |
class SDARRMSNorm(nn.Module):
|
|
|
|
| 82 |
self.variance_epsilon = eps
|
| 83 |
|
| 84 |
def forward(self, hidden_states):
|
| 85 |
+
return flash_rms_norm(
|
| 86 |
+
hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon)
|
|
|
|
| 87 |
'''
|
| 88 |
input_dtype = hidden_states.dtype
|
| 89 |
hidden_states = hidden_states.to(torch.float32)
|
| 90 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 91 |
hidden_states = hidden_states * \
|
| 92 |
torch.rsqrt(variance + self.variance_epsilon)
|
| 93 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 94 |
+
'''
|
| 95 |
|
| 96 |
|
| 97 |
def extra_repr(self):
|
|
|
|
| 113 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 114 |
|
| 115 |
def forward(self, x):
|
| 116 |
+
if liger_kernel_is_available:
|
| 117 |
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
| 118 |
+
else:
|
| 119 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 120 |
+
return down_proj
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
def rotate_half(x):
|
|
|
|
| 844 |
|
| 845 |
loss = None
|
| 846 |
if labels is not None:
|
| 847 |
+
# FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
|
| 848 |
+
# We don't use it when inferencing
|
| 849 |
+
loss_fct = nn.CrossEntropyLoss() # nn.CE
|
| 850 |
+
loss = loss_fct(
|
| 851 |
+
logits.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
|
| 853 |
return CausalLMOutputWithPast(
|
| 854 |
loss=loss,
|