|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from torch.nn.parameter import Parameter |
|
|
from torch.nn import Module |
|
|
from torch.nn.modules.linear import Linear |
|
|
from torch.nn.modules.dropout import Dropout |
|
|
from torch.nn.modules.normalization import LayerNorm |
|
|
from torch.nn.init import * |
|
|
|
|
|
class TransformerEncoderRPR(Module): |
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
|
super(TransformerEncoderRPR, self).__init__() |
|
|
self.layers = torch.nn.ModuleList([encoder_layer for _ in range(num_layers)]) |
|
|
self.num_layers = num_layers |
|
|
self.norm = norm |
|
|
|
|
|
def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs): |
|
|
output = src |
|
|
for layer in self.layers: |
|
|
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) |
|
|
if self.norm: |
|
|
output = self.norm(output) |
|
|
return output |
|
|
|
|
|
class TransformerEncoderLayerRPR(Module): |
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): |
|
|
super(TransformerEncoderLayerRPR, self).__init__() |
|
|
self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len) |
|
|
self.linear1 = Linear(d_model, dim_feedforward) |
|
|
self.dropout = Dropout(dropout) |
|
|
self.linear2 = Linear(dim_feedforward, d_model) |
|
|
self.norm1 = LayerNorm(d_model) |
|
|
self.norm2 = LayerNorm(d_model) |
|
|
self.dropout1 = Dropout(dropout) |
|
|
self.dropout2 = Dropout(dropout) |
|
|
|
|
|
def forward(self, src, src_mask=None, src_key_padding_mask=None): |
|
|
src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] |
|
|
src = src + self.dropout1(src2) |
|
|
src = self.norm1(src) |
|
|
src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) |
|
|
src = src + self.dropout2(src2) |
|
|
src = self.norm2(src) |
|
|
return src |
|
|
|
|
|
class MultiheadAttentionRPR(Module): |
|
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None): |
|
|
super(MultiheadAttentionRPR, self).__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.kdim = kdim if kdim is not None else embed_dim |
|
|
self.vdim = vdim if vdim is not None else embed_dim |
|
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.dropout = dropout |
|
|
self.head_dim = embed_dim // num_heads |
|
|
|
|
|
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) |
|
|
if not self._qkv_same_embed_dim: |
|
|
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) |
|
|
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) |
|
|
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) |
|
|
|
|
|
if bias: |
|
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) |
|
|
else: |
|
|
self.register_parameter('in_proj_bias', None) |
|
|
self.out_proj = Linear(embed_dim, embed_dim, bias=bias) |
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
|
|
if er_len is not None: |
|
|
self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32)) |
|
|
else: |
|
|
self.Er = None |
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) |
|
|
else: |
|
|
xavier_uniform_(self.q_proj_weight) |
|
|
xavier_uniform_(self.k_proj_weight) |
|
|
xavier_uniform_(self.v_proj_weight) |
|
|
if self.in_proj_bias is not None: |
|
|
constant_(self.in_proj_bias, 0.) |
|
|
constant_(self.out_proj.bias, 0.) |
|
|
|
|
|
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): |
|
|
return multi_head_attention_forward_rpr( |
|
|
query, key, value, self.embed_dim, self.num_heads, self.head_dim, |
|
|
self.in_proj_weight, self.in_proj_bias, |
|
|
None, None, self.add_zero_attn, |
|
|
self.dropout, self.out_proj.weight, self.out_proj.bias, |
|
|
training=self.training, |
|
|
key_padding_mask=key_padding_mask, need_weights=need_weights, |
|
|
attn_mask=attn_mask, use_separate_proj_weight=not self._qkv_same_embed_dim, |
|
|
q_proj_weight=getattr(self, 'q_proj_weight', None), |
|
|
k_proj_weight=getattr(self, 'k_proj_weight', None), |
|
|
v_proj_weight=getattr(self, 'v_proj_weight', None), |
|
|
rpr_mat=self.Er) |
|
|
|
|
|
def multi_head_attention_forward_rpr(query, key, value, embed_dim_to_check, num_heads, head_dim, |
|
|
in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, |
|
|
dropout_p, out_proj_weight, out_proj_bias, training=True, |
|
|
key_padding_mask=None, need_weights=True, attn_mask=None, |
|
|
use_separate_proj_weight=False, q_proj_weight=None, |
|
|
k_proj_weight=None, v_proj_weight=None, static_k=None, |
|
|
static_v=None, rpr_mat=None): |
|
|
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
|
scaling = float(head_dim) ** -0.5 |
|
|
|
|
|
if not use_separate_proj_weight: |
|
|
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) |
|
|
else: |
|
|
q = F.linear(query, q_proj_weight, in_proj_bias[0:embed_dim]) |
|
|
k = F.linear(key, k_proj_weight, in_proj_bias[embed_dim:(embed_dim * 2)]) |
|
|
v = F.linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2):]) |
|
|
|
|
|
q = q * scaling |
|
|
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) |
|
|
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) |
|
|
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) |
|
|
|
|
|
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
|
|
|
if rpr_mat is not None: |
|
|
|
|
|
len_q = q.shape[1] |
|
|
start_idx = rpr_mat.shape[0] - len_q |
|
|
rpr_mat_valid = rpr_mat[start_idx:, :] |
|
|
qe = torch.einsum("hld,md->hlm", q, rpr_mat_valid) |
|
|
|
|
|
|
|
|
B, L, _ = qe.shape |
|
|
|
|
|
mask_tri = torch.triu(torch.ones((L, L), device=qe.device, dtype=torch.bool)).flip(0) |
|
|
qe = qe.masked_fill(~mask_tri, 0.0) |
|
|
|
|
|
zeros = torch.zeros((B, L, 1), device=qe.device, dtype=qe.dtype) |
|
|
qe_pad = torch.cat([zeros, qe], dim=2).view(B, -1) |
|
|
|
|
|
offsets = torch.arange(L * L, device=qe.device, dtype=torch.int64) + L |
|
|
offsets = offsets.unsqueeze(0).expand(B, -1) |
|
|
srel = torch.gather(qe_pad, 1, offsets).view(B, L, L) |
|
|
|
|
|
attn_output_weights = attn_output_weights + srel |
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
|
if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) |
|
|
|
|
|
is_causal_mask = (attn_mask == float('-inf')) | (attn_mask < -1e4) |
|
|
attn_output_weights = attn_output_weights.masked_fill(is_causal_mask, float('-inf')) |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len) |
|
|
attn_output_weights = attn_output_weights.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') |
|
|
) |
|
|
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, tgt_len) |
|
|
|
|
|
attn_output_weights = F.softmax(attn_output_weights, dim=-1) |
|
|
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) |
|
|
|
|
|
attn_output = torch.bmm(attn_output_weights, v) |
|
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
|
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) |
|
|
|
|
|
if need_weights: |
|
|
return attn_output, attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len).sum(dim=1) / num_heads |
|
|
return attn_output, None |