CU-1 / rfdetr /models /transformer.py
Matis Despujols
Upload 97 files
066effd verified
raw
history blame
26.3 kB
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
# Copyright (c) 2024 Baidu. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Transformer class
"""
import math
import copy
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from rfdetr.models.ops.modules import MSDeformAttn
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def gen_sineembed_for_position(pos_tensor, dim=128):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / dim)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
if memory_padding_mask is not None:
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
else:
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device)
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
if unsigmoid:
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
if memory_padding_mask is not None:
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
else:
if memory_padding_mask is not None:
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(0))
output_memory = memory
if memory_padding_mask is not None:
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
return output_memory.to(memory.dtype), output_proposals.to(memory.dtype)
class Transformer(nn.Module):
def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False,
return_intermediate_dec=False, group_detr=1,
two_stage=False,
num_feature_levels=4, dec_n_points=4,
lite_refpoint_refine=False,
decoder_norm_type='LN',
bbox_reparam=False):
super().__init__()
self.encoder = None
decoder_layer = TransformerDecoderLayer(d_model, sa_nhead, ca_nhead, dim_feedforward,
dropout, activation, normalize_before,
group_detr=group_detr,
num_feature_levels=num_feature_levels,
dec_n_points=dec_n_points,
skip_self_attn=False,)
assert decoder_norm_type in ['LN', 'Identity']
norm = {
"LN": lambda channels: nn.LayerNorm(channels),
"Identity": lambda channels: nn.Identity(),
}
decoder_norm = norm[decoder_norm_type](d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
lite_refpoint_refine=lite_refpoint_refine,
bbox_reparam=bbox_reparam)
self.two_stage = two_stage
if two_stage:
self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)])
self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)])
self._reset_parameters()
self.num_queries = num_queries
self.d_model = d_model
self.dec_layers = num_decoder_layers
self.group_detr = group_detr
self.num_feature_levels = num_feature_levels
self.bbox_reparam = bbox_reparam
self._export = False
def export(self):
self._export = True
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
src_flatten = []
mask_flatten = [] if masks is not None else None
lvl_pos_embed_flatten = []
spatial_shapes = []
valid_ratios = [] if masks is not None else None
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
lvl_pos_embed_flatten.append(pos_embed)
src_flatten.append(src)
if masks is not None:
mask = masks[lvl].flatten(1) # bs, hw
mask_flatten.append(mask)
memory = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
if masks is not None:
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
if self.two_stage:
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam)
# group detr for first stage
refpoint_embed_ts, memory_ts, boxes_ts = [], [], []
group_detr = self.group_detr if self.training else 1
for g_idx in range(group_detr):
output_memory_gidx = self.enc_output_norm[g_idx](self.enc_output[g_idx](output_memory))
enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](output_memory_gidx)
if self.bbox_reparam:
enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](output_memory_gidx)
enc_outputs_coord_cxcy_gidx = enc_outputs_coord_delta_gidx[...,
:2] * output_proposals[..., 2:] + output_proposals[..., :2]
enc_outputs_coord_wh_gidx = enc_outputs_coord_delta_gidx[..., 2:].exp() * output_proposals[..., 2:]
enc_outputs_coord_unselected_gidx = torch.concat(
[enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1)
else:
enc_outputs_coord_unselected_gidx = self.enc_out_bbox_embed[g_idx](
output_memory_gidx) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
topk = min(self.num_queries, enc_outputs_class_unselected_gidx.shape[-2])
topk_proposals_gidx = torch.topk(enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1)[1] # bs, nq
refpoint_embed_gidx_undetach = torch.gather(
enc_outputs_coord_unselected_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
# for decoder layer, detached as initial ones, (bs, nq, 4)
refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach()
# get memory tgt
tgt_undetach_gidx = torch.gather(
output_memory_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model))
refpoint_embed_ts.append(refpoint_embed_gidx)
memory_ts.append(tgt_undetach_gidx)
boxes_ts.append(refpoint_embed_gidx_undetach)
# concat on dim=1, the nq dimension, (bs, nq, d) --> (bs, nq, d)
refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1)
# (bs, nq, d)
memory_ts = torch.cat(memory_ts, dim=1)#.transpose(0, 1)
boxes_ts = torch.cat(boxes_ts, dim=1)#.transpose(0, 1)
if self.dec_layers > 0:
tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1)
refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1)
if self.two_stage:
ts_len = refpoint_embed_ts.shape[-2]
refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :]
refpoint_embed_subset = refpoint_embed[..., ts_len:, :]
if self.bbox_reparam:
refpoint_embed_cxcy = refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:]
refpoint_embed_cxcy = refpoint_embed_cxcy + refpoint_embed_ts[..., :2]
refpoint_embed_wh = refpoint_embed_ts_subset[..., 2:].exp() * refpoint_embed_ts[..., 2:]
refpoint_embed_ts_subset = torch.concat(
[refpoint_embed_cxcy, refpoint_embed_wh], dim=-1
)
else:
refpoint_embed_ts_subset = refpoint_embed_ts_subset + refpoint_embed_ts
refpoint_embed = torch.concat(
[refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2)
hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten, refpoints_unsigmoid=refpoint_embed,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios.to(memory.dtype) if valid_ratios is not None else valid_ratios)
else:
assert self.two_stage, "if not using decoder, two_stage must be True"
hs = None
references = None
if self.two_stage:
if self.bbox_reparam:
return hs, references, memory_ts, boxes_ts
else:
return hs, references, memory_ts, boxes_ts.sigmoid()
return hs, references, None, None
class TransformerDecoder(nn.Module):
def __init__(self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
lite_refpoint_refine=False,
bbox_reparam=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.d_model = d_model
self.norm = norm
self.return_intermediate = return_intermediate
self.lite_refpoint_refine = lite_refpoint_refine
self.bbox_reparam = bbox_reparam
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
self._export = False
def export(self):
self._export = True
def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta):
if self.bbox_reparam:
new_refpoints_cxcy = new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:] + refpoints_unsigmoid[..., :2]
new_refpoints_wh = new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:]
new_refpoints_unsigmoid = torch.concat(
[new_refpoints_cxcy, new_refpoints_wh], dim=-1
)
else:
new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta
return new_refpoints_unsigmoid
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None,
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None):
output = tgt
intermediate = []
hs_refpoints_unsigmoid = [refpoints_unsigmoid]
def get_reference(refpoints):
# [num_queries, batch_size, 4]
obj_center = refpoints[..., :4]
if self._export:
query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model / 2) # bs, nq, 256*2
refpoints_input = obj_center[:, :, None] # bs, nq, 1, 4
else:
refpoints_input = obj_center[:, :, None] \
* torch.cat([valid_ratios, valid_ratios], -1)[:, None] # bs, nq, nlevel, 4
query_sine_embed = gen_sineembed_for_position(
refpoints_input[:, :, 0, :], self.d_model / 2) # bs, nq, 256*2
query_pos = self.ref_point_head(query_sine_embed)
return obj_center, refpoints_input, query_pos, query_sine_embed
# always use init refpoints
if self.lite_refpoint_refine:
if self.bbox_reparam:
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid)
else:
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid())
for layer_id, layer in enumerate(self.layers):
# iter refine each layer
if not self.lite_refpoint_refine:
if self.bbox_reparam:
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid)
else:
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid())
# For the first decoder layer, we do not apply transformation over p_s
pos_transformation = 1
query_pos = query_pos * pos_transformation
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0),
reference_points=refpoints_input,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
if not self.lite_refpoint_refine:
# box iterative update
new_refpoints_delta = self.bbox_embed(output)
new_refpoints_unsigmoid = self.refpoints_refine(refpoints_unsigmoid, new_refpoints_delta)
if layer_id != self.num_layers - 1:
hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid)
refpoints_unsigmoid = new_refpoints_unsigmoid.detach()
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
if self._export:
# to shape: B, N, C
hs = intermediate[-1]
if self.bbox_embed is not None:
ref = hs_refpoints_unsigmoid[-1]
else:
ref = refpoints_unsigmoid
return hs, ref
# box iterative update
if self.bbox_embed is not None:
return [
torch.stack(intermediate),
torch.stack(hs_refpoints_unsigmoid),
]
else:
return [
torch.stack(intermediate),
refpoints_unsigmoid.unsqueeze(0)
]
return output.unsqueeze(0)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, group_detr=1,
num_feature_levels=4, dec_n_points=4,
skip_self_attn=False):
super().__init__()
# Decoder Self-Attention
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# Decoder Cross-Attention
self.cross_attn = MSDeformAttn(
d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points)
self.nhead = ca_nhead
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.group_detr = group_detr
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed = None,
is_first = False,
reference_points = None,
spatial_shapes=None,
level_start_index=None,
):
bs, num_queries, _ = tgt.shape
# ========== Begin of Self-Attention =============
# Apply projections here
# shape: batch_size x num_queries x 256
q = k = tgt + query_pos
v = tgt
if self.training:
q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0)
k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0)
v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0)
tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
need_weights=False)[0]
if self.training:
tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1)
# ========== End of Self-Attention =============
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ========== Begin of Cross-Attention =============
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos),
reference_points,
memory,
spatial_shapes,
level_start_index,
memory_key_padding_mask
)
# ========== End of Cross-Attention =============
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed = None,
is_first = False,
reference_points = None,
spatial_shapes=None,
level_start_index=None):
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,
query_sine_embed, is_first,
reference_points, spatial_shapes, level_start_index)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
try:
two_stage = args.two_stage
except:
two_stage = False
return Transformer(
d_model=args.hidden_dim,
sa_nhead=args.sa_nheads,
ca_nhead=args.ca_nheads,
num_queries=args.num_queries,
dropout=args.dropout,
dim_feedforward=args.dim_feedforward,
num_decoder_layers=args.dec_layers,
return_intermediate_dec=True,
group_detr=args.group_detr,
two_stage=two_stage,
num_feature_levels=args.num_feature_levels,
dec_n_points=args.dec_n_points,
lite_refpoint_refine=args.lite_refpoint_refine,
decoder_norm_type=args.decoder_norm,
bbox_reparam=args.bbox_reparam,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")