|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Transformer class |
|
|
""" |
|
|
import math |
|
|
import copy |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn, Tensor |
|
|
|
|
|
from rfdetr.models.ops.modules import MSDeformAttn |
|
|
|
|
|
class MLP(nn.Module): |
|
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
|
|
def forward(self, x): |
|
|
for i, layer in enumerate(self.layers): |
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor, dim=128): |
|
|
|
|
|
|
|
|
scale = 2 * math.pi |
|
|
dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device) |
|
|
dim_t = 10000 ** (2 * (dim_t // 2) / dim) |
|
|
x_embed = pos_tensor[:, :, 0] * scale |
|
|
y_embed = pos_tensor[:, :, 1] * scale |
|
|
pos_x = x_embed[:, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, None] / dim_t |
|
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
if pos_tensor.size(-1) == 2: |
|
|
pos = torch.cat((pos_y, pos_x), dim=2) |
|
|
elif pos_tensor.size(-1) == 4: |
|
|
w_embed = pos_tensor[:, :, 2] * scale |
|
|
pos_w = w_embed[:, :, None] / dim_t |
|
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
|
|
|
h_embed = pos_tensor[:, :, 3] * scale |
|
|
pos_h = h_embed[:, :, None] / dim_t |
|
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
|
|
else: |
|
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
|
|
return pos |
|
|
|
|
|
|
|
|
def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True): |
|
|
""" |
|
|
Input: |
|
|
- memory: bs, \sum{hw}, d_model |
|
|
- memory_padding_mask: bs, \sum{hw} |
|
|
- spatial_shapes: nlevel, 2 |
|
|
Output: |
|
|
- output_memory: bs, \sum{hw}, d_model |
|
|
- output_proposals: bs, \sum{hw}, 4 |
|
|
""" |
|
|
N_, S_, C_ = memory.shape |
|
|
base_scale = 4.0 |
|
|
proposals = [] |
|
|
_cur = 0 |
|
|
for lvl, (H_, W_) in enumerate(spatial_shapes): |
|
|
if memory_padding_mask is not None: |
|
|
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) |
|
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) |
|
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) |
|
|
else: |
|
|
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device) |
|
|
valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device) |
|
|
|
|
|
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), |
|
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) |
|
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) |
|
|
|
|
|
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) |
|
|
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale |
|
|
|
|
|
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) |
|
|
|
|
|
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) |
|
|
proposals.append(proposal) |
|
|
_cur += (H_ * W_) |
|
|
|
|
|
output_proposals = torch.cat(proposals, 1) |
|
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) |
|
|
|
|
|
if unsigmoid: |
|
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) |
|
|
if memory_padding_mask is not None: |
|
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) |
|
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) |
|
|
else: |
|
|
if memory_padding_mask is not None: |
|
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) |
|
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(0)) |
|
|
|
|
|
output_memory = memory |
|
|
if memory_padding_mask is not None: |
|
|
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) |
|
|
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) |
|
|
|
|
|
return output_memory.to(memory.dtype), output_proposals.to(memory.dtype) |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, |
|
|
num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, |
|
|
activation="relu", normalize_before=False, |
|
|
return_intermediate_dec=False, group_detr=1, |
|
|
two_stage=False, |
|
|
num_feature_levels=4, dec_n_points=4, |
|
|
lite_refpoint_refine=False, |
|
|
decoder_norm_type='LN', |
|
|
bbox_reparam=False): |
|
|
super().__init__() |
|
|
self.encoder = None |
|
|
|
|
|
decoder_layer = TransformerDecoderLayer(d_model, sa_nhead, ca_nhead, dim_feedforward, |
|
|
dropout, activation, normalize_before, |
|
|
group_detr=group_detr, |
|
|
num_feature_levels=num_feature_levels, |
|
|
dec_n_points=dec_n_points, |
|
|
skip_self_attn=False,) |
|
|
assert decoder_norm_type in ['LN', 'Identity'] |
|
|
norm = { |
|
|
"LN": lambda channels: nn.LayerNorm(channels), |
|
|
"Identity": lambda channels: nn.Identity(), |
|
|
} |
|
|
decoder_norm = norm[decoder_norm_type](d_model) |
|
|
|
|
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
|
|
return_intermediate=return_intermediate_dec, |
|
|
d_model=d_model, |
|
|
lite_refpoint_refine=lite_refpoint_refine, |
|
|
bbox_reparam=bbox_reparam) |
|
|
|
|
|
|
|
|
self.two_stage = two_stage |
|
|
if two_stage: |
|
|
self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)]) |
|
|
self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)]) |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
self.num_queries = num_queries |
|
|
self.d_model = d_model |
|
|
self.dec_layers = num_decoder_layers |
|
|
self.group_detr = group_detr |
|
|
self.num_feature_levels = num_feature_levels |
|
|
self.bbox_reparam = bbox_reparam |
|
|
|
|
|
self._export = False |
|
|
|
|
|
def export(self): |
|
|
self._export = True |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
for m in self.modules(): |
|
|
if isinstance(m, MSDeformAttn): |
|
|
m._reset_parameters() |
|
|
|
|
|
def get_valid_ratio(self, mask): |
|
|
_, H, W = mask.shape |
|
|
valid_H = torch.sum(~mask[:, :, 0], 1) |
|
|
valid_W = torch.sum(~mask[:, 0, :], 1) |
|
|
valid_ratio_h = valid_H.float() / H |
|
|
valid_ratio_w = valid_W.float() / W |
|
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
|
|
return valid_ratio |
|
|
|
|
|
def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): |
|
|
src_flatten = [] |
|
|
mask_flatten = [] if masks is not None else None |
|
|
lvl_pos_embed_flatten = [] |
|
|
spatial_shapes = [] |
|
|
valid_ratios = [] if masks is not None else None |
|
|
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)): |
|
|
bs, c, h, w = src.shape |
|
|
spatial_shape = (h, w) |
|
|
spatial_shapes.append(spatial_shape) |
|
|
|
|
|
src = src.flatten(2).transpose(1, 2) |
|
|
pos_embed = pos_embed.flatten(2).transpose(1, 2) |
|
|
lvl_pos_embed_flatten.append(pos_embed) |
|
|
src_flatten.append(src) |
|
|
if masks is not None: |
|
|
mask = masks[lvl].flatten(1) |
|
|
mask_flatten.append(mask) |
|
|
memory = torch.cat(src_flatten, 1) |
|
|
if masks is not None: |
|
|
mask_flatten = torch.cat(mask_flatten, 1) |
|
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) |
|
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) |
|
|
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device) |
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) |
|
|
|
|
|
if self.two_stage: |
|
|
output_memory, output_proposals = gen_encoder_output_proposals( |
|
|
memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam) |
|
|
|
|
|
refpoint_embed_ts, memory_ts, boxes_ts = [], [], [] |
|
|
group_detr = self.group_detr if self.training else 1 |
|
|
for g_idx in range(group_detr): |
|
|
output_memory_gidx = self.enc_output_norm[g_idx](self.enc_output[g_idx](output_memory)) |
|
|
|
|
|
enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](output_memory_gidx) |
|
|
if self.bbox_reparam: |
|
|
enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](output_memory_gidx) |
|
|
enc_outputs_coord_cxcy_gidx = enc_outputs_coord_delta_gidx[..., |
|
|
:2] * output_proposals[..., 2:] + output_proposals[..., :2] |
|
|
enc_outputs_coord_wh_gidx = enc_outputs_coord_delta_gidx[..., 2:].exp() * output_proposals[..., 2:] |
|
|
enc_outputs_coord_unselected_gidx = torch.concat( |
|
|
[enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1) |
|
|
else: |
|
|
enc_outputs_coord_unselected_gidx = self.enc_out_bbox_embed[g_idx]( |
|
|
output_memory_gidx) + output_proposals |
|
|
|
|
|
topk = min(self.num_queries, enc_outputs_class_unselected_gidx.shape[-2]) |
|
|
topk_proposals_gidx = torch.topk(enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1)[1] |
|
|
|
|
|
refpoint_embed_gidx_undetach = torch.gather( |
|
|
enc_outputs_coord_unselected_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4)) |
|
|
|
|
|
refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach() |
|
|
|
|
|
|
|
|
tgt_undetach_gidx = torch.gather( |
|
|
output_memory_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model)) |
|
|
|
|
|
refpoint_embed_ts.append(refpoint_embed_gidx) |
|
|
memory_ts.append(tgt_undetach_gidx) |
|
|
boxes_ts.append(refpoint_embed_gidx_undetach) |
|
|
|
|
|
refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1) |
|
|
|
|
|
memory_ts = torch.cat(memory_ts, dim=1) |
|
|
boxes_ts = torch.cat(boxes_ts, dim=1) |
|
|
|
|
|
if self.dec_layers > 0: |
|
|
tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1) |
|
|
refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1) |
|
|
if self.two_stage: |
|
|
ts_len = refpoint_embed_ts.shape[-2] |
|
|
refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :] |
|
|
refpoint_embed_subset = refpoint_embed[..., ts_len:, :] |
|
|
|
|
|
if self.bbox_reparam: |
|
|
refpoint_embed_cxcy = refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:] |
|
|
refpoint_embed_cxcy = refpoint_embed_cxcy + refpoint_embed_ts[..., :2] |
|
|
refpoint_embed_wh = refpoint_embed_ts_subset[..., 2:].exp() * refpoint_embed_ts[..., 2:] |
|
|
refpoint_embed_ts_subset = torch.concat( |
|
|
[refpoint_embed_cxcy, refpoint_embed_wh], dim=-1 |
|
|
) |
|
|
else: |
|
|
refpoint_embed_ts_subset = refpoint_embed_ts_subset + refpoint_embed_ts |
|
|
|
|
|
refpoint_embed = torch.concat( |
|
|
[refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2) |
|
|
|
|
|
hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten, |
|
|
pos=lvl_pos_embed_flatten, refpoints_unsigmoid=refpoint_embed, |
|
|
level_start_index=level_start_index, |
|
|
spatial_shapes=spatial_shapes, |
|
|
valid_ratios=valid_ratios.to(memory.dtype) if valid_ratios is not None else valid_ratios) |
|
|
else: |
|
|
assert self.two_stage, "if not using decoder, two_stage must be True" |
|
|
hs = None |
|
|
references = None |
|
|
|
|
|
if self.two_stage: |
|
|
if self.bbox_reparam: |
|
|
return hs, references, memory_ts, boxes_ts |
|
|
else: |
|
|
return hs, references, memory_ts, boxes_ts.sigmoid() |
|
|
return hs, references, None, None |
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
decoder_layer, |
|
|
num_layers, |
|
|
norm=None, |
|
|
return_intermediate=False, |
|
|
d_model=256, |
|
|
lite_refpoint_refine=False, |
|
|
bbox_reparam=False): |
|
|
super().__init__() |
|
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
|
self.num_layers = num_layers |
|
|
self.d_model = d_model |
|
|
self.norm = norm |
|
|
self.return_intermediate = return_intermediate |
|
|
self.lite_refpoint_refine = lite_refpoint_refine |
|
|
self.bbox_reparam = bbox_reparam |
|
|
|
|
|
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) |
|
|
|
|
|
self._export = False |
|
|
|
|
|
def export(self): |
|
|
self._export = True |
|
|
|
|
|
def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta): |
|
|
if self.bbox_reparam: |
|
|
new_refpoints_cxcy = new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:] + refpoints_unsigmoid[..., :2] |
|
|
new_refpoints_wh = new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:] |
|
|
new_refpoints_unsigmoid = torch.concat( |
|
|
[new_refpoints_cxcy, new_refpoints_wh], dim=-1 |
|
|
) |
|
|
else: |
|
|
new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta |
|
|
return new_refpoints_unsigmoid |
|
|
|
|
|
def forward(self, tgt, memory, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
refpoints_unsigmoid: Optional[Tensor] = None, |
|
|
|
|
|
level_start_index: Optional[Tensor] = None, |
|
|
spatial_shapes: Optional[Tensor] = None, |
|
|
valid_ratios: Optional[Tensor] = None): |
|
|
output = tgt |
|
|
|
|
|
intermediate = [] |
|
|
hs_refpoints_unsigmoid = [refpoints_unsigmoid] |
|
|
|
|
|
def get_reference(refpoints): |
|
|
|
|
|
obj_center = refpoints[..., :4] |
|
|
|
|
|
if self._export: |
|
|
query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model / 2) |
|
|
refpoints_input = obj_center[:, :, None] |
|
|
else: |
|
|
refpoints_input = obj_center[:, :, None] \ |
|
|
* torch.cat([valid_ratios, valid_ratios], -1)[:, None] |
|
|
query_sine_embed = gen_sineembed_for_position( |
|
|
refpoints_input[:, :, 0, :], self.d_model / 2) |
|
|
query_pos = self.ref_point_head(query_sine_embed) |
|
|
return obj_center, refpoints_input, query_pos, query_sine_embed |
|
|
|
|
|
|
|
|
if self.lite_refpoint_refine: |
|
|
if self.bbox_reparam: |
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) |
|
|
else: |
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) |
|
|
|
|
|
for layer_id, layer in enumerate(self.layers): |
|
|
|
|
|
if not self.lite_refpoint_refine: |
|
|
if self.bbox_reparam: |
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) |
|
|
else: |
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) |
|
|
|
|
|
|
|
|
pos_transformation = 1 |
|
|
|
|
|
query_pos = query_pos * pos_transformation |
|
|
|
|
|
output = layer(output, memory, tgt_mask=tgt_mask, |
|
|
memory_mask=memory_mask, |
|
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
|
memory_key_padding_mask=memory_key_padding_mask, |
|
|
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, |
|
|
is_first=(layer_id == 0), |
|
|
reference_points=refpoints_input, |
|
|
spatial_shapes=spatial_shapes, |
|
|
level_start_index=level_start_index) |
|
|
|
|
|
if not self.lite_refpoint_refine: |
|
|
|
|
|
new_refpoints_delta = self.bbox_embed(output) |
|
|
new_refpoints_unsigmoid = self.refpoints_refine(refpoints_unsigmoid, new_refpoints_delta) |
|
|
if layer_id != self.num_layers - 1: |
|
|
hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid) |
|
|
refpoints_unsigmoid = new_refpoints_unsigmoid.detach() |
|
|
|
|
|
if self.return_intermediate: |
|
|
intermediate.append(self.norm(output)) |
|
|
|
|
|
if self.norm is not None: |
|
|
output = self.norm(output) |
|
|
if self.return_intermediate: |
|
|
intermediate.pop() |
|
|
intermediate.append(output) |
|
|
|
|
|
if self.return_intermediate: |
|
|
if self._export: |
|
|
|
|
|
hs = intermediate[-1] |
|
|
if self.bbox_embed is not None: |
|
|
ref = hs_refpoints_unsigmoid[-1] |
|
|
else: |
|
|
ref = refpoints_unsigmoid |
|
|
return hs, ref |
|
|
|
|
|
if self.bbox_embed is not None: |
|
|
return [ |
|
|
torch.stack(intermediate), |
|
|
torch.stack(hs_refpoints_unsigmoid), |
|
|
] |
|
|
else: |
|
|
return [ |
|
|
torch.stack(intermediate), |
|
|
refpoints_unsigmoid.unsqueeze(0) |
|
|
] |
|
|
|
|
|
return output.unsqueeze(0) |
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1, |
|
|
activation="relu", normalize_before=False, group_detr=1, |
|
|
num_feature_levels=4, dec_n_points=4, |
|
|
skip_self_attn=False): |
|
|
super().__init__() |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.cross_attn = MSDeformAttn( |
|
|
d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points) |
|
|
|
|
|
self.nhead = ca_nhead |
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
self.group_detr = group_detr |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt, memory, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
query_sine_embed = None, |
|
|
is_first = False, |
|
|
reference_points = None, |
|
|
spatial_shapes=None, |
|
|
level_start_index=None, |
|
|
): |
|
|
bs, num_queries, _ = tgt.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = k = tgt + query_pos |
|
|
v = tgt |
|
|
if self.training: |
|
|
q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0) |
|
|
k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0) |
|
|
v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0) |
|
|
|
|
|
tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask, |
|
|
need_weights=False)[0] |
|
|
|
|
|
if self.training: |
|
|
tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1) |
|
|
|
|
|
|
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
tgt = self.norm1(tgt) |
|
|
|
|
|
|
|
|
tgt2 = self.cross_attn( |
|
|
self.with_pos_embed(tgt, query_pos), |
|
|
reference_points, |
|
|
memory, |
|
|
spatial_shapes, |
|
|
level_start_index, |
|
|
memory_key_padding_mask |
|
|
) |
|
|
|
|
|
|
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
tgt = self.norm2(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout3(tgt2) |
|
|
tgt = self.norm3(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt, memory, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
query_sine_embed = None, |
|
|
is_first = False, |
|
|
reference_points = None, |
|
|
spatial_shapes=None, |
|
|
level_start_index=None): |
|
|
return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
|
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, |
|
|
query_sine_embed, is_first, |
|
|
reference_points, spatial_shapes, level_start_index) |
|
|
|
|
|
|
|
|
def _get_clones(module, N): |
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
|
|
|
def build_transformer(args): |
|
|
|
|
|
try: |
|
|
two_stage = args.two_stage |
|
|
except: |
|
|
two_stage = False |
|
|
|
|
|
return Transformer( |
|
|
d_model=args.hidden_dim, |
|
|
sa_nhead=args.sa_nheads, |
|
|
ca_nhead=args.ca_nheads, |
|
|
num_queries=args.num_queries, |
|
|
dropout=args.dropout, |
|
|
dim_feedforward=args.dim_feedforward, |
|
|
num_decoder_layers=args.dec_layers, |
|
|
return_intermediate_dec=True, |
|
|
group_detr=args.group_detr, |
|
|
two_stage=two_stage, |
|
|
num_feature_levels=args.num_feature_levels, |
|
|
dec_n_points=args.dec_n_points, |
|
|
lite_refpoint_refine=args.lite_refpoint_refine, |
|
|
decoder_norm_type=args.decoder_norm, |
|
|
bbox_reparam=args.bbox_reparam, |
|
|
) |
|
|
|
|
|
|
|
|
def _get_activation_fn(activation): |
|
|
"""Return an activation function given a string""" |
|
|
if activation == "relu": |
|
|
return F.relu |
|
|
if activation == "gelu": |
|
|
return F.gelu |
|
|
if activation == "glu": |
|
|
return F.glu |
|
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|