Upload modeling_ernie_45t_vl.py
#3
by
liaojc
- opened
- modeling_ernie_45t_vl.py +23 -24
modeling_ernie_45t_vl.py
CHANGED
|
@@ -27,6 +27,7 @@ import numpy as np
|
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
import torch.nn.functional as F
|
|
|
|
| 30 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 31 |
|
| 32 |
from transformers.activations import ACT2FN
|
|
@@ -2072,6 +2073,10 @@ class MOEAllGatherLayerV2(MOELayer):
|
|
| 2072 |
top_k = self.k
|
| 2073 |
num_expert_per_rank_per_modality = gate_logits_lm.shape[-1]
|
| 2074 |
group_size = gate_logits_lm.shape[-1] // top_k
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2075 |
if self.group_experts:
|
| 2076 |
assert not self.use_correction_bias
|
| 2077 |
gate_logits_lm = gate_logits_lm.reshape(
|
|
@@ -3457,33 +3462,27 @@ class VisionAttention(nn.Module):
|
|
| 3457 |
k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
|
| 3458 |
dim=0
|
| 3459 |
)
|
| 3460 |
-
|
| 3461 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 3462 |
-
|
| 3463 |
-
attention_mask = torch.full(
|
| 3464 |
-
[1, seq_length, seq_length],
|
| 3465 |
-
torch.finfo(q.dtype).min,
|
| 3466 |
-
device=q.device,
|
| 3467 |
-
dtype=q.dtype,
|
| 3468 |
-
)
|
| 3469 |
-
for i in range(1, len(cu_seqlens)):
|
| 3470 |
-
attention_mask[
|
| 3471 |
-
...,
|
| 3472 |
-
cu_seqlens[i - 1] : cu_seqlens[i],
|
| 3473 |
-
cu_seqlens[i - 1] : cu_seqlens[i],
|
| 3474 |
-
] = 0
|
| 3475 |
-
|
| 3476 |
q = q.transpose(0, 1)
|
| 3477 |
k = k.transpose(0, 1)
|
| 3478 |
v = v.transpose(0, 1)
|
| 3479 |
-
|
| 3480 |
-
|
| 3481 |
-
|
| 3482 |
-
|
| 3483 |
-
|
| 3484 |
-
|
| 3485 |
-
attn_output =
|
| 3486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3487 |
attn_output = self.proj(attn_output)
|
| 3488 |
return attn_output
|
| 3489 |
|
|
|
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
import torch.nn.functional as F
|
| 30 |
+
from torch.distributed.tensor import DTensor
|
| 31 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 32 |
|
| 33 |
from transformers.activations import ACT2FN
|
|
|
|
| 2073 |
top_k = self.k
|
| 2074 |
num_expert_per_rank_per_modality = gate_logits_lm.shape[-1]
|
| 2075 |
group_size = gate_logits_lm.shape[-1] // top_k
|
| 2076 |
+
if self.use_correction_bias and isinstance(self.moe_statics.e_score_correction_bias, DTensor):
|
| 2077 |
+
correction_bias = self.moe_statics.e_score_correction_bias.to_local()
|
| 2078 |
+
elif self.use_correction_bias:
|
| 2079 |
+
correction_bias = self.moe_statics.e_score_correction_bias
|
| 2080 |
if self.group_experts:
|
| 2081 |
assert not self.use_correction_bias
|
| 2082 |
gate_logits_lm = gate_logits_lm.reshape(
|
|
|
|
| 3462 |
k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
|
| 3463 |
dim=0
|
| 3464 |
)
|
| 3465 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3466 |
q = q.transpose(0, 1)
|
| 3467 |
k = k.transpose(0, 1)
|
| 3468 |
v = v.transpose(0, 1)
|
| 3469 |
+
|
| 3470 |
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
| 3471 |
+
splits = [
|
| 3472 |
+
torch.split(tensor, lengths.tolist(), dim=1) for tensor in (q, k, v)
|
| 3473 |
+
]
|
| 3474 |
+
|
| 3475 |
+
attn_output = []
|
| 3476 |
+
for q, k, v in zip(*splits):
|
| 3477 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 3478 |
+
attn_weights = nn.functional.softmax(
|
| 3479 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 3480 |
+
).to(q.dtype)
|
| 3481 |
+
attn_output_splited = torch.matmul(attn_weights, v)
|
| 3482 |
+
attn_output_splited = attn_output_splited.transpose(0, 1)
|
| 3483 |
+
attn_output.append(attn_output_splited)
|
| 3484 |
+
attn_output = torch.cat(attn_output, dim=0)
|
| 3485 |
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
| 3486 |
attn_output = self.proj(attn_output)
|
| 3487 |
return attn_output
|
| 3488 |
|