Spaces:
Configuration error
Configuration error
StableVITON
/
preprocess
/detectron2
/projects
/Rethinking-BatchNorm
/configs
/retinanet_SyncBNhead_SharedTraining.py
| from typing import List | |
| import torch | |
| from torch import Tensor, nn | |
| from detectron2.modeling.meta_arch.retinanet import RetinaNetHead | |
| def apply_sequential(inputs, modules): | |
| for mod in modules: | |
| if isinstance(mod, (nn.BatchNorm2d, nn.SyncBatchNorm)): | |
| # for BN layer, normalize all inputs together | |
| shapes = [i.shape for i in inputs] | |
| spatial_sizes = [s[2] * s[3] for s in shapes] | |
| x = [i.flatten(2) for i in inputs] | |
| x = torch.cat(x, dim=2).unsqueeze(3) | |
| x = mod(x).split(spatial_sizes, dim=2) | |
| inputs = [i.view(s) for s, i in zip(shapes, x)] | |
| else: | |
| inputs = [mod(i) for i in inputs] | |
| return inputs | |
| class RetinaNetHead_SharedTrainingBN(RetinaNetHead): | |
| def forward(self, features: List[Tensor]): | |
| logits = apply_sequential(features, list(self.cls_subnet) + [self.cls_score]) | |
| bbox_reg = apply_sequential(features, list(self.bbox_subnet) + [self.bbox_pred]) | |
| return logits, bbox_reg | |
| from .retinanet_SyncBNhead import model, dataloader, lr_multiplier, optimizer, train | |
| model.head._target_ = RetinaNetHead_SharedTrainingBN | |