Spaces:
Running
on
Zero
Running
on
Zero
Upload 56 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- moge/__init__.py +0 -0
- moge/model/__init__.py +18 -0
- moge/model/dinov2/__init__.py +6 -0
- moge/model/dinov2/hub/__init__.py +4 -0
- moge/model/dinov2/hub/backbones.py +156 -0
- moge/model/dinov2/hub/utils.py +39 -0
- moge/model/dinov2/layers/__init__.py +11 -0
- moge/model/dinov2/layers/attention.py +100 -0
- moge/model/dinov2/layers/block.py +259 -0
- moge/model/dinov2/layers/dino_head.py +58 -0
- moge/model/dinov2/layers/drop_path.py +34 -0
- moge/model/dinov2/layers/layer_scale.py +27 -0
- moge/model/dinov2/layers/mlp.py +40 -0
- moge/model/dinov2/layers/patch_embed.py +88 -0
- moge/model/dinov2/layers/swiglu_ffn.py +72 -0
- moge/model/dinov2/models/__init__.py +43 -0
- moge/model/dinov2/models/vision_transformer.py +407 -0
- moge/model/dinov2/utils/__init__.py +4 -0
- moge/model/dinov2/utils/cluster.py +95 -0
- moge/model/dinov2/utils/config.py +72 -0
- moge/model/dinov2/utils/dtype.py +37 -0
- moge/model/dinov2/utils/param_groups.py +103 -0
- moge/model/dinov2/utils/utils.py +95 -0
- moge/model/modules.py +254 -0
- moge/model/utils.py +49 -0
- moge/model/v1.py +392 -0
- moge/model/v2.py +303 -0
- moge/scripts/__init__.py +0 -0
- moge/scripts/app.py +301 -0
- moge/scripts/cli.py +27 -0
- moge/scripts/eval_baseline.py +165 -0
- moge/scripts/infer.py +170 -0
- moge/scripts/infer_baseline.py +140 -0
- moge/scripts/infer_panorama.py +162 -0
- moge/scripts/train.py +452 -0
- moge/scripts/vis_data.py +84 -0
- moge/test/__init__.py +0 -0
- moge/test/baseline.py +43 -0
- moge/test/dataloader.py +221 -0
- moge/test/metrics.py +343 -0
- moge/train/__init__.py +0 -0
- moge/train/dataloader.py +338 -0
- moge/train/losses.py +270 -0
- moge/train/utils.py +57 -0
- moge/utils/__init__.py +0 -0
- moge/utils/alignment.py +416 -0
- moge/utils/download.py +55 -0
- moge/utils/geometry_numpy.py +406 -0
- moge/utils/geometry_torch.py +354 -0
- moge/utils/io.py +236 -0
moge/__init__.py
ADDED
|
File without changes
|
moge/model/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from .v1 import MoGeModel as MoGeModelV1
|
| 6 |
+
from .v2 import MoGeModel as MoGeModelV2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
|
| 10 |
+
assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
module = importlib.import_module(f'.{version}', __package__)
|
| 14 |
+
except ModuleNotFoundError:
|
| 15 |
+
raise ValueError(f'Model version "{version}" not found.')
|
| 16 |
+
|
| 17 |
+
cls = getattr(module, 'MoGeModel')
|
| 18 |
+
return cls
|
moge/model/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
moge/model/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
moge/model/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
moge/model/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
moge/model/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("dinov2")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 23 |
+
try:
|
| 24 |
+
if XFORMERS_ENABLED:
|
| 25 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 26 |
+
|
| 27 |
+
XFORMERS_AVAILABLE = True
|
| 28 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 29 |
+
else:
|
| 30 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 31 |
+
raise ImportError
|
| 32 |
+
except ImportError:
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Attention(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
num_heads: int = 8,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
attn_drop: float = 0.0,
|
| 45 |
+
proj_drop: float = 0.0,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
head_dim = dim // num_heads
|
| 50 |
+
self.scale = head_dim**-0.5
|
| 51 |
+
|
| 52 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 53 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 54 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 55 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 56 |
+
|
| 57 |
+
# # Deprecated implementation, extremely slow
|
| 58 |
+
# def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 59 |
+
# B, N, C = x.shape
|
| 60 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 61 |
+
# q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
+
# attn = q @ k.transpose(-2, -1)
|
| 63 |
+
# attn = attn.softmax(dim=-1)
|
| 64 |
+
# attn = self.attn_drop(attn)
|
| 65 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 66 |
+
# x = self.proj(x)
|
| 67 |
+
# x = self.proj_drop(x)
|
| 68 |
+
# return x
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 73 |
+
|
| 74 |
+
q, k, v = qkv.unbind(0) # (B, H, N, C // H)
|
| 75 |
+
|
| 76 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 77 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MemEffAttention(Attention):
|
| 84 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 85 |
+
if not XFORMERS_AVAILABLE:
|
| 86 |
+
if attn_bias is not None:
|
| 87 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 88 |
+
return super().forward(x)
|
| 89 |
+
|
| 90 |
+
B, N, C = x.shape
|
| 91 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 92 |
+
|
| 93 |
+
q, k, v = unbind(qkv, 2)
|
| 94 |
+
|
| 95 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 96 |
+
x = x.reshape([B, N, C])
|
| 97 |
+
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.proj_drop(x)
|
| 100 |
+
return x
|
moge/model/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
moge/model/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
moge/model/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
moge/model/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
moge/model/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
moge/model/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
moge/model/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
moge/model/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from . import vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("dinov2")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 16 |
+
if "vit" in args.arch:
|
| 17 |
+
vit_kwargs = dict(
|
| 18 |
+
img_size=img_size,
|
| 19 |
+
patch_size=args.patch_size,
|
| 20 |
+
init_values=args.layerscale,
|
| 21 |
+
ffn_layer=args.ffn_layer,
|
| 22 |
+
block_chunks=args.block_chunks,
|
| 23 |
+
qkv_bias=args.qkv_bias,
|
| 24 |
+
proj_bias=args.proj_bias,
|
| 25 |
+
ffn_bias=args.ffn_bias,
|
| 26 |
+
num_register_tokens=args.num_register_tokens,
|
| 27 |
+
interpolate_offset=args.interpolate_offset,
|
| 28 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 29 |
+
)
|
| 30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 31 |
+
if only_teacher:
|
| 32 |
+
return teacher, teacher.embed_dim
|
| 33 |
+
student = vits.__dict__[args.arch](
|
| 34 |
+
**vit_kwargs,
|
| 35 |
+
drop_path_rate=args.drop_path_rate,
|
| 36 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 37 |
+
)
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
return student, teacher, embed_dim
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
moge/model/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable, Optional, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 27 |
+
if not depth_first and include_root:
|
| 28 |
+
fn(module=module, name=name)
|
| 29 |
+
for child_name, child_module in module.named_children():
|
| 30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 32 |
+
if depth_first and include_root:
|
| 33 |
+
fn(module=module, name=name)
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BlockChunk(nn.ModuleList):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
for b in self:
|
| 40 |
+
x = b(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DinoVisionTransformer(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
img_size=224,
|
| 48 |
+
patch_size=16,
|
| 49 |
+
in_chans=3,
|
| 50 |
+
embed_dim=768,
|
| 51 |
+
depth=12,
|
| 52 |
+
num_heads=12,
|
| 53 |
+
mlp_ratio=4.0,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
ffn_bias=True,
|
| 56 |
+
proj_bias=True,
|
| 57 |
+
drop_path_rate=0.0,
|
| 58 |
+
drop_path_uniform=False,
|
| 59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 60 |
+
embed_layer=PatchEmbed,
|
| 61 |
+
act_layer=nn.GELU,
|
| 62 |
+
block_fn=Block,
|
| 63 |
+
ffn_layer="mlp",
|
| 64 |
+
block_chunks=1,
|
| 65 |
+
num_register_tokens=0,
|
| 66 |
+
interpolate_antialias=False,
|
| 67 |
+
interpolate_offset=0.1,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
img_size (int, tuple): input image size
|
| 72 |
+
patch_size (int, tuple): patch size
|
| 73 |
+
in_chans (int): number of input channels
|
| 74 |
+
embed_dim (int): embedding dimension
|
| 75 |
+
depth (int): depth of transformer
|
| 76 |
+
num_heads (int): number of attention heads
|
| 77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 78 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 80 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 81 |
+
drop_path_rate (float): stochastic depth rate
|
| 82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 83 |
+
weight_init (str): weight init scheme
|
| 84 |
+
init_values (float): layer-scale init values
|
| 85 |
+
embed_layer (nn.Module): patch embedding layer
|
| 86 |
+
act_layer (nn.Module): MLP activation layer
|
| 87 |
+
block_fn (nn.Module): transformer block class
|
| 88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 96 |
+
|
| 97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 98 |
+
self.num_tokens = 1
|
| 99 |
+
self.n_blocks = depth
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.num_register_tokens = num_register_tokens
|
| 103 |
+
self.interpolate_antialias = interpolate_antialias
|
| 104 |
+
self.interpolate_offset = interpolate_offset
|
| 105 |
+
|
| 106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 107 |
+
num_patches = self.patch_embed.num_patches
|
| 108 |
+
|
| 109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 111 |
+
assert num_register_tokens >= 0
|
| 112 |
+
self.register_tokens = (
|
| 113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if drop_path_uniform is True:
|
| 117 |
+
dpr = [drop_path_rate] * depth
|
| 118 |
+
else:
|
| 119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 120 |
+
|
| 121 |
+
if ffn_layer == "mlp":
|
| 122 |
+
logger.info("using MLP layer as FFN")
|
| 123 |
+
ffn_layer = Mlp
|
| 124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 125 |
+
logger.info("using SwiGLU layer as FFN")
|
| 126 |
+
ffn_layer = SwiGLUFFNFused
|
| 127 |
+
elif ffn_layer == "identity":
|
| 128 |
+
logger.info("using Identity layer as FFN")
|
| 129 |
+
|
| 130 |
+
def f(*args, **kwargs):
|
| 131 |
+
return nn.Identity()
|
| 132 |
+
|
| 133 |
+
ffn_layer = f
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
| 137 |
+
blocks_list = [
|
| 138 |
+
block_fn(
|
| 139 |
+
dim=embed_dim,
|
| 140 |
+
num_heads=num_heads,
|
| 141 |
+
mlp_ratio=mlp_ratio,
|
| 142 |
+
qkv_bias=qkv_bias,
|
| 143 |
+
proj_bias=proj_bias,
|
| 144 |
+
ffn_bias=ffn_bias,
|
| 145 |
+
drop_path=dpr[i],
|
| 146 |
+
norm_layer=norm_layer,
|
| 147 |
+
act_layer=act_layer,
|
| 148 |
+
ffn_layer=ffn_layer,
|
| 149 |
+
init_values=init_values,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
if block_chunks > 0:
|
| 154 |
+
self.chunked_blocks = True
|
| 155 |
+
chunked_blocks = []
|
| 156 |
+
chunksize = depth // block_chunks
|
| 157 |
+
for i in range(0, depth, chunksize):
|
| 158 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 161 |
+
else:
|
| 162 |
+
self.chunked_blocks = False
|
| 163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 164 |
+
|
| 165 |
+
self.norm = norm_layer(embed_dim)
|
| 166 |
+
self.head = nn.Identity()
|
| 167 |
+
|
| 168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 169 |
+
|
| 170 |
+
self.init_weights()
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def onnx_compatible_mode(self):
|
| 174 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 175 |
+
|
| 176 |
+
@onnx_compatible_mode.setter
|
| 177 |
+
def onnx_compatible_mode(self, value: bool):
|
| 178 |
+
self._onnx_compatible_mode = value
|
| 179 |
+
|
| 180 |
+
def init_weights(self):
|
| 181 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 182 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 183 |
+
if self.register_tokens is not None:
|
| 184 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 185 |
+
named_apply(init_weights_vit_timm, self)
|
| 186 |
+
|
| 187 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 188 |
+
previous_dtype = x.dtype
|
| 189 |
+
npatch = x.shape[1] - 1
|
| 190 |
+
batch_size = x.shape[0]
|
| 191 |
+
N = self.pos_embed.shape[1] - 1
|
| 192 |
+
if not self.onnx_compatible_mode and npatch == N and w == h:
|
| 193 |
+
return self.pos_embed
|
| 194 |
+
pos_embed = self.pos_embed.float()
|
| 195 |
+
class_pos_embed = pos_embed[:, 0, :]
|
| 196 |
+
patch_pos_embed = pos_embed[:, 1:, :]
|
| 197 |
+
dim = x.shape[-1]
|
| 198 |
+
h0, w0 = h // self.patch_size, w // self.patch_size
|
| 199 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 200 |
+
assert N == M * M
|
| 201 |
+
kwargs = {}
|
| 202 |
+
if not self.onnx_compatible_mode and self.interpolate_offset > 0:
|
| 203 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 204 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 205 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 206 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 207 |
+
kwargs["scale_factor"] = (sy, sx)
|
| 208 |
+
else:
|
| 209 |
+
# Simply specify an output size instead of a scale factor
|
| 210 |
+
kwargs["size"] = (h0, w0)
|
| 211 |
+
|
| 212 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 213 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 214 |
+
mode="bicubic",
|
| 215 |
+
antialias=self.interpolate_antialias,
|
| 216 |
+
**kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
assert (h0, w0) == patch_pos_embed.shape[-2:]
|
| 220 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
| 221 |
+
return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
|
| 222 |
+
|
| 223 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 224 |
+
B, nc, h, w = x.shape
|
| 225 |
+
x = self.patch_embed(x)
|
| 226 |
+
|
| 227 |
+
if masks is not None:
|
| 228 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 229 |
+
|
| 230 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 231 |
+
x = x + self.interpolate_pos_encoding(x, h, w)
|
| 232 |
+
|
| 233 |
+
if self.register_tokens is not None:
|
| 234 |
+
x = torch.cat(
|
| 235 |
+
(
|
| 236 |
+
x[:, :1],
|
| 237 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 238 |
+
x[:, 1:],
|
| 239 |
+
),
|
| 240 |
+
dim=1,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def forward_features_list(self, x_list, masks_list):
|
| 246 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)]
|
| 247 |
+
for blk in self.blocks:
|
| 248 |
+
x = blk(x)
|
| 249 |
+
|
| 250 |
+
all_x = x
|
| 251 |
+
output = []
|
| 252 |
+
for x, masks in zip(all_x, masks_list):
|
| 253 |
+
x_norm = self.norm(x)
|
| 254 |
+
output.append(
|
| 255 |
+
{
|
| 256 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 257 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 258 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 259 |
+
"x_prenorm": x,
|
| 260 |
+
"masks": masks,
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
return output
|
| 264 |
+
|
| 265 |
+
def forward_features(self, x, masks=None):
|
| 266 |
+
if isinstance(x, list):
|
| 267 |
+
return self.forward_features_list(x, masks)
|
| 268 |
+
|
| 269 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 270 |
+
|
| 271 |
+
for blk in self.blocks:
|
| 272 |
+
x = blk(x)
|
| 273 |
+
|
| 274 |
+
x_norm = self.norm(x)
|
| 275 |
+
return {
|
| 276 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 277 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 278 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 279 |
+
"x_prenorm": x,
|
| 280 |
+
"masks": masks,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 284 |
+
x = self.prepare_tokens_with_masks(x)
|
| 285 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 286 |
+
output, total_block_len = [], len(self.blocks)
|
| 287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 288 |
+
for i, blk in enumerate(self.blocks):
|
| 289 |
+
x = blk(x)
|
| 290 |
+
if i in blocks_to_take:
|
| 291 |
+
output.append(x)
|
| 292 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 296 |
+
x = self.prepare_tokens_with_masks(x)
|
| 297 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 298 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 299 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 300 |
+
for block_chunk in self.blocks:
|
| 301 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 302 |
+
x = blk(x)
|
| 303 |
+
if i in blocks_to_take:
|
| 304 |
+
output.append(x)
|
| 305 |
+
i += 1
|
| 306 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 307 |
+
return output
|
| 308 |
+
|
| 309 |
+
def get_intermediate_layers(
|
| 310 |
+
self,
|
| 311 |
+
x: torch.Tensor,
|
| 312 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 313 |
+
reshape: bool = False,
|
| 314 |
+
return_class_token: bool = False,
|
| 315 |
+
norm=True,
|
| 316 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 317 |
+
if self.chunked_blocks:
|
| 318 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 319 |
+
else:
|
| 320 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 321 |
+
if norm:
|
| 322 |
+
outputs = [self.norm(out) for out in outputs]
|
| 323 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 324 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 325 |
+
if reshape:
|
| 326 |
+
B, _, w, h = x.shape
|
| 327 |
+
outputs = [
|
| 328 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 329 |
+
for out in outputs
|
| 330 |
+
]
|
| 331 |
+
if return_class_token:
|
| 332 |
+
return tuple(zip(outputs, class_tokens))
|
| 333 |
+
return tuple(outputs)
|
| 334 |
+
|
| 335 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 336 |
+
ret = self.forward_features(*args, **kwargs)
|
| 337 |
+
if is_training:
|
| 338 |
+
return ret
|
| 339 |
+
else:
|
| 340 |
+
return self.head(ret["x_norm_clstoken"])
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 344 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 345 |
+
if isinstance(module, nn.Linear):
|
| 346 |
+
trunc_normal_(module.weight, std=0.02)
|
| 347 |
+
if module.bias is not None:
|
| 348 |
+
nn.init.zeros_(module.bias)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 352 |
+
model = DinoVisionTransformer(
|
| 353 |
+
patch_size=patch_size,
|
| 354 |
+
embed_dim=384,
|
| 355 |
+
depth=12,
|
| 356 |
+
num_heads=6,
|
| 357 |
+
mlp_ratio=4,
|
| 358 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 359 |
+
num_register_tokens=num_register_tokens,
|
| 360 |
+
**kwargs,
|
| 361 |
+
)
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 366 |
+
model = DinoVisionTransformer(
|
| 367 |
+
patch_size=patch_size,
|
| 368 |
+
embed_dim=768,
|
| 369 |
+
depth=12,
|
| 370 |
+
num_heads=12,
|
| 371 |
+
mlp_ratio=4,
|
| 372 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 373 |
+
num_register_tokens=num_register_tokens,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
return model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 380 |
+
model = DinoVisionTransformer(
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
embed_dim=1024,
|
| 383 |
+
depth=24,
|
| 384 |
+
num_heads=16,
|
| 385 |
+
mlp_ratio=4,
|
| 386 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 387 |
+
num_register_tokens=num_register_tokens,
|
| 388 |
+
**kwargs,
|
| 389 |
+
)
|
| 390 |
+
return model
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 394 |
+
"""
|
| 395 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 396 |
+
"""
|
| 397 |
+
model = DinoVisionTransformer(
|
| 398 |
+
patch_size=patch_size,
|
| 399 |
+
embed_dim=1536,
|
| 400 |
+
depth=40,
|
| 401 |
+
num_heads=24,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 404 |
+
num_register_tokens=num_register_tokens,
|
| 405 |
+
**kwargs,
|
| 406 |
+
)
|
| 407 |
+
return model
|
moge/model/dinov2/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/utils/cluster.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
| 32 |
+
if cluster_type is None:
|
| 33 |
+
return _guess_cluster_type()
|
| 34 |
+
|
| 35 |
+
return cluster_type
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 39 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 40 |
+
if cluster_type is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
CHECKPOINT_DIRNAMES = {
|
| 44 |
+
ClusterType.AWS: "checkpoints",
|
| 45 |
+
ClusterType.FAIR: "checkpoint",
|
| 46 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 47 |
+
}
|
| 48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 53 |
+
if checkpoint_path is None:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
username = os.environ.get("USER")
|
| 57 |
+
assert username is not None
|
| 58 |
+
return checkpoint_path / username
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 62 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 63 |
+
if cluster_type is None:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
SLURM_PARTITIONS = {
|
| 67 |
+
ClusterType.AWS: "learnlab",
|
| 68 |
+
ClusterType.FAIR: "learnlab",
|
| 69 |
+
ClusterType.RSC: "learn",
|
| 70 |
+
}
|
| 71 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_slurm_executor_parameters(
|
| 75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
| 76 |
+
) -> Dict[str, Any]:
|
| 77 |
+
# create default parameters
|
| 78 |
+
params = {
|
| 79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 80 |
+
"gpus_per_node": num_gpus_per_node,
|
| 81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 82 |
+
"cpus_per_task": 10,
|
| 83 |
+
"nodes": nodes,
|
| 84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 85 |
+
}
|
| 86 |
+
# apply cluster-specific adjustments
|
| 87 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 88 |
+
if cluster_type == ClusterType.AWS:
|
| 89 |
+
params["cpus_per_task"] = 12
|
| 90 |
+
del params["mem_gb"]
|
| 91 |
+
elif cluster_type == ClusterType.RSC:
|
| 92 |
+
params["cpus_per_task"] = 12
|
| 93 |
+
# set additional parameters / apply overrides
|
| 94 |
+
params.update(kwargs)
|
| 95 |
+
return params
|
moge/model/dinov2/utils/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from dinov2.logging import setup_logging
|
| 14 |
+
from dinov2.utils import utils
|
| 15 |
+
from dinov2.configs import dinov2_default_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 23 |
+
base_lr = cfg.optim.base_lr
|
| 24 |
+
cfg.optim.lr = base_lr
|
| 25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def setup(args):
|
| 64 |
+
"""
|
| 65 |
+
Create configs and perform basic setups.
|
| 66 |
+
"""
|
| 67 |
+
cfg = get_cfg_from_args(args)
|
| 68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 69 |
+
default_setup(args)
|
| 70 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 71 |
+
write_config(cfg, args.output_dir)
|
| 72 |
+
return cfg
|
moge/model/dinov2/utils/dtype.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 17 |
+
np.dtype("bool"): torch.bool,
|
| 18 |
+
np.dtype("uint8"): torch.uint8,
|
| 19 |
+
np.dtype("int8"): torch.int8,
|
| 20 |
+
np.dtype("int16"): torch.int16,
|
| 21 |
+
np.dtype("int32"): torch.int32,
|
| 22 |
+
np.dtype("int64"): torch.int64,
|
| 23 |
+
np.dtype("float16"): torch.float16,
|
| 24 |
+
np.dtype("float32"): torch.float32,
|
| 25 |
+
np.dtype("float64"): torch.float64,
|
| 26 |
+
np.dtype("complex64"): torch.complex64,
|
| 27 |
+
np.dtype("complex128"): torch.complex128,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 32 |
+
if isinstance(dtype, torch.dtype):
|
| 33 |
+
return dtype
|
| 34 |
+
if isinstance(dtype, str):
|
| 35 |
+
dtype = np.dtype(dtype)
|
| 36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
moge/model/dinov2/utils/param_groups.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
| 14 |
+
"""
|
| 15 |
+
Calculate lr decay rate for different ViT blocks.
|
| 16 |
+
Args:
|
| 17 |
+
name (string): parameter name.
|
| 18 |
+
lr_decay_rate (float): base lr decay rate.
|
| 19 |
+
num_layers (int): number of ViT blocks.
|
| 20 |
+
Returns:
|
| 21 |
+
lr decay rate for the given parameter.
|
| 22 |
+
"""
|
| 23 |
+
layer_id = num_layers + 1
|
| 24 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 25 |
+
if (
|
| 26 |
+
".pos_embed" in name
|
| 27 |
+
or ".patch_embed" in name
|
| 28 |
+
or ".mask_token" in name
|
| 29 |
+
or ".cls_token" in name
|
| 30 |
+
or ".register_tokens" in name
|
| 31 |
+
):
|
| 32 |
+
layer_id = 0
|
| 33 |
+
elif force_is_backbone and (
|
| 34 |
+
"pos_embed" in name
|
| 35 |
+
or "patch_embed" in name
|
| 36 |
+
or "mask_token" in name
|
| 37 |
+
or "cls_token" in name
|
| 38 |
+
or "register_tokens" in name
|
| 39 |
+
):
|
| 40 |
+
layer_id = 0
|
| 41 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 45 |
+
elif "blocks." in name and "residual." not in name:
|
| 46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 47 |
+
|
| 48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 52 |
+
chunked_blocks = False
|
| 53 |
+
if hasattr(model, "n_blocks"):
|
| 54 |
+
logger.info("chunked fsdp")
|
| 55 |
+
n_blocks = model.n_blocks
|
| 56 |
+
chunked_blocks = model.chunked_blocks
|
| 57 |
+
elif hasattr(model, "blocks"):
|
| 58 |
+
logger.info("first code branch")
|
| 59 |
+
n_blocks = len(model.blocks)
|
| 60 |
+
elif hasattr(model, "backbone"):
|
| 61 |
+
logger.info("second code branch")
|
| 62 |
+
n_blocks = len(model.backbone.blocks)
|
| 63 |
+
else:
|
| 64 |
+
logger.info("else code branch")
|
| 65 |
+
n_blocks = 0
|
| 66 |
+
all_param_groups = []
|
| 67 |
+
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
| 70 |
+
if not param.requires_grad:
|
| 71 |
+
continue
|
| 72 |
+
decay_rate = get_vit_lr_decay_rate(
|
| 73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
| 74 |
+
)
|
| 75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
| 76 |
+
|
| 77 |
+
if "last_layer" in name:
|
| 78 |
+
d.update({"is_last_layer": True})
|
| 79 |
+
|
| 80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 81 |
+
d.update({"wd_multiplier": 0.0})
|
| 82 |
+
|
| 83 |
+
if "patch_embed" in name:
|
| 84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 85 |
+
|
| 86 |
+
all_param_groups.append(d)
|
| 87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
| 88 |
+
|
| 89 |
+
return all_param_groups
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| 93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 94 |
+
for d in all_params_groups:
|
| 95 |
+
identifier = ""
|
| 96 |
+
for k in keys:
|
| 97 |
+
identifier += k + str(d[k]) + "_"
|
| 98 |
+
|
| 99 |
+
for k in keys:
|
| 100 |
+
fused_params_groups[identifier][k] = d[k]
|
| 101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
| 102 |
+
|
| 103 |
+
return fused_params_groups.values()
|
moge/model/dinov2/utils/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import subprocess
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
| 21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
| 22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
| 23 |
+
else:
|
| 24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 26 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 27 |
+
state_dict = state_dict[checkpoint_key]
|
| 28 |
+
# remove `module.` prefix
|
| 29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 33 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fix_random_seeds(seed=31):
|
| 37 |
+
"""
|
| 38 |
+
Fix random seeds.
|
| 39 |
+
"""
|
| 40 |
+
torch.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sha():
|
| 47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
|
| 49 |
+
def _run(command):
|
| 50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 51 |
+
|
| 52 |
+
sha = "N/A"
|
| 53 |
+
diff = "clean"
|
| 54 |
+
branch = "N/A"
|
| 55 |
+
try:
|
| 56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
| 57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
| 59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
| 60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 64 |
+
return message
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CosineScheduler(object):
|
| 68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.final_value = final_value
|
| 71 |
+
self.total_iters = total_iters
|
| 72 |
+
|
| 73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
| 74 |
+
|
| 75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 76 |
+
|
| 77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
| 78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
| 80 |
+
|
| 81 |
+
assert len(self.schedule) == self.total_iters
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, it):
|
| 84 |
+
if it >= self.total_iters:
|
| 85 |
+
return self.final_value
|
| 86 |
+
else:
|
| 87 |
+
return self.schedule[it]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def has_batchnorms(model):
|
| 91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 92 |
+
for name, module in model.named_modules():
|
| 93 |
+
if isinstance(module, bn_types):
|
| 94 |
+
return True
|
| 95 |
+
return False
|
moge/model/modules.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
import importlib
|
| 4 |
+
import itertools
|
| 5 |
+
import functools
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from .dinov2.models.vision_transformer import DinoVisionTransformer
|
| 14 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
| 15 |
+
from ..utils.geometry_torch import normalized_view_plane_uv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ResidualConvBlock(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
in_channels: int,
|
| 22 |
+
out_channels: int = None,
|
| 23 |
+
hidden_channels: int = None,
|
| 24 |
+
kernel_size: int = 3,
|
| 25 |
+
padding_mode: str = 'replicate',
|
| 26 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 27 |
+
in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
|
| 28 |
+
hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
|
| 29 |
+
):
|
| 30 |
+
super(ResidualConvBlock, self).__init__()
|
| 31 |
+
if out_channels is None:
|
| 32 |
+
out_channels = in_channels
|
| 33 |
+
if hidden_channels is None:
|
| 34 |
+
hidden_channels = in_channels
|
| 35 |
+
|
| 36 |
+
if activation =='relu':
|
| 37 |
+
activation_cls = nn.ReLU
|
| 38 |
+
elif activation == 'leaky_relu':
|
| 39 |
+
activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
|
| 40 |
+
elif activation =='silu':
|
| 41 |
+
activation_cls = nn.SiLU
|
| 42 |
+
elif activation == 'elu':
|
| 43 |
+
activation_cls = nn.ELU
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
| 46 |
+
|
| 47 |
+
self.layers = nn.Sequential(
|
| 48 |
+
nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
|
| 49 |
+
nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
|
| 50 |
+
nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
|
| 51 |
+
nn.Identity(),
|
| 52 |
+
activation_cls(),
|
| 53 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
|
| 54 |
+
nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
|
| 55 |
+
nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
|
| 56 |
+
nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
|
| 57 |
+
nn.Identity(),
|
| 58 |
+
activation_cls(),
|
| 59 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
skip = self.skip_connection(x)
|
| 66 |
+
x = self.layers(x)
|
| 67 |
+
x = x + skip
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class DINOv2Encoder(nn.Module):
|
| 72 |
+
"Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
|
| 73 |
+
backbone: DinoVisionTransformer
|
| 74 |
+
image_mean: torch.Tensor
|
| 75 |
+
image_std: torch.Tensor
|
| 76 |
+
dim_features: int
|
| 77 |
+
|
| 78 |
+
def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
|
| 79 |
+
super(DINOv2Encoder, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.intermediate_layers = intermediate_layers
|
| 82 |
+
|
| 83 |
+
# Load the backbone
|
| 84 |
+
self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
|
| 85 |
+
self.backbone_name = backbone
|
| 86 |
+
self.backbone = self.hub_loader(pretrained=False)
|
| 87 |
+
|
| 88 |
+
self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
|
| 89 |
+
self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
|
| 90 |
+
|
| 91 |
+
self.output_projections = nn.ModuleList([
|
| 92 |
+
nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
|
| 93 |
+
for _ in range(self.num_features)
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 97 |
+
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def onnx_compatible_mode(self):
|
| 101 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 102 |
+
|
| 103 |
+
@onnx_compatible_mode.setter
|
| 104 |
+
def onnx_compatible_mode(self, value: bool):
|
| 105 |
+
self._onnx_compatible_mode = value
|
| 106 |
+
self.backbone.onnx_compatible_mode = value
|
| 107 |
+
|
| 108 |
+
def init_weights(self):
|
| 109 |
+
pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
|
| 110 |
+
self.backbone.load_state_dict(pretrained_backbone_state_dict)
|
| 111 |
+
|
| 112 |
+
def enable_gradient_checkpointing(self):
|
| 113 |
+
for i in range(len(self.backbone.blocks)):
|
| 114 |
+
wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
| 115 |
+
|
| 116 |
+
def enable_pytorch_native_sdpa(self):
|
| 117 |
+
for i in range(len(self.backbone.blocks)):
|
| 118 |
+
wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
| 119 |
+
|
| 120 |
+
def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 121 |
+
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
|
| 122 |
+
image_14 = (image_14 - self.image_mean) / self.image_std
|
| 123 |
+
|
| 124 |
+
# Get intermediate layers from the backbone
|
| 125 |
+
features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
|
| 126 |
+
|
| 127 |
+
# Project features to the desired dimensionality
|
| 128 |
+
x = torch.stack([
|
| 129 |
+
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
|
| 130 |
+
for proj, (feat, clstoken) in zip(self.output_projections, features)
|
| 131 |
+
], dim=1).sum(dim=1)
|
| 132 |
+
|
| 133 |
+
if return_class_token:
|
| 134 |
+
return x, features[-1][1]
|
| 135 |
+
else:
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Resampler(nn.Sequential):
|
| 140 |
+
def __init__(self,
|
| 141 |
+
in_channels: int,
|
| 142 |
+
out_channels: int,
|
| 143 |
+
type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
|
| 144 |
+
scale_factor: int = 2,
|
| 145 |
+
):
|
| 146 |
+
if type_ == 'pixel_shuffle':
|
| 147 |
+
nn.Sequential.__init__(self,
|
| 148 |
+
nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 149 |
+
nn.PixelShuffle(scale_factor),
|
| 150 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 151 |
+
)
|
| 152 |
+
for i in range(1, scale_factor ** 2):
|
| 153 |
+
self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
|
| 154 |
+
self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
|
| 155 |
+
elif type_ in ['nearest', 'bilinear']:
|
| 156 |
+
nn.Sequential.__init__(self,
|
| 157 |
+
nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
|
| 158 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 159 |
+
)
|
| 160 |
+
elif type_ == 'conv_transpose':
|
| 161 |
+
nn.Sequential.__init__(self,
|
| 162 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
|
| 163 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 164 |
+
)
|
| 165 |
+
self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
|
| 166 |
+
elif type_ == 'pixel_unshuffle':
|
| 167 |
+
nn.Sequential.__init__(self,
|
| 168 |
+
nn.PixelUnshuffle(scale_factor),
|
| 169 |
+
nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 170 |
+
)
|
| 171 |
+
elif type_ == 'avg_pool':
|
| 172 |
+
nn.Sequential.__init__(self,
|
| 173 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 174 |
+
nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 175 |
+
)
|
| 176 |
+
elif type_ == 'max_pool':
|
| 177 |
+
nn.Sequential.__init__(self,
|
| 178 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 179 |
+
nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f'Unsupported resampler type: {type_}')
|
| 183 |
+
|
| 184 |
+
class MLP(nn.Sequential):
|
| 185 |
+
def __init__(self, dims: Sequence[int]):
|
| 186 |
+
nn.Sequential.__init__(self,
|
| 187 |
+
*itertools.chain(*[
|
| 188 |
+
(nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
|
| 189 |
+
for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
|
| 190 |
+
]),
|
| 191 |
+
nn.Linear(dims[-2], dims[-1]),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ConvStack(nn.Module):
|
| 196 |
+
def __init__(self,
|
| 197 |
+
dim_in: List[Optional[int]],
|
| 198 |
+
dim_res_blocks: List[int],
|
| 199 |
+
dim_out: List[Optional[int]],
|
| 200 |
+
resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
|
| 201 |
+
dim_times_res_block_hidden: int = 1,
|
| 202 |
+
num_res_blocks: int = 1,
|
| 203 |
+
res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
|
| 204 |
+
res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
|
| 205 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 206 |
+
):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.input_blocks = nn.ModuleList([
|
| 209 |
+
nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
|
| 210 |
+
for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
|
| 211 |
+
])
|
| 212 |
+
self.resamplers = nn.ModuleList([
|
| 213 |
+
Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
|
| 214 |
+
for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
|
| 215 |
+
dim_res_blocks[:-1],
|
| 216 |
+
dim_res_blocks[1:],
|
| 217 |
+
resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
|
| 218 |
+
))
|
| 219 |
+
])
|
| 220 |
+
self.res_blocks = nn.ModuleList([
|
| 221 |
+
nn.Sequential(
|
| 222 |
+
*(
|
| 223 |
+
ResidualConvBlock(
|
| 224 |
+
dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
|
| 225 |
+
activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
|
| 226 |
+
) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
|
| 227 |
+
)
|
| 228 |
+
) for i, dim_res_block_ in enumerate(dim_res_blocks)
|
| 229 |
+
])
|
| 230 |
+
self.output_blocks = nn.ModuleList([
|
| 231 |
+
nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
|
| 232 |
+
for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
|
| 233 |
+
])
|
| 234 |
+
|
| 235 |
+
def enable_gradient_checkpointing(self):
|
| 236 |
+
for i in range(len(self.resamplers)):
|
| 237 |
+
self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
|
| 238 |
+
for i in range(len(self.res_blocks)):
|
| 239 |
+
for j in range(len(self.res_blocks[i])):
|
| 240 |
+
self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
|
| 241 |
+
|
| 242 |
+
def forward(self, in_features: List[torch.Tensor]):
|
| 243 |
+
out_features = []
|
| 244 |
+
for i in range(len(self.res_blocks)):
|
| 245 |
+
feature = self.input_blocks[i](in_features[i])
|
| 246 |
+
if i == 0:
|
| 247 |
+
x = feature
|
| 248 |
+
elif feature is not None:
|
| 249 |
+
x = x + feature
|
| 250 |
+
x = self.res_blocks[i](x)
|
| 251 |
+
out_features.append(self.output_blocks[i](x))
|
| 252 |
+
if i < len(self.res_blocks) - 1:
|
| 253 |
+
x = self.resamplers[i](x)
|
| 254 |
+
return out_features
|
moge/model/utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
class _CheckpointingWrapper(module.__class__):
|
| 10 |
+
_restore_cls = module.__class__
|
| 11 |
+
def forward(self, *args, **kwargs):
|
| 12 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
| 13 |
+
|
| 14 |
+
module.__class__ = _CheckpointingWrapper
|
| 15 |
+
return module
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 19 |
+
module.__class__ = module.__class__._restore_cls
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
| 23 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
| 24 |
+
class _AttentionWrapper(module.__class__):
|
| 25 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 26 |
+
B, N, C = x.shape
|
| 27 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 28 |
+
|
| 29 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
| 30 |
+
|
| 31 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 32 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 33 |
+
|
| 34 |
+
x = self.proj(x)
|
| 35 |
+
x = self.proj_drop(x)
|
| 36 |
+
return x
|
| 37 |
+
module.__class__ = _AttentionWrapper
|
| 38 |
+
return module
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
|
| 42 |
+
group_to_use = torch.distributed.group.WORLD
|
| 43 |
+
world_size = group_to_use.size()
|
| 44 |
+
grad = bucket.buffer()
|
| 45 |
+
grad.div_(world_size)
|
| 46 |
+
torch.distributed.all_reduce(grad, group=group_to_use)
|
| 47 |
+
fut = torch.futures.Future()
|
| 48 |
+
fut.set_result(grad)
|
| 49 |
+
return fut
|
moge/model/v1.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
from functools import partial
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import importlib
|
| 6 |
+
import warnings
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.utils
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
import torch.version
|
| 15 |
+
import utils3d
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask
|
| 20 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
| 21 |
+
from ..utils.tools import timeit
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ResidualConvBlock(nn.Module):
|
| 25 |
+
def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
|
| 26 |
+
super(ResidualConvBlock, self).__init__()
|
| 27 |
+
if out_channels is None:
|
| 28 |
+
out_channels = in_channels
|
| 29 |
+
if hidden_channels is None:
|
| 30 |
+
hidden_channels = in_channels
|
| 31 |
+
|
| 32 |
+
if activation =='relu':
|
| 33 |
+
activation_cls = lambda: nn.ReLU(inplace=True)
|
| 34 |
+
elif activation == 'leaky_relu':
|
| 35 |
+
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 36 |
+
elif activation =='silu':
|
| 37 |
+
activation_cls = lambda: nn.SiLU(inplace=True)
|
| 38 |
+
elif activation == 'elu':
|
| 39 |
+
activation_cls = lambda: nn.ELU(inplace=True)
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
| 42 |
+
|
| 43 |
+
self.layers = nn.Sequential(
|
| 44 |
+
nn.GroupNorm(1, in_channels),
|
| 45 |
+
activation_cls(),
|
| 46 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
|
| 47 |
+
nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
|
| 48 |
+
activation_cls(),
|
| 49 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
skip = self.skip_connection(x)
|
| 56 |
+
x = self.layers(x)
|
| 57 |
+
x = x + skip
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Head(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
num_features: int,
|
| 65 |
+
dim_in: int,
|
| 66 |
+
dim_out: List[int],
|
| 67 |
+
dim_proj: int = 512,
|
| 68 |
+
dim_upsample: List[int] = [256, 128, 128],
|
| 69 |
+
dim_times_res_block_hidden: int = 1,
|
| 70 |
+
num_res_blocks: int = 1,
|
| 71 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
| 72 |
+
last_res_blocks: int = 0,
|
| 73 |
+
last_conv_channels: int = 32,
|
| 74 |
+
last_conv_size: int = 1
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.projects = nn.ModuleList([
|
| 79 |
+
nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
|
| 80 |
+
])
|
| 81 |
+
|
| 82 |
+
self.upsample_blocks = nn.ModuleList([
|
| 83 |
+
nn.Sequential(
|
| 84 |
+
self._make_upsampler(in_ch + 2, out_ch),
|
| 85 |
+
*(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
|
| 86 |
+
) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
self.output_block = nn.ModuleList([
|
| 90 |
+
self._make_output_block(
|
| 91 |
+
dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
|
| 92 |
+
) for dim_out_ in dim_out
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
def _make_upsampler(self, in_channels: int, out_channels: int):
|
| 96 |
+
upsampler = nn.Sequential(
|
| 97 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
| 98 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 99 |
+
)
|
| 100 |
+
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
|
| 101 |
+
return upsampler
|
| 102 |
+
|
| 103 |
+
def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
|
| 104 |
+
return nn.Sequential(
|
| 105 |
+
nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 106 |
+
*(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
|
| 107 |
+
nn.ReLU(inplace=True),
|
| 108 |
+
nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
|
| 112 |
+
img_h, img_w = image.shape[-2:]
|
| 113 |
+
patch_h, patch_w = img_h // 14, img_w // 14
|
| 114 |
+
|
| 115 |
+
# Process the hidden states
|
| 116 |
+
x = torch.stack([
|
| 117 |
+
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
| 118 |
+
for proj, (feat, clstoken) in zip(self.projects, hidden_states)
|
| 119 |
+
], dim=1).sum(dim=1)
|
| 120 |
+
|
| 121 |
+
# Upsample stage
|
| 122 |
+
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
|
| 123 |
+
for i, block in enumerate(self.upsample_blocks):
|
| 124 |
+
# UV coordinates is for awareness of image aspect ratio
|
| 125 |
+
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
| 126 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
| 127 |
+
x = torch.cat([x, uv], dim=1)
|
| 128 |
+
for layer in block:
|
| 129 |
+
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
|
| 130 |
+
|
| 131 |
+
# (patch_h * 8, patch_w * 8) -> (img_h, img_w)
|
| 132 |
+
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
| 133 |
+
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
| 134 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
| 135 |
+
x = torch.cat([x, uv], dim=1)
|
| 136 |
+
|
| 137 |
+
if isinstance(self.output_block, nn.ModuleList):
|
| 138 |
+
output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
|
| 139 |
+
else:
|
| 140 |
+
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
|
| 141 |
+
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class MoGeModel(nn.Module):
|
| 146 |
+
image_mean: torch.Tensor
|
| 147 |
+
image_std: torch.Tensor
|
| 148 |
+
|
| 149 |
+
def __init__(self,
|
| 150 |
+
encoder: str = 'dinov2_vitb14',
|
| 151 |
+
intermediate_layers: Union[int, List[int]] = 4,
|
| 152 |
+
dim_proj: int = 512,
|
| 153 |
+
dim_upsample: List[int] = [256, 128, 128],
|
| 154 |
+
dim_times_res_block_hidden: int = 1,
|
| 155 |
+
num_res_blocks: int = 1,
|
| 156 |
+
remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
|
| 157 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
| 158 |
+
num_tokens_range: Tuple[Number, Number] = [1200, 2500],
|
| 159 |
+
last_res_blocks: int = 0,
|
| 160 |
+
last_conv_channels: int = 32,
|
| 161 |
+
last_conv_size: int = 1,
|
| 162 |
+
mask_threshold: float = 0.5,
|
| 163 |
+
**deprecated_kwargs
|
| 164 |
+
):
|
| 165 |
+
super(MoGeModel, self).__init__()
|
| 166 |
+
|
| 167 |
+
if deprecated_kwargs:
|
| 168 |
+
# Process legacy arguments
|
| 169 |
+
if 'trained_area_range' in deprecated_kwargs:
|
| 170 |
+
num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2]
|
| 171 |
+
del deprecated_kwargs['trained_area_range']
|
| 172 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
| 173 |
+
|
| 174 |
+
self.encoder = encoder
|
| 175 |
+
self.remap_output = remap_output
|
| 176 |
+
self.intermediate_layers = intermediate_layers
|
| 177 |
+
self.num_tokens_range = num_tokens_range
|
| 178 |
+
self.mask_threshold = mask_threshold
|
| 179 |
+
|
| 180 |
+
# NOTE: We have copied the DINOv2 code in torchhub to this repository.
|
| 181 |
+
# Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
|
| 182 |
+
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
|
| 183 |
+
self.backbone = hub_loader(pretrained=False)
|
| 184 |
+
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
|
| 185 |
+
|
| 186 |
+
self.head = Head(
|
| 187 |
+
num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
|
| 188 |
+
dim_in=dim_feature,
|
| 189 |
+
dim_out=[3, 1],
|
| 190 |
+
dim_proj=dim_proj,
|
| 191 |
+
dim_upsample=dim_upsample,
|
| 192 |
+
dim_times_res_block_hidden=dim_times_res_block_hidden,
|
| 193 |
+
num_res_blocks=num_res_blocks,
|
| 194 |
+
res_block_norm=res_block_norm,
|
| 195 |
+
last_res_blocks=last_res_blocks,
|
| 196 |
+
last_conv_channels=last_conv_channels,
|
| 197 |
+
last_conv_size=last_conv_size
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 201 |
+
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 202 |
+
|
| 203 |
+
self.register_buffer("image_mean", image_mean)
|
| 204 |
+
self.register_buffer("image_std", image_std)
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def device(self) -> torch.device:
|
| 208 |
+
return next(self.parameters()).device
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def dtype(self) -> torch.dtype:
|
| 212 |
+
return next(self.parameters()).dtype
|
| 213 |
+
|
| 214 |
+
@classmethod
|
| 215 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
|
| 216 |
+
"""
|
| 217 |
+
Load a model from a checkpoint file.
|
| 218 |
+
|
| 219 |
+
### Parameters:
|
| 220 |
+
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
| 221 |
+
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
| 222 |
+
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
| 223 |
+
|
| 224 |
+
### Returns:
|
| 225 |
+
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
| 226 |
+
"""
|
| 227 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 228 |
+
checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
|
| 229 |
+
else:
|
| 230 |
+
cached_checkpoint_path = hf_hub_download(
|
| 231 |
+
repo_id=pretrained_model_name_or_path,
|
| 232 |
+
repo_type="model",
|
| 233 |
+
filename="model.pt",
|
| 234 |
+
**hf_kwargs
|
| 235 |
+
)
|
| 236 |
+
checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
|
| 237 |
+
model_config = checkpoint['model_config']
|
| 238 |
+
if model_kwargs is not None:
|
| 239 |
+
model_config.update(model_kwargs)
|
| 240 |
+
model = cls(**model_config)
|
| 241 |
+
model.load_state_dict(checkpoint['model'])
|
| 242 |
+
return model
|
| 243 |
+
|
| 244 |
+
def init_weights(self):
|
| 245 |
+
"Load the backbone with pretrained dinov2 weights from torch hub"
|
| 246 |
+
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
|
| 247 |
+
self.backbone.load_state_dict(state_dict)
|
| 248 |
+
|
| 249 |
+
def enable_gradient_checkpointing(self):
|
| 250 |
+
for i in range(len(self.backbone.blocks)):
|
| 251 |
+
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
| 252 |
+
|
| 253 |
+
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
if self.remap_output == 'linear':
|
| 255 |
+
pass
|
| 256 |
+
elif self.remap_output =='sinh':
|
| 257 |
+
points = torch.sinh(points)
|
| 258 |
+
elif self.remap_output == 'exp':
|
| 259 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 260 |
+
z = torch.exp(z)
|
| 261 |
+
points = torch.cat([xy * z, z], dim=-1)
|
| 262 |
+
elif self.remap_output =='sinh_exp':
|
| 263 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 264 |
+
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 265 |
+
else:
|
| 266 |
+
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
| 267 |
+
return points
|
| 268 |
+
|
| 269 |
+
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
| 270 |
+
original_height, original_width = image.shape[-2:]
|
| 271 |
+
|
| 272 |
+
# Resize to expected resolution defined by num_tokens
|
| 273 |
+
resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5
|
| 274 |
+
resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor)
|
| 275 |
+
image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True)
|
| 276 |
+
|
| 277 |
+
# Apply image transformation for DINOv2
|
| 278 |
+
image = (image - self.image_mean) / self.image_std
|
| 279 |
+
image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
|
| 280 |
+
|
| 281 |
+
# Get intermediate layers from the backbone
|
| 282 |
+
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
|
| 283 |
+
|
| 284 |
+
# Predict points (and mask)
|
| 285 |
+
output = self.head(features, image)
|
| 286 |
+
points, mask = output
|
| 287 |
+
|
| 288 |
+
# Make sure fp32 precision for output
|
| 289 |
+
with torch.autocast(device_type=image.device.type, dtype=torch.float32):
|
| 290 |
+
# Resize to original resolution
|
| 291 |
+
points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
|
| 292 |
+
mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
|
| 293 |
+
|
| 294 |
+
# Post-process points and mask
|
| 295 |
+
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 296 |
+
points = self._remap_points(points) # slightly improves the performance in case of very large output values
|
| 297 |
+
|
| 298 |
+
return_dict = {'points': points, 'mask': mask}
|
| 299 |
+
return return_dict
|
| 300 |
+
|
| 301 |
+
@torch.inference_mode()
|
| 302 |
+
def infer(
|
| 303 |
+
self,
|
| 304 |
+
image: torch.Tensor,
|
| 305 |
+
fov_x: Union[Number, torch.Tensor] = None,
|
| 306 |
+
resolution_level: int = 9,
|
| 307 |
+
num_tokens: int = None,
|
| 308 |
+
apply_mask: bool = True,
|
| 309 |
+
force_projection: bool = True,
|
| 310 |
+
use_fp16: bool = True,
|
| 311 |
+
) -> Dict[str, torch.Tensor]:
|
| 312 |
+
"""
|
| 313 |
+
User-friendly inference function
|
| 314 |
+
|
| 315 |
+
### Parameters
|
| 316 |
+
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
|
| 317 |
+
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
|
| 318 |
+
- `resolution_level`: An integer [0-9] for the resolution level for inference.
|
| 319 |
+
The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
|
| 320 |
+
`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
|
| 321 |
+
- `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
|
| 322 |
+
`resolution_level` will be ignored if `num_tokens` is provided. Default: None
|
| 323 |
+
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
|
| 324 |
+
- `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
|
| 325 |
+
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
|
| 326 |
+
|
| 327 |
+
### Returns
|
| 328 |
+
|
| 329 |
+
A dictionary containing the following keys:
|
| 330 |
+
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
| 331 |
+
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
| 332 |
+
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
| 333 |
+
"""
|
| 334 |
+
if image.dim() == 3:
|
| 335 |
+
omit_batch_dim = True
|
| 336 |
+
image = image.unsqueeze(0)
|
| 337 |
+
else:
|
| 338 |
+
omit_batch_dim = False
|
| 339 |
+
image = image.to(dtype=self.dtype, device=self.device)
|
| 340 |
+
|
| 341 |
+
original_height, original_width = image.shape[-2:]
|
| 342 |
+
aspect_ratio = original_width / original_height
|
| 343 |
+
|
| 344 |
+
if num_tokens is None:
|
| 345 |
+
min_tokens, max_tokens = self.num_tokens_range
|
| 346 |
+
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
|
| 347 |
+
|
| 348 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
|
| 349 |
+
output = self.forward(image, num_tokens)
|
| 350 |
+
points, mask = output['points'], output['mask']
|
| 351 |
+
|
| 352 |
+
# Always process the output in fp32 precision
|
| 353 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 354 |
+
points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
|
| 355 |
+
|
| 356 |
+
mask_binary = mask > self.mask_threshold
|
| 357 |
+
|
| 358 |
+
# Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
|
| 359 |
+
if fov_x is None:
|
| 360 |
+
focal, shift = recover_focal_shift(points, mask_binary)
|
| 361 |
+
else:
|
| 362 |
+
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
|
| 363 |
+
if focal.ndim == 0:
|
| 364 |
+
focal = focal[None].expand(points.shape[0])
|
| 365 |
+
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
| 366 |
+
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
|
| 367 |
+
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
| 368 |
+
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
|
| 369 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 370 |
+
|
| 371 |
+
# If projection constraint is forced, recompute the point map using the actual depth map
|
| 372 |
+
if force_projection:
|
| 373 |
+
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
|
| 374 |
+
else:
|
| 375 |
+
points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
|
| 376 |
+
|
| 377 |
+
# Apply mask if needed
|
| 378 |
+
if apply_mask:
|
| 379 |
+
points = torch.where(mask_binary[..., None], points, torch.inf)
|
| 380 |
+
depth = torch.where(mask_binary, depth, torch.inf)
|
| 381 |
+
|
| 382 |
+
return_dict = {
|
| 383 |
+
'points': points,
|
| 384 |
+
'intrinsics': intrinsics,
|
| 385 |
+
'depth': depth,
|
| 386 |
+
'mask': mask_binary,
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
if omit_batch_dim:
|
| 390 |
+
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 391 |
+
|
| 392 |
+
return return_dict
|
moge/model/v2.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
from functools import partial
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils
|
| 11 |
+
import torch.utils.checkpoint
|
| 12 |
+
import torch.amp
|
| 13 |
+
import torch.version
|
| 14 |
+
import utils3d
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
|
| 18 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
| 19 |
+
from .modules import DINOv2Encoder, MLP, ConvStack
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MoGeModel(nn.Module):
|
| 23 |
+
encoder: DINOv2Encoder
|
| 24 |
+
neck: ConvStack
|
| 25 |
+
points_head: ConvStack
|
| 26 |
+
mask_head: ConvStack
|
| 27 |
+
scale_head: MLP
|
| 28 |
+
onnx_compatible_mode: bool
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
encoder: Dict[str, Any],
|
| 32 |
+
neck: Dict[str, Any],
|
| 33 |
+
points_head: Dict[str, Any] = None,
|
| 34 |
+
mask_head: Dict[str, Any] = None,
|
| 35 |
+
normal_head: Dict[str, Any] = None,
|
| 36 |
+
scale_head: Dict[str, Any] = None,
|
| 37 |
+
remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
|
| 38 |
+
num_tokens_range: List[int] = [1200, 3600],
|
| 39 |
+
**deprecated_kwargs
|
| 40 |
+
):
|
| 41 |
+
super(MoGeModel, self).__init__()
|
| 42 |
+
if deprecated_kwargs:
|
| 43 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
| 44 |
+
|
| 45 |
+
self.remap_output = remap_output
|
| 46 |
+
self.num_tokens_range = num_tokens_range
|
| 47 |
+
|
| 48 |
+
self.encoder = DINOv2Encoder(**encoder)
|
| 49 |
+
self.neck = ConvStack(**neck)
|
| 50 |
+
if points_head is not None:
|
| 51 |
+
self.points_head = ConvStack(**points_head)
|
| 52 |
+
if mask_head is not None:
|
| 53 |
+
self.mask_head = ConvStack(**mask_head)
|
| 54 |
+
if normal_head is not None:
|
| 55 |
+
self.normal_head = ConvStack(**normal_head)
|
| 56 |
+
if scale_head is not None:
|
| 57 |
+
self.scale_head = MLP(**scale_head)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def device(self) -> torch.device:
|
| 61 |
+
return next(self.parameters()).device
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def dtype(self) -> torch.dtype:
|
| 65 |
+
return next(self.parameters()).dtype
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def onnx_compatible_mode(self) -> bool:
|
| 69 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 70 |
+
|
| 71 |
+
@onnx_compatible_mode.setter
|
| 72 |
+
def onnx_compatible_mode(self, value: bool):
|
| 73 |
+
self._onnx_compatible_mode = value
|
| 74 |
+
self.encoder.onnx_compatible_mode = value
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
|
| 78 |
+
"""
|
| 79 |
+
Load a model from a checkpoint file.
|
| 80 |
+
|
| 81 |
+
### Parameters:
|
| 82 |
+
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
| 83 |
+
- `compiled`
|
| 84 |
+
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
| 85 |
+
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
| 86 |
+
|
| 87 |
+
### Returns:
|
| 88 |
+
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
| 89 |
+
"""
|
| 90 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 91 |
+
checkpoint_path = pretrained_model_name_or_path
|
| 92 |
+
else:
|
| 93 |
+
checkpoint_path = hf_hub_download(
|
| 94 |
+
repo_id=pretrained_model_name_or_path,
|
| 95 |
+
repo_type="model",
|
| 96 |
+
filename="model.pt",
|
| 97 |
+
**hf_kwargs
|
| 98 |
+
)
|
| 99 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
|
| 100 |
+
|
| 101 |
+
model_config = checkpoint['model_config']
|
| 102 |
+
if model_kwargs is not None:
|
| 103 |
+
model_config.update(model_kwargs)
|
| 104 |
+
model = cls(**model_config)
|
| 105 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 106 |
+
|
| 107 |
+
return model
|
| 108 |
+
|
| 109 |
+
def init_weights(self):
|
| 110 |
+
self.encoder.init_weights()
|
| 111 |
+
|
| 112 |
+
def enable_gradient_checkpointing(self):
|
| 113 |
+
self.encoder.enable_gradient_checkpointing()
|
| 114 |
+
self.neck.enable_gradient_checkpointing()
|
| 115 |
+
for head in ['points_head', 'normal_head', 'mask_head']:
|
| 116 |
+
if hasattr(self, head):
|
| 117 |
+
getattr(self, head).enable_gradient_checkpointing()
|
| 118 |
+
|
| 119 |
+
def enable_pytorch_native_sdpa(self):
|
| 120 |
+
self.encoder.enable_pytorch_native_sdpa()
|
| 121 |
+
|
| 122 |
+
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
if self.remap_output == 'linear':
|
| 124 |
+
pass
|
| 125 |
+
elif self.remap_output =='sinh':
|
| 126 |
+
points = torch.sinh(points)
|
| 127 |
+
elif self.remap_output == 'exp':
|
| 128 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 129 |
+
z = torch.exp(z)
|
| 130 |
+
points = torch.cat([xy * z, z], dim=-1)
|
| 131 |
+
elif self.remap_output =='sinh_exp':
|
| 132 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 133 |
+
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
| 136 |
+
return points
|
| 137 |
+
|
| 138 |
+
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
| 139 |
+
batch_size, _, img_h, img_w = image.shape
|
| 140 |
+
device, dtype = image.device, image.dtype
|
| 141 |
+
|
| 142 |
+
aspect_ratio = img_w / img_h
|
| 143 |
+
base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
|
| 144 |
+
num_tokens = base_h * base_w
|
| 145 |
+
|
| 146 |
+
# Backbones encoding
|
| 147 |
+
features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
|
| 148 |
+
features = [features, None, None, None, None]
|
| 149 |
+
|
| 150 |
+
# Concat UVs for aspect ratio input
|
| 151 |
+
for level in range(5):
|
| 152 |
+
uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
| 153 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 154 |
+
if features[level] is None:
|
| 155 |
+
features[level] = uv
|
| 156 |
+
else:
|
| 157 |
+
features[level] = torch.concat([features[level], uv], dim=1)
|
| 158 |
+
|
| 159 |
+
# Shared neck
|
| 160 |
+
features = self.neck(features)
|
| 161 |
+
|
| 162 |
+
# Heads decoding
|
| 163 |
+
|
| 164 |
+
points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
|
| 165 |
+
metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
|
| 166 |
+
|
| 167 |
+
# Resize
|
| 168 |
+
points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
|
| 169 |
+
|
| 170 |
+
# Remap output
|
| 171 |
+
if points is not None:
|
| 172 |
+
points = points.permute(0, 2, 3, 1)
|
| 173 |
+
points = self._remap_points(points) # slightly improves the performance in case of very large output values
|
| 174 |
+
if normal is not None:
|
| 175 |
+
normal = normal.permute(0, 2, 3, 1)
|
| 176 |
+
normal = F.normalize(normal, dim=-1)
|
| 177 |
+
if mask is not None:
|
| 178 |
+
mask = mask.squeeze(1).sigmoid()
|
| 179 |
+
if metric_scale is not None:
|
| 180 |
+
metric_scale = metric_scale.squeeze(1).exp()
|
| 181 |
+
|
| 182 |
+
return_dict = {
|
| 183 |
+
'points': points,
|
| 184 |
+
'normal': normal,
|
| 185 |
+
'mask': mask,
|
| 186 |
+
'metric_scale': metric_scale
|
| 187 |
+
}
|
| 188 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 189 |
+
|
| 190 |
+
return return_dict
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def infer(
|
| 194 |
+
self,
|
| 195 |
+
image: torch.Tensor,
|
| 196 |
+
num_tokens: int = None,
|
| 197 |
+
resolution_level: int = 9,
|
| 198 |
+
force_projection: bool = True,
|
| 199 |
+
apply_mask: Literal[False, True, 'blend'] = True,
|
| 200 |
+
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
| 201 |
+
use_fp16: bool = True,
|
| 202 |
+
) -> Dict[str, torch.Tensor]:
|
| 203 |
+
"""
|
| 204 |
+
User-friendly inference function
|
| 205 |
+
|
| 206 |
+
### Parameters
|
| 207 |
+
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
|
| 208 |
+
- `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
|
| 209 |
+
More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
|
| 210 |
+
- `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
|
| 211 |
+
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
|
| 212 |
+
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
|
| 213 |
+
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
|
| 214 |
+
|
| 215 |
+
### Returns
|
| 216 |
+
|
| 217 |
+
A dictionary containing the following keys:
|
| 218 |
+
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
| 219 |
+
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
| 220 |
+
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
| 221 |
+
"""
|
| 222 |
+
if image.dim() == 3:
|
| 223 |
+
omit_batch_dim = True
|
| 224 |
+
image = image.unsqueeze(0)
|
| 225 |
+
else:
|
| 226 |
+
omit_batch_dim = False
|
| 227 |
+
image = image.to(dtype=self.dtype, device=self.device)
|
| 228 |
+
|
| 229 |
+
original_height, original_width = image.shape[-2:]
|
| 230 |
+
area = original_height * original_width
|
| 231 |
+
aspect_ratio = original_width / original_height
|
| 232 |
+
|
| 233 |
+
# Determine the number of base tokens to use
|
| 234 |
+
if num_tokens is None:
|
| 235 |
+
min_tokens, max_tokens = self.num_tokens_range
|
| 236 |
+
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
|
| 237 |
+
|
| 238 |
+
# Forward pass
|
| 239 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
|
| 240 |
+
output = self.forward(image, num_tokens=num_tokens)
|
| 241 |
+
points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
|
| 242 |
+
|
| 243 |
+
# Always process the output in fp32 precision
|
| 244 |
+
points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
|
| 245 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 246 |
+
if mask is not None:
|
| 247 |
+
mask_binary = mask > 0.5
|
| 248 |
+
else:
|
| 249 |
+
mask_binary = None
|
| 250 |
+
|
| 251 |
+
if points is not None:
|
| 252 |
+
# Convert affine point map to camera-space. Recover depth and intrinsics from point map.
|
| 253 |
+
# NOTE: Focal here is the focal length relative to half the image diagonal
|
| 254 |
+
if fov_x is None:
|
| 255 |
+
# Recover focal and shift from predicted point map
|
| 256 |
+
focal, shift = recover_focal_shift(points, mask_binary)
|
| 257 |
+
else:
|
| 258 |
+
# Focal is known, recover shift only
|
| 259 |
+
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
|
| 260 |
+
if focal.ndim == 0:
|
| 261 |
+
focal = focal[None].expand(points.shape[0])
|
| 262 |
+
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
| 263 |
+
fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
| 264 |
+
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
|
| 265 |
+
points[..., 2] += shift[..., None, None]
|
| 266 |
+
if mask_binary is not None:
|
| 267 |
+
mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
|
| 268 |
+
depth = points[..., 2].clone()
|
| 269 |
+
else:
|
| 270 |
+
depth, intrinsics = None, None
|
| 271 |
+
|
| 272 |
+
# If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
|
| 273 |
+
if force_projection and depth is not None:
|
| 274 |
+
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
|
| 275 |
+
|
| 276 |
+
# Apply metric scale
|
| 277 |
+
if metric_scale is not None:
|
| 278 |
+
if points is not None:
|
| 279 |
+
points *= metric_scale[:, None, None, None]
|
| 280 |
+
if depth is not None:
|
| 281 |
+
depth *= metric_scale[:, None, None]
|
| 282 |
+
|
| 283 |
+
# Apply mask
|
| 284 |
+
if apply_mask and mask_binary is not None:
|
| 285 |
+
points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
|
| 286 |
+
depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
|
| 287 |
+
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
|
| 288 |
+
|
| 289 |
+
return depth.squeeze().cpu().numpy(), mask_binary.squeeze().cpu().numpy(), intrinsics.squeeze().cpu().numpy()
|
| 290 |
+
|
| 291 |
+
# return_dict = {
|
| 292 |
+
# 'points': points,
|
| 293 |
+
# 'intrinsics': intrinsics,
|
| 294 |
+
# 'depth': depth,
|
| 295 |
+
# 'mask': mask_binary,
|
| 296 |
+
# 'normal': normal
|
| 297 |
+
# }
|
| 298 |
+
# return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 299 |
+
|
| 300 |
+
# if omit_batch_dim:
|
| 301 |
+
# return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 302 |
+
|
| 303 |
+
# return return_dict
|
moge/scripts/__init__.py
ADDED
|
File without changes
|
moge/scripts/app.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
import tempfile
|
| 10 |
+
import itertools
|
| 11 |
+
from typing import *
|
| 12 |
+
import atexit
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
import shutil
|
| 15 |
+
|
| 16 |
+
import click
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@click.command(help='Web demo')
|
| 20 |
+
@click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
|
| 21 |
+
@click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.')
|
| 22 |
+
@click.option('--version', 'model_version', default='v2', help='The version of the model.')
|
| 23 |
+
@click.option('--fp16', 'use_fp16', is_flag=True, help='Whether to use fp16 inference.')
|
| 24 |
+
def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool):
|
| 25 |
+
print("Import modules...")
|
| 26 |
+
# Lazy import
|
| 27 |
+
import cv2
|
| 28 |
+
import torch
|
| 29 |
+
import numpy as np
|
| 30 |
+
import trimesh
|
| 31 |
+
import trimesh.visual
|
| 32 |
+
from PIL import Image
|
| 33 |
+
import gradio as gr
|
| 34 |
+
try:
|
| 35 |
+
import spaces # This is for deployment at huggingface.co/spaces
|
| 36 |
+
HUGGINFACE_SPACES_INSTALLED = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
HUGGINFACE_SPACES_INSTALLED = False
|
| 39 |
+
|
| 40 |
+
import utils3d
|
| 41 |
+
from moge.utils.io import write_normal
|
| 42 |
+
from moge.utils.vis import colorize_depth, colorize_normal
|
| 43 |
+
from moge.model import import_model_class_by_version
|
| 44 |
+
from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
|
| 45 |
+
from moge.utils.tools import timeit
|
| 46 |
+
|
| 47 |
+
print("Load model...")
|
| 48 |
+
if pretrained_model_name_or_path is None:
|
| 49 |
+
DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
|
| 50 |
+
"v1": "Ruicheng/moge-vitl",
|
| 51 |
+
"v2": "Ruicheng/moge-2-vitl-normal",
|
| 52 |
+
}
|
| 53 |
+
pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
|
| 54 |
+
model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval()
|
| 55 |
+
if use_fp16:
|
| 56 |
+
model.half()
|
| 57 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=1)
|
| 58 |
+
|
| 59 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 300):
|
| 60 |
+
def _delete():
|
| 61 |
+
try:
|
| 62 |
+
os.remove(path)
|
| 63 |
+
except FileNotFoundError:
|
| 64 |
+
pass
|
| 65 |
+
def _wait_and_delete():
|
| 66 |
+
time.sleep(delay)
|
| 67 |
+
_delete(path)
|
| 68 |
+
thread_pool_executor.submit(_wait_and_delete)
|
| 69 |
+
atexit.register(_delete)
|
| 70 |
+
|
| 71 |
+
# Inference on GPU.
|
| 72 |
+
@(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x)
|
| 73 |
+
def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]:
|
| 74 |
+
image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255
|
| 75 |
+
output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16)
|
| 76 |
+
output = {k: v.cpu().numpy() for k, v in output.items()}
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
# Full inference pipeline
|
| 80 |
+
def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None):
|
| 81 |
+
larger_size = max(image.shape[:2])
|
| 82 |
+
if larger_size > max_size:
|
| 83 |
+
scale = max_size / larger_size
|
| 84 |
+
image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
| 85 |
+
|
| 86 |
+
height, width = image.shape[:2]
|
| 87 |
+
|
| 88 |
+
resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9)
|
| 89 |
+
output = run_with_gpu(image, resolution_level_int, apply_mask)
|
| 90 |
+
|
| 91 |
+
points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None)
|
| 92 |
+
|
| 93 |
+
if remove_edge:
|
| 94 |
+
mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04)
|
| 95 |
+
else:
|
| 96 |
+
mask_cleaned = mask
|
| 97 |
+
|
| 98 |
+
results = {
|
| 99 |
+
**output,
|
| 100 |
+
'mask_cleaned': mask_cleaned,
|
| 101 |
+
'image': image
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# depth & normal visualization
|
| 105 |
+
depth_vis = colorize_depth(depth)
|
| 106 |
+
if normal is not None:
|
| 107 |
+
normal_vis = colorize_normal(normal)
|
| 108 |
+
else:
|
| 109 |
+
normal_vis = gr.update(label="Normal map (not avalable for this model)")
|
| 110 |
+
|
| 111 |
+
# mesh & pointcloud
|
| 112 |
+
if normal is None:
|
| 113 |
+
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
|
| 114 |
+
points,
|
| 115 |
+
image.astype(np.float32) / 255,
|
| 116 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 117 |
+
mask=mask_cleaned,
|
| 118 |
+
tri=True
|
| 119 |
+
)
|
| 120 |
+
vertex_normals = None
|
| 121 |
+
else:
|
| 122 |
+
faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
|
| 123 |
+
points,
|
| 124 |
+
image.astype(np.float32) / 255,
|
| 125 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 126 |
+
normal,
|
| 127 |
+
mask=mask_cleaned,
|
| 128 |
+
tri=True
|
| 129 |
+
)
|
| 130 |
+
vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
|
| 131 |
+
vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32)
|
| 132 |
+
if vertex_normals is not None:
|
| 133 |
+
vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
|
| 134 |
+
|
| 135 |
+
tempdir = Path(tempfile.gettempdir(), 'moge')
|
| 136 |
+
tempdir.mkdir(exist_ok=True)
|
| 137 |
+
output_path = Path(tempdir, request.session_hash)
|
| 138 |
+
shutil.rmtree(output_path, ignore_errors=True)
|
| 139 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 140 |
+
trimesh.Trimesh(
|
| 141 |
+
vertices=vertices,
|
| 142 |
+
faces=faces,
|
| 143 |
+
visual = trimesh.visual.texture.TextureVisuals(
|
| 144 |
+
uv=vertex_uvs,
|
| 145 |
+
material=trimesh.visual.material.PBRMaterial(
|
| 146 |
+
baseColorTexture=Image.fromarray(image),
|
| 147 |
+
metallicFactor=0.5,
|
| 148 |
+
roughnessFactor=1.0
|
| 149 |
+
)
|
| 150 |
+
),
|
| 151 |
+
vertex_normals=vertex_normals,
|
| 152 |
+
process=False
|
| 153 |
+
).export(output_path / 'mesh.glb')
|
| 154 |
+
pointcloud = trimesh.PointCloud(
|
| 155 |
+
vertices=vertices,
|
| 156 |
+
colors=vertex_colors,
|
| 157 |
+
)
|
| 158 |
+
pointcloud.vertex_normals = vertex_normals
|
| 159 |
+
pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True)
|
| 160 |
+
trimesh.PointCloud(
|
| 161 |
+
vertices=vertices,
|
| 162 |
+
colors=vertex_colors,
|
| 163 |
+
).export(output_path / 'pointcloud.glb', include_normals=True)
|
| 164 |
+
cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255)
|
| 165 |
+
cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 166 |
+
cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 167 |
+
if normal is not None:
|
| 168 |
+
cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
| 169 |
+
|
| 170 |
+
files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png']
|
| 171 |
+
if normal is not None:
|
| 172 |
+
files.append('normal.exr')
|
| 173 |
+
|
| 174 |
+
for f in files:
|
| 175 |
+
delete_later(output_path / f)
|
| 176 |
+
|
| 177 |
+
# FOV
|
| 178 |
+
intrinsics = results['intrinsics']
|
| 179 |
+
fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
|
| 180 |
+
fov_x, fov_y = np.rad2deg([fov_x, fov_y])
|
| 181 |
+
|
| 182 |
+
# messages
|
| 183 |
+
viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.'
|
| 184 |
+
if resolution_level != 'Ultra':
|
| 185 |
+
depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.'
|
| 186 |
+
else:
|
| 187 |
+
depth_message = ""
|
| 188 |
+
|
| 189 |
+
return (
|
| 190 |
+
results,
|
| 191 |
+
depth_vis,
|
| 192 |
+
normal_vis,
|
| 193 |
+
output_path / 'pointcloud.glb',
|
| 194 |
+
[(output_path / f).as_posix() for f in files if (output_path / f).exists()],
|
| 195 |
+
f'- **Horizontal FOV: {fov_x:.1f}°**. \n - **Vertical FOV: {fov_y:.1f}°**',
|
| 196 |
+
viewer_message,
|
| 197 |
+
depth_message
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def reset_measure(results: Dict[str, np.ndarray]):
|
| 201 |
+
return [results['image'], [], ""]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData):
|
| 205 |
+
point2d = event.index[0], event.index[1]
|
| 206 |
+
measure_points.append(point2d)
|
| 207 |
+
|
| 208 |
+
image = results['image'].copy()
|
| 209 |
+
for p in measure_points:
|
| 210 |
+
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
|
| 211 |
+
|
| 212 |
+
depth_text = ""
|
| 213 |
+
for i, p in enumerate(measure_points):
|
| 214 |
+
d = results['depth'][p[1], p[0]]
|
| 215 |
+
depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
|
| 216 |
+
|
| 217 |
+
if len(measure_points) == 2:
|
| 218 |
+
point1, point2 = measure_points
|
| 219 |
+
image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
|
| 220 |
+
distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]])
|
| 221 |
+
measure_points = []
|
| 222 |
+
|
| 223 |
+
distance_text = f"- **Distance: {distance:.2f}m**"
|
| 224 |
+
|
| 225 |
+
text = depth_text + distance_text
|
| 226 |
+
return [image, measure_points, text]
|
| 227 |
+
else:
|
| 228 |
+
return [image, measure_points, depth_text]
|
| 229 |
+
|
| 230 |
+
print("Create Gradio app...")
|
| 231 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 232 |
+
gr.Markdown(
|
| 233 |
+
f'''
|
| 234 |
+
<div align="center">
|
| 235 |
+
<h1> Turn a 2D image into 3D with MoGe <a title="Github" href="https://github.com/microsoft/MoGe" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/MoGe?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1>
|
| 236 |
+
</div>
|
| 237 |
+
''')
|
| 238 |
+
results = gr.State(value=None)
|
| 239 |
+
measure_points = gr.State(value=[])
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
with gr.Column():
|
| 243 |
+
input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image")
|
| 244 |
+
with gr.Accordion(label="Settings", open=False):
|
| 245 |
+
max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048)
|
| 246 |
+
resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High')
|
| 247 |
+
apply_mask = gr.Checkbox(value=True, label="Apply mask")
|
| 248 |
+
remove_edges = gr.Checkbox(value=True, label="Remove edges")
|
| 249 |
+
submit_btn = gr.Button("Submit", variant='primary')
|
| 250 |
+
|
| 251 |
+
with gr.Column():
|
| 252 |
+
with gr.Tabs():
|
| 253 |
+
with gr.Tab("3D View"):
|
| 254 |
+
viewer_message = gr.Markdown("")
|
| 255 |
+
model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh")
|
| 256 |
+
fov = gr.Markdown()
|
| 257 |
+
with gr.Tab("Depth"):
|
| 258 |
+
depth_message = gr.Markdown("")
|
| 259 |
+
depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False)
|
| 260 |
+
with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')):
|
| 261 |
+
normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False)
|
| 262 |
+
with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')):
|
| 263 |
+
gr.Markdown("### Click on the image to measure the distance between two points. \n"
|
| 264 |
+
"**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).")
|
| 265 |
+
measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[])
|
| 266 |
+
measure_text = gr.Markdown("")
|
| 267 |
+
with gr.Tab("Download"):
|
| 268 |
+
files = gr.File(type='filepath', label="Output Files")
|
| 269 |
+
|
| 270 |
+
if Path('example_images').exists():
|
| 271 |
+
example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']])))
|
| 272 |
+
examples = gr.Examples(
|
| 273 |
+
examples = example_image_paths,
|
| 274 |
+
inputs=input_image,
|
| 275 |
+
label="Examples"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
submit_btn.click(
|
| 279 |
+
fn=lambda: [None, None, None, None, None, "", "", ""],
|
| 280 |
+
outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
|
| 281 |
+
).then(
|
| 282 |
+
fn=run,
|
| 283 |
+
inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges],
|
| 284 |
+
outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
|
| 285 |
+
).then(
|
| 286 |
+
fn=reset_measure,
|
| 287 |
+
inputs=[results],
|
| 288 |
+
outputs=[measure_image, measure_points, measure_text]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
measure_image.select(
|
| 292 |
+
fn=measure,
|
| 293 |
+
inputs=[results, measure_points],
|
| 294 |
+
outputs=[measure_image, measure_points, measure_text]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
demo.launch(share=share)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == '__main__':
|
| 301 |
+
main()
|
moge/scripts/cli.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@click.group(help='MoGe command line interface.')
|
| 12 |
+
def cli():
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data
|
| 17 |
+
cli.add_command(app.main, name='app')
|
| 18 |
+
cli.add_command(infer.main, name='infer')
|
| 19 |
+
cli.add_command(infer_baseline.main, name='infer_baseline')
|
| 20 |
+
cli.add_command(infer_panorama.main, name='infer_panorama')
|
| 21 |
+
cli.add_command(eval_baseline.main, name='eval_baseline')
|
| 22 |
+
cli.add_command(vis_data.main, name='vis_data')
|
| 23 |
+
cli()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if __name__ == '__main__':
|
| 27 |
+
main()
|
moge/scripts/eval_baseline.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 5 |
+
sys.path.insert(0, _package_root)
|
| 6 |
+
import json
|
| 7 |
+
from typing import *
|
| 8 |
+
import importlib
|
| 9 |
+
import importlib.util
|
| 10 |
+
|
| 11 |
+
import click
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.')
|
| 15 |
+
@click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.')
|
| 16 |
+
@click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. '
|
| 17 |
+
'Defaults to "configs/eval/all_benchmarks.json".')
|
| 18 |
+
@click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.')
|
| 19 |
+
@click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.')
|
| 20 |
+
@click.option('--dump_pred', is_flag=True, help='Dump predition results.')
|
| 21 |
+
@click.option('--dump_gt', is_flag=True, help='Dump ground truth.')
|
| 22 |
+
@click.pass_context
|
| 23 |
+
def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool):
|
| 24 |
+
# Lazy import
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
import utils3d
|
| 31 |
+
|
| 32 |
+
from moge.test.baseline import MGEBaselineInterface
|
| 33 |
+
from moge.test.dataloader import EvalDataLoaderPipeline
|
| 34 |
+
from moge.test.metrics import compute_metrics
|
| 35 |
+
from moge.utils.geometry_torch import intrinsics_to_fov
|
| 36 |
+
from moge.utils.vis import colorize_depth, colorize_normal
|
| 37 |
+
from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
|
| 38 |
+
|
| 39 |
+
# Load the baseline model
|
| 40 |
+
module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
|
| 41 |
+
baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
|
| 42 |
+
baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
|
| 43 |
+
|
| 44 |
+
# Load the evaluation configurations
|
| 45 |
+
with open(config_path, 'r') as f:
|
| 46 |
+
config = json.load(f)
|
| 47 |
+
|
| 48 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
all_metrics = {}
|
| 50 |
+
# Iterate over the dataset
|
| 51 |
+
for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'):
|
| 52 |
+
filenames, metrics_list = [], []
|
| 53 |
+
with (
|
| 54 |
+
EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe,
|
| 55 |
+
tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar
|
| 56 |
+
):
|
| 57 |
+
# Iterate over the samples in the dataset
|
| 58 |
+
for i in range(len(eval_data_pipe)):
|
| 59 |
+
sample = eval_data_pipe.get()
|
| 60 |
+
sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()}
|
| 61 |
+
image = sample['image']
|
| 62 |
+
gt_intrinsics = sample['intrinsics']
|
| 63 |
+
|
| 64 |
+
# Inference
|
| 65 |
+
torch.cuda.synchronize()
|
| 66 |
+
with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer:
|
| 67 |
+
if oracle_mode:
|
| 68 |
+
pred = baseline.infer_for_evaluation(image, gt_intrinsics)
|
| 69 |
+
else:
|
| 70 |
+
pred = baseline.infer_for_evaluation(image)
|
| 71 |
+
torch.cuda.synchronize()
|
| 72 |
+
|
| 73 |
+
# Compute metrics
|
| 74 |
+
metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt)
|
| 75 |
+
metrics['inference_time'] = timer.time
|
| 76 |
+
metrics_list.append(metrics)
|
| 77 |
+
|
| 78 |
+
# Dump results
|
| 79 |
+
dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', ''))
|
| 80 |
+
if dump_pred:
|
| 81 |
+
dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True)
|
| 82 |
+
cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
| 83 |
+
|
| 84 |
+
with Path(dump_path, 'pred', 'metrics.json').open('w') as f:
|
| 85 |
+
json.dump(metrics, f, indent=4)
|
| 86 |
+
|
| 87 |
+
if 'pred_points' in misc:
|
| 88 |
+
points = misc['pred_points'].cpu().numpy()
|
| 89 |
+
cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 90 |
+
|
| 91 |
+
if 'pred_depth' in misc:
|
| 92 |
+
depth = misc['pred_depth'].cpu().numpy()
|
| 93 |
+
if 'mask' in pred:
|
| 94 |
+
mask = pred['mask'].cpu().numpy()
|
| 95 |
+
depth = np.where(mask, depth, np.inf)
|
| 96 |
+
cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
|
| 97 |
+
|
| 98 |
+
if 'mask' in pred:
|
| 99 |
+
mask = pred['mask'].cpu().numpy()
|
| 100 |
+
cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8))
|
| 101 |
+
|
| 102 |
+
if 'normal' in pred:
|
| 103 |
+
normal = pred['normal'].cpu().numpy()
|
| 104 |
+
cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
|
| 105 |
+
|
| 106 |
+
if 'intrinsics' in pred:
|
| 107 |
+
intrinsics = pred['intrinsics']
|
| 108 |
+
fov_x, fov_y = intrinsics_to_fov(intrinsics)
|
| 109 |
+
with open(dump_path / 'pred' / 'fov.json', 'w') as f:
|
| 110 |
+
json.dump({
|
| 111 |
+
'fov_x': np.rad2deg(fov_x.item()),
|
| 112 |
+
'fov_y': np.rad2deg(fov_y.item()),
|
| 113 |
+
'intrinsics': intrinsics.cpu().numpy().tolist(),
|
| 114 |
+
}, f)
|
| 115 |
+
|
| 116 |
+
if dump_gt:
|
| 117 |
+
dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True)
|
| 118 |
+
cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
| 119 |
+
|
| 120 |
+
if 'points' in sample:
|
| 121 |
+
points = sample['points']
|
| 122 |
+
cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 123 |
+
|
| 124 |
+
if 'depth' in sample:
|
| 125 |
+
depth = sample['depth']
|
| 126 |
+
mask = sample['depth_mask']
|
| 127 |
+
cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR))
|
| 128 |
+
|
| 129 |
+
if 'normal' in sample:
|
| 130 |
+
normal = sample['normal']
|
| 131 |
+
cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR))
|
| 132 |
+
|
| 133 |
+
if 'depth_mask' in sample:
|
| 134 |
+
mask = sample['depth_mask']
|
| 135 |
+
cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8))
|
| 136 |
+
|
| 137 |
+
if 'intrinsics' in sample:
|
| 138 |
+
intrinsics = sample['intrinsics']
|
| 139 |
+
fov_x, fov_y = intrinsics_to_fov(intrinsics)
|
| 140 |
+
with open(dump_path / 'gt' / 'info.json', 'w') as f:
|
| 141 |
+
json.dump({
|
| 142 |
+
'fov_x': np.rad2deg(fov_x.item()),
|
| 143 |
+
'fov_y': np.rad2deg(fov_y.item()),
|
| 144 |
+
'intrinsics': intrinsics.cpu().numpy().tolist(),
|
| 145 |
+
}, f)
|
| 146 |
+
|
| 147 |
+
# Save intermediate results
|
| 148 |
+
if i % 100 == 0 or i == len(eval_data_pipe) - 1:
|
| 149 |
+
Path(output_path).write_text(
|
| 150 |
+
json.dumps({
|
| 151 |
+
**all_metrics,
|
| 152 |
+
benchmark_name: key_average(metrics_list)
|
| 153 |
+
}, indent=4)
|
| 154 |
+
)
|
| 155 |
+
pbar.update(1)
|
| 156 |
+
|
| 157 |
+
all_metrics[benchmark_name] = key_average(metrics_list)
|
| 158 |
+
|
| 159 |
+
# Save final results
|
| 160 |
+
all_metrics['mean'] = key_average(list(all_metrics.values()))
|
| 161 |
+
Path(output_path).write_text(json.dumps(all_metrics, indent=4))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == '__main__':
|
| 165 |
+
main()
|
moge/scripts/infer.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
from typing import *
|
| 8 |
+
import itertools
|
| 9 |
+
import json
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import click
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@click.command(help='Inference script')
|
| 17 |
+
@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.')
|
| 18 |
+
@click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.')
|
| 19 |
+
@click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path')
|
| 20 |
+
@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default=None, help='Pretrained model name or path. If not provided, the corresponding default model will be chosen.')
|
| 21 |
+
@click.option('--version', 'model_version', type=click.Choice(['v1', 'v2']), default='v2', help='Model version. Defaults to "v2"')
|
| 22 |
+
@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
|
| 23 |
+
@click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for much faster inference.')
|
| 24 |
+
@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
|
| 25 |
+
@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \
|
| 26 |
+
Higher value means more tokens and the finer details will be captured, but inference can be slower. \
|
| 27 |
+
Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \
|
| 28 |
+
`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.')
|
| 29 |
+
@click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \
|
| 30 |
+
`resolution_level` will be ignored if `num_tokens` is provided. Default: None')
|
| 31 |
+
@click.option('--threshold', type=float, default=0.04, help='Threshold for removing edges. Defaults to 0.01. Smaller value removes more edges. "inf" means no thresholding.')
|
| 32 |
+
@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps (image, point map, depth map, normal map, mask) and fov.')
|
| 33 |
+
@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
|
| 34 |
+
@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
|
| 35 |
+
@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
|
| 36 |
+
def main(
|
| 37 |
+
input_path: str,
|
| 38 |
+
fov_x_: float,
|
| 39 |
+
output_path: str,
|
| 40 |
+
pretrained_model_name_or_path: str,
|
| 41 |
+
model_version: str,
|
| 42 |
+
device_name: str,
|
| 43 |
+
use_fp16: bool,
|
| 44 |
+
resize_to: int,
|
| 45 |
+
resolution_level: int,
|
| 46 |
+
num_tokens: int,
|
| 47 |
+
threshold: float,
|
| 48 |
+
save_maps_: bool,
|
| 49 |
+
save_glb_: bool,
|
| 50 |
+
save_ply_: bool,
|
| 51 |
+
show: bool,
|
| 52 |
+
):
|
| 53 |
+
import cv2
|
| 54 |
+
import numpy as np
|
| 55 |
+
import torch
|
| 56 |
+
from PIL import Image
|
| 57 |
+
from tqdm import tqdm
|
| 58 |
+
import trimesh
|
| 59 |
+
import trimesh.visual
|
| 60 |
+
import click
|
| 61 |
+
|
| 62 |
+
from moge.model import import_model_class_by_version
|
| 63 |
+
from moge.utils.io import save_glb, save_ply
|
| 64 |
+
from moge.utils.vis import colorize_depth, colorize_normal
|
| 65 |
+
from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
|
| 66 |
+
import utils3d
|
| 67 |
+
|
| 68 |
+
device = torch.device(device_name)
|
| 69 |
+
|
| 70 |
+
include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
|
| 71 |
+
if Path(input_path).is_dir():
|
| 72 |
+
image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
|
| 73 |
+
else:
|
| 74 |
+
image_paths = [Path(input_path)]
|
| 75 |
+
|
| 76 |
+
if len(image_paths) == 0:
|
| 77 |
+
raise FileNotFoundError(f'No image files found in {input_path}')
|
| 78 |
+
|
| 79 |
+
if pretrained_model_name_or_path is None:
|
| 80 |
+
DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
|
| 81 |
+
"v1": "Ruicheng/moge-vitl",
|
| 82 |
+
"v2": "Ruicheng/moge-2-vitl-normal",
|
| 83 |
+
}
|
| 84 |
+
pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
|
| 85 |
+
model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval()
|
| 86 |
+
if use_fp16:
|
| 87 |
+
model.half()
|
| 88 |
+
|
| 89 |
+
if not any([save_maps_, save_glb_, save_ply_]):
|
| 90 |
+
warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
|
| 91 |
+
save_maps_ = save_glb_ = save_ply_ = True
|
| 92 |
+
|
| 93 |
+
for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
|
| 94 |
+
image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
|
| 95 |
+
height, width = image.shape[:2]
|
| 96 |
+
if resize_to is not None:
|
| 97 |
+
height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
|
| 98 |
+
image = cv2.resize(image, (width, height), cv2.INTER_AREA)
|
| 99 |
+
image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
|
| 100 |
+
|
| 101 |
+
# Inference
|
| 102 |
+
output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16)
|
| 103 |
+
points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy()
|
| 104 |
+
normal = output['normal'].cpu().numpy() if 'normal' in output else None
|
| 105 |
+
|
| 106 |
+
save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
|
| 107 |
+
save_path.mkdir(exist_ok=True, parents=True)
|
| 108 |
+
|
| 109 |
+
# Save images / maps
|
| 110 |
+
if save_maps_:
|
| 111 |
+
cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
| 112 |
+
cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
|
| 113 |
+
cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 114 |
+
cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8))
|
| 115 |
+
cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 116 |
+
if normal is not None:
|
| 117 |
+
cv2.imwrite(str(save_path / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
|
| 118 |
+
fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
|
| 119 |
+
with open(save_path / 'fov.json', 'w') as f:
|
| 120 |
+
json.dump({
|
| 121 |
+
'fov_x': round(float(np.rad2deg(fov_x)), 2),
|
| 122 |
+
'fov_y': round(float(np.rad2deg(fov_y)), 2),
|
| 123 |
+
}, f)
|
| 124 |
+
|
| 125 |
+
# Export mesh & visulization
|
| 126 |
+
if save_glb_ or save_ply_ or show:
|
| 127 |
+
mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=threshold)
|
| 128 |
+
if normal is None:
|
| 129 |
+
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
|
| 130 |
+
points,
|
| 131 |
+
image.astype(np.float32) / 255,
|
| 132 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 133 |
+
mask=mask_cleaned,
|
| 134 |
+
tri=True
|
| 135 |
+
)
|
| 136 |
+
vertex_normals = None
|
| 137 |
+
else:
|
| 138 |
+
faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
|
| 139 |
+
points,
|
| 140 |
+
image.astype(np.float32) / 255,
|
| 141 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 142 |
+
normal,
|
| 143 |
+
mask=mask_cleaned,
|
| 144 |
+
tri=True
|
| 145 |
+
)
|
| 146 |
+
# When exporting the model, follow the OpenGL coordinate conventions:
|
| 147 |
+
# - world coordinate system: x right, y up, z backward.
|
| 148 |
+
# - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
|
| 149 |
+
vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
|
| 150 |
+
if normal is not None:
|
| 151 |
+
vertex_normals = vertex_normals * [1, -1, -1]
|
| 152 |
+
|
| 153 |
+
if save_glb_:
|
| 154 |
+
save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image, vertex_normals)
|
| 155 |
+
|
| 156 |
+
if save_ply_:
|
| 157 |
+
save_ply(save_path / 'pointcloud.ply', vertices, np.zeros((0, 3), dtype=np.int32), vertex_colors, vertex_normals)
|
| 158 |
+
|
| 159 |
+
if show:
|
| 160 |
+
trimesh.Trimesh(
|
| 161 |
+
vertices=vertices,
|
| 162 |
+
vertex_colors=vertex_colors,
|
| 163 |
+
vertex_normals=vertex_normals,
|
| 164 |
+
faces=faces,
|
| 165 |
+
process=False
|
| 166 |
+
).show()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
moge/scripts/infer_baseline.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import *
|
| 10 |
+
import itertools
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
import click
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods')
|
| 17 |
+
@click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.')
|
| 18 |
+
@click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder')
|
| 19 |
+
@click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder')
|
| 20 |
+
@click.option('--size', 'image_size', type=int, default=None, help='Resize input image')
|
| 21 |
+
@click.option('--skip', is_flag=True, help='Skip existing output')
|
| 22 |
+
@click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps')
|
| 23 |
+
@click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format')
|
| 24 |
+
@click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format')
|
| 25 |
+
@click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh')
|
| 26 |
+
@click.pass_context
|
| 27 |
+
def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float):
|
| 28 |
+
# Lazy import
|
| 29 |
+
import cv2
|
| 30 |
+
import numpy as np
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
import torch
|
| 33 |
+
import utils3d
|
| 34 |
+
|
| 35 |
+
from moge.utils.io import save_ply, save_glb
|
| 36 |
+
from moge.utils.geometry_numpy import intrinsics_to_fov_numpy
|
| 37 |
+
from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity
|
| 38 |
+
from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
|
| 39 |
+
from moge.test.baseline import MGEBaselineInterface
|
| 40 |
+
|
| 41 |
+
# Load the baseline model
|
| 42 |
+
module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
|
| 43 |
+
baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
|
| 44 |
+
baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
|
| 45 |
+
|
| 46 |
+
# Input images list
|
| 47 |
+
include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
|
| 48 |
+
if Path(input_path).is_dir():
|
| 49 |
+
image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
|
| 50 |
+
else:
|
| 51 |
+
image_paths = [Path(input_path)]
|
| 52 |
+
|
| 53 |
+
if not any([save_maps_, save_glb_, save_ply_]):
|
| 54 |
+
warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.')
|
| 55 |
+
save_maps_ = True
|
| 56 |
+
|
| 57 |
+
for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
|
| 58 |
+
# Load one image at a time
|
| 59 |
+
image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
|
| 60 |
+
height, width = image_np.shape[:2]
|
| 61 |
+
if image_size is not None and max(image_np.shape[:2]) > image_size:
|
| 62 |
+
height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height))
|
| 63 |
+
image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA)
|
| 64 |
+
image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device)
|
| 65 |
+
|
| 66 |
+
# Inference
|
| 67 |
+
torch.cuda.synchronize()
|
| 68 |
+
with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)):
|
| 69 |
+
output = baseline.infer(image)
|
| 70 |
+
torch.cuda.synchronize()
|
| 71 |
+
|
| 72 |
+
inference_time = timer.average_time
|
| 73 |
+
pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'})
|
| 74 |
+
|
| 75 |
+
# Save the output
|
| 76 |
+
save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
|
| 77 |
+
if skip and save_path.exists():
|
| 78 |
+
continue
|
| 79 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
if save_maps_:
|
| 82 |
+
cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
|
| 83 |
+
|
| 84 |
+
if 'mask' in output:
|
| 85 |
+
mask = output['mask'].cpu().numpy()
|
| 86 |
+
cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8))
|
| 87 |
+
|
| 88 |
+
for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']:
|
| 89 |
+
if k in output:
|
| 90 |
+
points = output[k].cpu().numpy()
|
| 91 |
+
cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 92 |
+
|
| 93 |
+
for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']:
|
| 94 |
+
if k in output:
|
| 95 |
+
depth = output[k].cpu().numpy()
|
| 96 |
+
cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 97 |
+
if k in ['depth_metric', 'depth_scale_invariant']:
|
| 98 |
+
depth_vis = colorize_depth(depth)
|
| 99 |
+
elif k == 'depth_affine_invariant':
|
| 100 |
+
depth_vis = colorize_depth_affine(depth)
|
| 101 |
+
elif k == 'disparity_affine_invariant':
|
| 102 |
+
depth_vis = colorize_disparity(depth)
|
| 103 |
+
cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR))
|
| 104 |
+
|
| 105 |
+
if 'intrinsics' in output:
|
| 106 |
+
intrinsics = output['intrinsics'].cpu().numpy()
|
| 107 |
+
fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics)
|
| 108 |
+
with open(save_path / 'fov.json', 'w') as f:
|
| 109 |
+
json.dump({
|
| 110 |
+
'fov_x': float(np.rad2deg(fov_x)),
|
| 111 |
+
'fov_y': float(np.rad2deg(fov_y)),
|
| 112 |
+
'intrinsics': intrinsics.tolist()
|
| 113 |
+
}, f, indent=4)
|
| 114 |
+
|
| 115 |
+
# Export mesh & visulization
|
| 116 |
+
if save_ply_ or save_glb_:
|
| 117 |
+
assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output'
|
| 118 |
+
points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy()
|
| 119 |
+
mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool)
|
| 120 |
+
normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask)
|
| 121 |
+
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
|
| 122 |
+
points,
|
| 123 |
+
image_np.astype(np.float32) / 255,
|
| 124 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 125 |
+
mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)),
|
| 126 |
+
tri=True
|
| 127 |
+
)
|
| 128 |
+
# When exporting the model, follow the OpenGL coordinate conventions:
|
| 129 |
+
# - world coordinate system: x right, y up, z backward.
|
| 130 |
+
# - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
|
| 131 |
+
vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
|
| 132 |
+
|
| 133 |
+
if save_glb_:
|
| 134 |
+
save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np)
|
| 135 |
+
|
| 136 |
+
if save_ply_:
|
| 137 |
+
save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
|
| 138 |
+
|
| 139 |
+
if __name__ == '__main__':
|
| 140 |
+
main()
|
moge/scripts/infer_panorama.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
from typing import *
|
| 8 |
+
import itertools
|
| 9 |
+
import json
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
import click
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@click.command(help='Inference script for panorama images')
|
| 16 |
+
@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.')
|
| 17 |
+
@click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path')
|
| 18 |
+
@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"')
|
| 19 |
+
@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
|
| 20 |
+
@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
|
| 21 |
+
@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.')
|
| 22 |
+
@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.')
|
| 23 |
+
@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.')
|
| 24 |
+
@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.')
|
| 25 |
+
@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).')
|
| 26 |
+
@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
|
| 27 |
+
@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
|
| 28 |
+
@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
|
| 29 |
+
def main(
|
| 30 |
+
input_path: str,
|
| 31 |
+
output_path: str,
|
| 32 |
+
pretrained_model_name_or_path: str,
|
| 33 |
+
device_name: str,
|
| 34 |
+
resize_to: int,
|
| 35 |
+
resolution_level: int,
|
| 36 |
+
threshold: float,
|
| 37 |
+
batch_size: int,
|
| 38 |
+
save_splitted: bool,
|
| 39 |
+
save_maps_: bool,
|
| 40 |
+
save_glb_: bool,
|
| 41 |
+
save_ply_: bool,
|
| 42 |
+
show: bool,
|
| 43 |
+
):
|
| 44 |
+
# Lazy import
|
| 45 |
+
import cv2
|
| 46 |
+
import numpy as np
|
| 47 |
+
from numpy import ndarray
|
| 48 |
+
import torch
|
| 49 |
+
from PIL import Image
|
| 50 |
+
from tqdm import tqdm, trange
|
| 51 |
+
import trimesh
|
| 52 |
+
import trimesh.visual
|
| 53 |
+
from scipy.sparse import csr_array, hstack, vstack
|
| 54 |
+
from scipy.ndimage import convolve
|
| 55 |
+
from scipy.sparse.linalg import lsmr
|
| 56 |
+
|
| 57 |
+
import utils3d
|
| 58 |
+
from moge.model.v1 import MoGeModel
|
| 59 |
+
from moge.utils.io import save_glb, save_ply
|
| 60 |
+
from moge.utils.vis import colorize_depth
|
| 61 |
+
from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
device = torch.device(device_name)
|
| 65 |
+
|
| 66 |
+
include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
|
| 67 |
+
if Path(input_path).is_dir():
|
| 68 |
+
image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
|
| 69 |
+
else:
|
| 70 |
+
image_paths = [Path(input_path)]
|
| 71 |
+
|
| 72 |
+
if len(image_paths) == 0:
|
| 73 |
+
raise FileNotFoundError(f'No image files found in {input_path}')
|
| 74 |
+
|
| 75 |
+
# Write outputs
|
| 76 |
+
if not any([save_maps_, save_glb_, save_ply_]):
|
| 77 |
+
warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
|
| 78 |
+
save_maps_ = save_glb_ = save_ply_ = True
|
| 79 |
+
|
| 80 |
+
model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval()
|
| 81 |
+
|
| 82 |
+
for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)):
|
| 83 |
+
image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
|
| 84 |
+
height, width = image.shape[:2]
|
| 85 |
+
if resize_to is not None:
|
| 86 |
+
height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
|
| 87 |
+
image = cv2.resize(image, (width, height), cv2.INTER_AREA)
|
| 88 |
+
|
| 89 |
+
splitted_extrinsics, splitted_intriniscs = get_panorama_cameras()
|
| 90 |
+
splitted_resolution = 512
|
| 91 |
+
splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution)
|
| 92 |
+
|
| 93 |
+
# Infer each view
|
| 94 |
+
print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
|
| 95 |
+
|
| 96 |
+
splitted_distance_maps, splitted_masks = [], []
|
| 97 |
+
for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False):
|
| 98 |
+
image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
|
| 99 |
+
fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size])))
|
| 100 |
+
fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device)
|
| 101 |
+
output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False)
|
| 102 |
+
distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy()
|
| 103 |
+
splitted_distance_maps.extend(list(distance_map))
|
| 104 |
+
splitted_masks.extend(list(mask))
|
| 105 |
+
|
| 106 |
+
# Save splitted
|
| 107 |
+
if save_splitted:
|
| 108 |
+
splitted_save_path = Path(output_path, image_path.stem, 'splitted')
|
| 109 |
+
splitted_save_path.mkdir(exist_ok=True, parents=True)
|
| 110 |
+
for i in range(len(splitted_images)):
|
| 111 |
+
cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR))
|
| 112 |
+
cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR))
|
| 113 |
+
|
| 114 |
+
# Merge
|
| 115 |
+
print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging')
|
| 116 |
+
|
| 117 |
+
merging_width, merging_height = min(1920, width), min(960, height)
|
| 118 |
+
panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs)
|
| 119 |
+
panorama_depth = panorama_depth.astype(np.float32)
|
| 120 |
+
panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR)
|
| 121 |
+
panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0
|
| 122 |
+
points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height))
|
| 123 |
+
|
| 124 |
+
# Write outputs
|
| 125 |
+
print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
|
| 126 |
+
save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
|
| 127 |
+
save_path.mkdir(exist_ok=True, parents=True)
|
| 128 |
+
if save_maps_:
|
| 129 |
+
cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
| 130 |
+
cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR))
|
| 131 |
+
cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 132 |
+
cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 133 |
+
cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8))
|
| 134 |
+
|
| 135 |
+
# Export mesh & visulization
|
| 136 |
+
if save_glb_ or save_ply_ or show:
|
| 137 |
+
normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask)
|
| 138 |
+
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
|
| 139 |
+
points,
|
| 140 |
+
image.astype(np.float32) / 255,
|
| 141 |
+
utils3d.numpy.image_uv(width=width, height=height),
|
| 142 |
+
mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)),
|
| 143 |
+
tri=True
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if save_glb_:
|
| 147 |
+
save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image)
|
| 148 |
+
|
| 149 |
+
if save_ply_:
|
| 150 |
+
save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
|
| 151 |
+
|
| 152 |
+
if show:
|
| 153 |
+
trimesh.Trimesh(
|
| 154 |
+
vertices=vertices,
|
| 155 |
+
vertex_colors=vertex_colors,
|
| 156 |
+
faces=faces,
|
| 157 |
+
process=False
|
| 158 |
+
).show()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
main()
|
moge/scripts/train.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 5 |
+
sys.path.insert(0, _package_root)
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
import random
|
| 9 |
+
from typing import *
|
| 10 |
+
import itertools
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
import io
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import cv2
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torch.version
|
| 22 |
+
import accelerate
|
| 23 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 24 |
+
from accelerate.utils import set_seed
|
| 25 |
+
import utils3d
|
| 26 |
+
import click
|
| 27 |
+
from tqdm import tqdm, trange
|
| 28 |
+
import mlflow
|
| 29 |
+
torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled
|
| 30 |
+
|
| 31 |
+
from moge.train.dataloader import TrainDataLoaderPipeline
|
| 32 |
+
from moge.train.losses import (
|
| 33 |
+
affine_invariant_global_loss,
|
| 34 |
+
affine_invariant_local_loss,
|
| 35 |
+
edge_loss,
|
| 36 |
+
normal_loss,
|
| 37 |
+
mask_l2_loss,
|
| 38 |
+
mask_bce_loss,
|
| 39 |
+
monitoring,
|
| 40 |
+
)
|
| 41 |
+
from moge.train.utils import build_optimizer, build_lr_scheduler
|
| 42 |
+
from moge.utils.geometry_torch import intrinsics_to_fov
|
| 43 |
+
from moge.utils.vis import colorize_depth, colorize_normal
|
| 44 |
+
from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict
|
| 45 |
+
from moge.test.metrics import compute_metrics
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@click.command()
|
| 49 |
+
@click.option('--config', 'config_path', type=str, default='configs/debug.json')
|
| 50 |
+
@click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace')
|
| 51 |
+
@click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load')
|
| 52 |
+
@click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device')
|
| 53 |
+
@click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients')
|
| 54 |
+
@click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone')
|
| 55 |
+
@click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16')
|
| 56 |
+
@click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights')
|
| 57 |
+
@click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model')
|
| 58 |
+
@click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations')
|
| 59 |
+
@click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations')
|
| 60 |
+
@click.option('--vis_every', type=int, default=0, help='Visualize every n iterations')
|
| 61 |
+
@click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size')
|
| 62 |
+
@click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow')
|
| 63 |
+
@click.option('--seed', type=int, default=0, help='Random seed')
|
| 64 |
+
def main(
|
| 65 |
+
config_path: str,
|
| 66 |
+
workspace: str,
|
| 67 |
+
checkpoint_path: str,
|
| 68 |
+
batch_size_forward: int,
|
| 69 |
+
gradient_accumulation_steps: int,
|
| 70 |
+
enable_gradient_checkpointing: bool,
|
| 71 |
+
enable_mixed_precision: bool,
|
| 72 |
+
enable_ema: bool,
|
| 73 |
+
num_iterations: int,
|
| 74 |
+
save_every: int,
|
| 75 |
+
log_every: int,
|
| 76 |
+
vis_every: int,
|
| 77 |
+
num_vis_images: int,
|
| 78 |
+
enable_mlflow: bool,
|
| 79 |
+
seed: Optional[int],
|
| 80 |
+
):
|
| 81 |
+
# Load config
|
| 82 |
+
with open(config_path, 'r') as f:
|
| 83 |
+
config = json.load(f)
|
| 84 |
+
|
| 85 |
+
accelerator = Accelerator(
|
| 86 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 87 |
+
mixed_precision='fp16' if enable_mixed_precision else None,
|
| 88 |
+
kwargs_handlers=[
|
| 89 |
+
DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 90 |
+
]
|
| 91 |
+
)
|
| 92 |
+
device = accelerator.device
|
| 93 |
+
batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes
|
| 94 |
+
|
| 95 |
+
# Log config
|
| 96 |
+
if accelerator.is_main_process:
|
| 97 |
+
if enable_mlflow:
|
| 98 |
+
try:
|
| 99 |
+
mlflow.log_params({
|
| 100 |
+
**click.get_current_context().params,
|
| 101 |
+
'batch_size_total': batch_size_total,
|
| 102 |
+
})
|
| 103 |
+
except:
|
| 104 |
+
print('Failed to log config to MLFlow')
|
| 105 |
+
Path(workspace).mkdir(parents=True, exist_ok=True)
|
| 106 |
+
with Path(workspace).joinpath('config.json').open('w') as f:
|
| 107 |
+
json.dump(config, f, indent=4)
|
| 108 |
+
|
| 109 |
+
# Set seed
|
| 110 |
+
if seed is not None:
|
| 111 |
+
set_seed(seed, device_specific=True)
|
| 112 |
+
|
| 113 |
+
# Initialize model
|
| 114 |
+
print('Initialize model')
|
| 115 |
+
with accelerator.local_main_process_first():
|
| 116 |
+
from moge.model import import_model_class_by_version
|
| 117 |
+
MoGeModel = import_model_class_by_version(config['model_version'])
|
| 118 |
+
model = MoGeModel(**config['model'])
|
| 119 |
+
count_total_parameters = sum(p.numel() for p in model.parameters())
|
| 120 |
+
print(f'Total parameters: {count_total_parameters}')
|
| 121 |
+
|
| 122 |
+
# Set up EMA model
|
| 123 |
+
if enable_ema and accelerator.is_main_process:
|
| 124 |
+
ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter
|
| 125 |
+
ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn)
|
| 126 |
+
|
| 127 |
+
# Set gradient checkpointing
|
| 128 |
+
if enable_gradient_checkpointing:
|
| 129 |
+
model.enable_gradient_checkpointing()
|
| 130 |
+
import warnings
|
| 131 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint")
|
| 132 |
+
|
| 133 |
+
# Initalize optimizer & lr scheduler
|
| 134 |
+
optimizer = build_optimizer(model, config['optimizer'])
|
| 135 |
+
lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler'])
|
| 136 |
+
|
| 137 |
+
count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups]
|
| 138 |
+
for i, count in enumerate(count_grouped_parameters):
|
| 139 |
+
print(f'- Group {i}: {count} parameters')
|
| 140 |
+
|
| 141 |
+
# Attempt to load checkpoint
|
| 142 |
+
checkpoint: Dict[str, Any]
|
| 143 |
+
with accelerator.local_main_process_first():
|
| 144 |
+
if checkpoint_path.endswith('.pt'):
|
| 145 |
+
# - Load specific checkpoint file
|
| 146 |
+
print(f'Load checkpoint: {checkpoint_path}')
|
| 147 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
|
| 148 |
+
elif checkpoint_path == "latest":
|
| 149 |
+
# - Load latest
|
| 150 |
+
checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt')
|
| 151 |
+
if checkpoint_path.exists():
|
| 152 |
+
print(f'Load checkpoint: {checkpoint_path}')
|
| 153 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
|
| 154 |
+
i_step = checkpoint['step']
|
| 155 |
+
if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
|
| 156 |
+
print(f'Load model checkpoint: {checkpoint_model_path}')
|
| 157 |
+
checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
|
| 158 |
+
if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
|
| 159 |
+
print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
|
| 160 |
+
checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
|
| 161 |
+
if enable_ema and accelerator.is_main_process:
|
| 162 |
+
if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
|
| 163 |
+
print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
|
| 164 |
+
checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
|
| 165 |
+
else:
|
| 166 |
+
checkpoint = None
|
| 167 |
+
elif checkpoint_path is not None:
|
| 168 |
+
# - Load by step number
|
| 169 |
+
i_step = int(checkpoint_path)
|
| 170 |
+
checkpoint = {'step': i_step}
|
| 171 |
+
if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
|
| 172 |
+
print(f'Load model checkpoint: {checkpoint_model_path}')
|
| 173 |
+
checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
|
| 174 |
+
if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
|
| 175 |
+
print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
|
| 176 |
+
checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
|
| 177 |
+
if enable_ema and accelerator.is_main_process:
|
| 178 |
+
if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
|
| 179 |
+
print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
|
| 180 |
+
checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
|
| 181 |
+
else:
|
| 182 |
+
checkpoint = None
|
| 183 |
+
|
| 184 |
+
if checkpoint is None:
|
| 185 |
+
# Initialize model weights
|
| 186 |
+
print('Initialize model weights')
|
| 187 |
+
with accelerator.local_main_process_first():
|
| 188 |
+
model.init_weights()
|
| 189 |
+
initial_step = 0
|
| 190 |
+
else:
|
| 191 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 192 |
+
if 'step' in checkpoint:
|
| 193 |
+
initial_step = checkpoint['step'] + 1
|
| 194 |
+
else:
|
| 195 |
+
initial_step = 0
|
| 196 |
+
if 'optimizer' in checkpoint:
|
| 197 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 198 |
+
if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint:
|
| 199 |
+
ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False)
|
| 200 |
+
if 'lr_scheduler' in checkpoint:
|
| 201 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 202 |
+
|
| 203 |
+
del checkpoint
|
| 204 |
+
|
| 205 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
| 206 |
+
if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
| 207 |
+
# Hacking potential gradient synchronization issue in ROCm backend
|
| 208 |
+
from moge.model.utils import sync_ddp_hook
|
| 209 |
+
model.register_comm_hook(None, sync_ddp_hook)
|
| 210 |
+
|
| 211 |
+
# Initialize training data pipeline
|
| 212 |
+
with accelerator.local_main_process_first():
|
| 213 |
+
train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward)
|
| 214 |
+
|
| 215 |
+
def _write_bytes_retry_loop(save_path: Path, data: bytes):
|
| 216 |
+
while True:
|
| 217 |
+
try:
|
| 218 |
+
save_path.write_bytes(data)
|
| 219 |
+
break
|
| 220 |
+
except Exception as e:
|
| 221 |
+
print('Error while saving checkpoint, retrying in 1 minute: ', e)
|
| 222 |
+
time.sleep(60)
|
| 223 |
+
|
| 224 |
+
# Ready to train
|
| 225 |
+
records = []
|
| 226 |
+
model.train()
|
| 227 |
+
with (
|
| 228 |
+
train_data_pipe,
|
| 229 |
+
tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar,
|
| 230 |
+
ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor,
|
| 231 |
+
):
|
| 232 |
+
# Get some batches for visualization
|
| 233 |
+
if accelerator.is_main_process:
|
| 234 |
+
batches_for_vis: List[Dict[str, torch.Tensor]] = []
|
| 235 |
+
num_vis_images = num_vis_images // batch_size_forward * batch_size_forward
|
| 236 |
+
for _ in range(num_vis_images // batch_size_forward):
|
| 237 |
+
batch = train_data_pipe.get()
|
| 238 |
+
batches_for_vis.append(batch)
|
| 239 |
+
|
| 240 |
+
# Visualize GT
|
| 241 |
+
if vis_every > 0 and accelerator.is_main_process and initial_step == 0:
|
| 242 |
+
save_dir = Path(workspace).joinpath('vis/gt')
|
| 243 |
+
for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)):
|
| 244 |
+
image, gt_depth, gt_mask, gt_mask_inf, gt_intrinsics, info = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_inf'], batch['intrinsics'], batch['info']
|
| 245 |
+
gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics)
|
| 246 |
+
gt_normal, gt_normal_mask = utils3d.torch.points_to_normals(gt_points, gt_mask)
|
| 247 |
+
for i_instance in range(batch['image'].shape[0]):
|
| 248 |
+
idx = i_batch * batch_size_forward + i_instance
|
| 249 |
+
image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 250 |
+
gt_depth_i = gt_depth[i_instance].numpy()
|
| 251 |
+
gt_mask_i = gt_mask[i_instance].numpy()
|
| 252 |
+
gt_mask_inf_i = gt_mask_inf[i_instance].numpy()
|
| 253 |
+
gt_points_i = gt_points[i_instance].numpy()
|
| 254 |
+
gt_normal_i = gt_normal[i_instance].numpy()
|
| 255 |
+
save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
|
| 256 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR))
|
| 257 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 258 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), gt_mask_i * 255)
|
| 259 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i, gt_mask_i), cv2.COLOR_RGB2BGR))
|
| 260 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR))
|
| 261 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask_inf.png')), gt_mask_inf_i * 255)
|
| 262 |
+
with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f:
|
| 263 |
+
json.dump(info[i_instance], f)
|
| 264 |
+
|
| 265 |
+
# Reset seed to avoid training on the same data when resuming training
|
| 266 |
+
if seed is not None:
|
| 267 |
+
set_seed(seed + initial_step, device_specific=True)
|
| 268 |
+
|
| 269 |
+
# Training loop
|
| 270 |
+
for i_step in range(initial_step, num_iterations):
|
| 271 |
+
|
| 272 |
+
i_accumulate, weight_accumulate = 0, 0
|
| 273 |
+
while i_accumulate < gradient_accumulation_steps:
|
| 274 |
+
# Load batch
|
| 275 |
+
batch = train_data_pipe.get()
|
| 276 |
+
image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric']
|
| 277 |
+
image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device)
|
| 278 |
+
current_batch_size = image.shape[0]
|
| 279 |
+
if all(label == 'invalid' for label in label_type):
|
| 280 |
+
continue # NOTE: Skip all-invalid batches to avoid messing up the optimizer.
|
| 281 |
+
|
| 282 |
+
gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics)
|
| 283 |
+
gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5
|
| 284 |
+
|
| 285 |
+
with accelerator.accumulate(model):
|
| 286 |
+
# Forward
|
| 287 |
+
if i_step <= config.get('low_resolution_training_steps', 0):
|
| 288 |
+
num_tokens = config['model']['num_tokens_range'][0]
|
| 289 |
+
else:
|
| 290 |
+
num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0]
|
| 291 |
+
with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision):
|
| 292 |
+
output = model(image, num_tokens=num_tokens)
|
| 293 |
+
pred_points, pred_mask, pred_metric_scale = output['points'], output['mask'], output.get('metric_scale', None)
|
| 294 |
+
|
| 295 |
+
# Compute loss (per instance)
|
| 296 |
+
loss_list, weight_list = [], []
|
| 297 |
+
for i in range(current_batch_size):
|
| 298 |
+
gt_metric_scale = None
|
| 299 |
+
loss_dict, weight_dict, misc_dict = {}, {}, {}
|
| 300 |
+
misc_dict['monitoring'] = monitoring(pred_points[i])
|
| 301 |
+
for k, v in config['loss'][label_type[i]].items():
|
| 302 |
+
weight_dict[k] = v['weight']
|
| 303 |
+
if v['function'] == 'affine_invariant_global_loss':
|
| 304 |
+
loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], gt_mask[i], **v['params'])
|
| 305 |
+
elif v['function'] == 'affine_invariant_local_loss':
|
| 306 |
+
loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_mask[i], gt_focal[i], gt_metric_scale, **v['params'])
|
| 307 |
+
elif v['function'] == 'normal_loss':
|
| 308 |
+
loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i], gt_mask[i])
|
| 309 |
+
elif v['function'] == 'edge_loss':
|
| 310 |
+
loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i], gt_mask[i])
|
| 311 |
+
elif v['function'] == 'mask_bce_loss':
|
| 312 |
+
loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
|
| 313 |
+
elif v['function'] == 'mask_l2_loss':
|
| 314 |
+
loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
|
| 315 |
+
else:
|
| 316 |
+
raise ValueError(f'Undefined loss function: {v["function"]}')
|
| 317 |
+
weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()}
|
| 318 |
+
loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()}
|
| 319 |
+
loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device))
|
| 320 |
+
loss_list.append(loss_)
|
| 321 |
+
|
| 322 |
+
if torch.isnan(loss_).item():
|
| 323 |
+
pbar.write(f'NaN loss in process {accelerator.process_index}')
|
| 324 |
+
pbar.write(str(loss_dict))
|
| 325 |
+
|
| 326 |
+
misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()}
|
| 327 |
+
records.append({
|
| 328 |
+
**{k: v.item() for k, v in loss_dict.items()},
|
| 329 |
+
**misc_dict,
|
| 330 |
+
})
|
| 331 |
+
|
| 332 |
+
loss = sum(loss_list) / len(loss_list)
|
| 333 |
+
|
| 334 |
+
# Backward & update
|
| 335 |
+
accelerator.backward(loss)
|
| 336 |
+
if accelerator.sync_gradients:
|
| 337 |
+
if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
|
| 338 |
+
if accelerator.is_main_process:
|
| 339 |
+
pbar.write(f'NaN gradients, skip update')
|
| 340 |
+
optimizer.zero_grad()
|
| 341 |
+
continue
|
| 342 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
| 343 |
+
|
| 344 |
+
optimizer.step()
|
| 345 |
+
optimizer.zero_grad()
|
| 346 |
+
|
| 347 |
+
i_accumulate += 1
|
| 348 |
+
|
| 349 |
+
lr_scheduler.step()
|
| 350 |
+
|
| 351 |
+
# EMA update
|
| 352 |
+
if enable_ema and accelerator.is_main_process and accelerator.sync_gradients:
|
| 353 |
+
ema_model.update_parameters(model)
|
| 354 |
+
|
| 355 |
+
# Log metrics
|
| 356 |
+
if i_step == initial_step or i_step % log_every == 0:
|
| 357 |
+
records = [key_average(records)]
|
| 358 |
+
records = accelerator.gather_for_metrics(records, use_gather_object=True)
|
| 359 |
+
if accelerator.is_main_process:
|
| 360 |
+
records = key_average(records)
|
| 361 |
+
if enable_mlflow:
|
| 362 |
+
try:
|
| 363 |
+
mlflow.log_metrics(records, step=i_step)
|
| 364 |
+
except Exception as e:
|
| 365 |
+
print(f'Error while logging metrics to mlflow: {e}')
|
| 366 |
+
records = []
|
| 367 |
+
|
| 368 |
+
# Save model weight checkpoint
|
| 369 |
+
if accelerator.is_main_process and (i_step % save_every == 0):
|
| 370 |
+
# NOTE: Writing checkpoint is done in a separate thread to avoid blocking the main process
|
| 371 |
+
pbar.write(f'Save checkpoint: {i_step:08d}')
|
| 372 |
+
Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True)
|
| 373 |
+
|
| 374 |
+
# Model checkpoint
|
| 375 |
+
with io.BytesIO() as f:
|
| 376 |
+
torch.save({
|
| 377 |
+
'model_config': config['model'],
|
| 378 |
+
'model': accelerator.unwrap_model(model).state_dict(),
|
| 379 |
+
}, f)
|
| 380 |
+
checkpoint_bytes = f.getvalue()
|
| 381 |
+
save_checkpoint_executor.submit(
|
| 382 |
+
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Optimizer checkpoint
|
| 386 |
+
with io.BytesIO() as f:
|
| 387 |
+
torch.save({
|
| 388 |
+
'model_config': config['model'],
|
| 389 |
+
'step': i_step,
|
| 390 |
+
'optimizer': optimizer.state_dict(),
|
| 391 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 392 |
+
}, f)
|
| 393 |
+
checkpoint_bytes = f.getvalue()
|
| 394 |
+
save_checkpoint_executor.submit(
|
| 395 |
+
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# EMA model checkpoint
|
| 399 |
+
if enable_ema:
|
| 400 |
+
with io.BytesIO() as f:
|
| 401 |
+
torch.save({
|
| 402 |
+
'model_config': config['model'],
|
| 403 |
+
'model': ema_model.module.state_dict(),
|
| 404 |
+
}, f)
|
| 405 |
+
checkpoint_bytes = f.getvalue()
|
| 406 |
+
save_checkpoint_executor.submit(
|
| 407 |
+
_write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Latest checkpoint
|
| 411 |
+
with io.BytesIO() as f:
|
| 412 |
+
torch.save({
|
| 413 |
+
'model_config': config['model'],
|
| 414 |
+
'step': i_step,
|
| 415 |
+
}, f)
|
| 416 |
+
checkpoint_bytes = f.getvalue()
|
| 417 |
+
save_checkpoint_executor.submit(
|
| 418 |
+
_write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Visualize
|
| 422 |
+
if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0):
|
| 423 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 424 |
+
save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}')
|
| 425 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 426 |
+
with torch.inference_mode():
|
| 427 |
+
for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)):
|
| 428 |
+
image, gt_depth, gt_mask, gt_intrinsics = batch['image'], batch['depth'], batch['depth_mask'], batch['intrinsics']
|
| 429 |
+
image, gt_depth, gt_mask, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_intrinsics.to(device)
|
| 430 |
+
|
| 431 |
+
output = unwrapped_model.infer(image)
|
| 432 |
+
pred_points, pred_depth, pred_mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy()
|
| 433 |
+
image = image.cpu().numpy()
|
| 434 |
+
|
| 435 |
+
for i_instance in range(image.shape[0]):
|
| 436 |
+
idx = i_batch * batch_size_forward + i_instance
|
| 437 |
+
image_i = (image[i_instance].transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 438 |
+
pred_points_i = pred_points[i_instance]
|
| 439 |
+
pred_mask_i = pred_mask[i_instance]
|
| 440 |
+
pred_depth_i = pred_depth[i_instance]
|
| 441 |
+
save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
|
| 442 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR))
|
| 443 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
|
| 444 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask_i * 255)
|
| 445 |
+
cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth_i, pred_mask_i), cv2.COLOR_RGB2BGR))
|
| 446 |
+
|
| 447 |
+
pbar.set_postfix({'loss': loss.item()}, refresh=False)
|
| 448 |
+
pbar.update(1)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
if __name__ == '__main__':
|
| 452 |
+
main()
|
moge/scripts/vis_data.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
|
| 6 |
+
sys.path.insert(0, _package_root)
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@click.command()
|
| 12 |
+
@click.argument('folder_or_path', type=click.Path(exists=True))
|
| 13 |
+
@click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder')
|
| 14 |
+
@click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth')
|
| 15 |
+
@click.option('--fov', type=float, default=None, help='field of view in degrees')
|
| 16 |
+
@click.option('--show', 'show', is_flag=True, help='show point cloud')
|
| 17 |
+
@click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name')
|
| 18 |
+
@click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file')
|
| 19 |
+
@click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image')
|
| 20 |
+
@click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask')
|
| 21 |
+
@click.option('--version', 'version', type=str, default='v3', help='version of rgbd data')
|
| 22 |
+
def main(
|
| 23 |
+
folder_or_path: str,
|
| 24 |
+
output_folder: str,
|
| 25 |
+
max_depth: float,
|
| 26 |
+
fov: float,
|
| 27 |
+
depth_filename: str,
|
| 28 |
+
show: bool,
|
| 29 |
+
save_ply: bool,
|
| 30 |
+
save_depth_vis: bool,
|
| 31 |
+
inf_mask: bool,
|
| 32 |
+
version: str
|
| 33 |
+
):
|
| 34 |
+
# Lazy import
|
| 35 |
+
import cv2
|
| 36 |
+
import numpy as np
|
| 37 |
+
import utils3d
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
import trimesh
|
| 40 |
+
|
| 41 |
+
from moge.utils.io import read_image, read_depth, read_meta
|
| 42 |
+
from moge.utils.vis import colorize_depth, colorize_normal
|
| 43 |
+
|
| 44 |
+
filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json'))
|
| 45 |
+
|
| 46 |
+
for filepath in tqdm(filepaths):
|
| 47 |
+
image = read_image(Path(filepath, 'image.jpg'))
|
| 48 |
+
depth, unit = read_depth(Path(filepath, depth_filename))
|
| 49 |
+
meta = read_meta(Path(filepath,'meta.json'))
|
| 50 |
+
depth_mask = np.isfinite(depth)
|
| 51 |
+
depth_mask_inf = (depth == np.inf)
|
| 52 |
+
intrinsics = np.array(meta['intrinsics'])
|
| 53 |
+
|
| 54 |
+
extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera
|
| 55 |
+
verts = utils3d.numpy.unproject_cv(utils3d.numpy.image_uv(*image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics)
|
| 56 |
+
|
| 57 |
+
depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth)
|
| 58 |
+
point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255)
|
| 59 |
+
|
| 60 |
+
if show:
|
| 61 |
+
point_cloud.show()
|
| 62 |
+
|
| 63 |
+
if output_folder is None:
|
| 64 |
+
output_path = filepath
|
| 65 |
+
else:
|
| 66 |
+
output_path = Path(output_folder, filepath.name)
|
| 67 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 68 |
+
|
| 69 |
+
if inf_mask:
|
| 70 |
+
depth = np.where(depth_mask_inf, np.inf, depth)
|
| 71 |
+
depth_mask = depth_mask | depth_mask_inf
|
| 72 |
+
|
| 73 |
+
if save_depth_vis:
|
| 74 |
+
p = output_path.joinpath('depth_vis.png')
|
| 75 |
+
cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR))
|
| 76 |
+
print(f"{p}")
|
| 77 |
+
|
| 78 |
+
if save_ply:
|
| 79 |
+
p = output_path.joinpath('pointcloud.ply')
|
| 80 |
+
point_cloud.export(p)
|
| 81 |
+
print(f"{p}")
|
| 82 |
+
|
| 83 |
+
if __name__ == '__main__':
|
| 84 |
+
main()
|
moge/test/__init__.py
ADDED
|
File without changes
|
moge/test/baseline.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MGEBaselineInterface:
|
| 8 |
+
"""
|
| 9 |
+
Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
|
| 10 |
+
"""
|
| 11 |
+
device: torch.device
|
| 12 |
+
|
| 13 |
+
@click.command()
|
| 14 |
+
@staticmethod
|
| 15 |
+
def load(*args, **kwargs) -> "MGEBaselineInterface":
|
| 16 |
+
"""
|
| 17 |
+
Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
|
| 18 |
+
"""
|
| 19 |
+
raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
|
| 20 |
+
|
| 21 |
+
def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
| 22 |
+
"""
|
| 23 |
+
### Parameters
|
| 24 |
+
`image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
|
| 25 |
+
`intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
|
| 26 |
+
|
| 27 |
+
### Returns
|
| 28 |
+
A dictionary containing:
|
| 29 |
+
- `points_*`. point map output in OpenCV identity camera space.
|
| 30 |
+
Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
|
| 31 |
+
- `depth_*`. depth map output
|
| 32 |
+
Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
|
| 33 |
+
- `disparity_affine_invariant`. affine disparity map output
|
| 34 |
+
"""
|
| 35 |
+
raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
|
| 36 |
+
|
| 37 |
+
def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
| 38 |
+
"""
|
| 39 |
+
If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
|
| 40 |
+
|
| 41 |
+
By default, this method simply calls `infer()`.
|
| 42 |
+
"""
|
| 43 |
+
return self.infer(image, intrinsics)
|
moge/test/dataloader.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import *
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
import utils3d
|
| 11 |
+
|
| 12 |
+
from ..utils import pipeline
|
| 13 |
+
from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d
|
| 14 |
+
from ..utils.io import *
|
| 15 |
+
from ..utils.tools import timeit
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EvalDataLoaderPipeline:
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
path: str,
|
| 23 |
+
width: int,
|
| 24 |
+
height: int,
|
| 25 |
+
split: int = '.index.txt',
|
| 26 |
+
drop_max_depth: float = 1000.,
|
| 27 |
+
num_load_workers: int = 4,
|
| 28 |
+
num_process_workers: int = 8,
|
| 29 |
+
include_segmentation: bool = False,
|
| 30 |
+
include_normal: bool = False,
|
| 31 |
+
depth_to_normal: bool = False,
|
| 32 |
+
max_segments: int = 100,
|
| 33 |
+
min_seg_area: int = 1000,
|
| 34 |
+
depth_unit: str = None,
|
| 35 |
+
has_sharp_boundary = False,
|
| 36 |
+
subset: int = None,
|
| 37 |
+
):
|
| 38 |
+
filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines()
|
| 39 |
+
filenames = filenames[::subset]
|
| 40 |
+
self.width = width
|
| 41 |
+
self.height = height
|
| 42 |
+
self.drop_max_depth = drop_max_depth
|
| 43 |
+
self.path = Path(path)
|
| 44 |
+
self.filenames = filenames
|
| 45 |
+
self.include_segmentation = include_segmentation
|
| 46 |
+
self.include_normal = include_normal
|
| 47 |
+
self.max_segments = max_segments
|
| 48 |
+
self.min_seg_area = min_seg_area
|
| 49 |
+
self.depth_to_normal = depth_to_normal
|
| 50 |
+
self.depth_unit = depth_unit
|
| 51 |
+
self.has_sharp_boundary = has_sharp_boundary
|
| 52 |
+
|
| 53 |
+
self.rng = np.random.default_rng(seed=0)
|
| 54 |
+
|
| 55 |
+
self.pipeline = pipeline.Sequential([
|
| 56 |
+
self._generator,
|
| 57 |
+
pipeline.Parallel([self._load_instance] * num_load_workers),
|
| 58 |
+
pipeline.Parallel([self._process_instance] * num_process_workers),
|
| 59 |
+
pipeline.Buffer(4)
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return math.ceil(len(self.filenames))
|
| 64 |
+
|
| 65 |
+
def _generator(self):
|
| 66 |
+
for idx in range(len(self)):
|
| 67 |
+
yield idx
|
| 68 |
+
|
| 69 |
+
def _load_instance(self, idx):
|
| 70 |
+
if idx >= len(self.filenames):
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
path = self.path.joinpath(self.filenames[idx])
|
| 74 |
+
|
| 75 |
+
instance = {
|
| 76 |
+
'filename': self.filenames[idx],
|
| 77 |
+
'width': self.width,
|
| 78 |
+
'height': self.height,
|
| 79 |
+
}
|
| 80 |
+
instance['image'] = read_image(Path(path, 'image.jpg'))
|
| 81 |
+
|
| 82 |
+
depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead
|
| 83 |
+
instance.update({
|
| 84 |
+
'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1),
|
| 85 |
+
'depth_mask': np.isfinite(depth),
|
| 86 |
+
'depth_mask_inf': np.isinf(depth),
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
if self.include_segmentation:
|
| 90 |
+
segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png'))
|
| 91 |
+
instance.update({
|
| 92 |
+
'segmentation_mask': segmentation_mask,
|
| 93 |
+
'segmentation_labels': segmentation_labels,
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
meta = read_meta(Path(path, 'meta.json'))
|
| 97 |
+
instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
return instance
|
| 100 |
+
|
| 101 |
+
def _process_instance(self, instance: dict):
|
| 102 |
+
if instance is None:
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics']
|
| 106 |
+
segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None)
|
| 107 |
+
|
| 108 |
+
raw_height, raw_width = image.shape[:2]
|
| 109 |
+
raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
|
| 110 |
+
raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
|
| 111 |
+
tgt_width, tgt_height = instance['width'], instance['height']
|
| 112 |
+
tgt_aspect = tgt_width / tgt_height
|
| 113 |
+
|
| 114 |
+
# set expected target view field
|
| 115 |
+
tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect)
|
| 116 |
+
tgt_vertical = tgt_horizontal / tgt_aspect
|
| 117 |
+
|
| 118 |
+
# set target view direction
|
| 119 |
+
cu, cv = 0.5, 0.5
|
| 120 |
+
direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
|
| 121 |
+
R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
|
| 122 |
+
|
| 123 |
+
# restrict target view field within the raw view
|
| 124 |
+
corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
|
| 125 |
+
corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
|
| 126 |
+
corners = corners[:, :2] / corners[:, 2:3]
|
| 127 |
+
|
| 128 |
+
warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
|
| 129 |
+
for i in range(4):
|
| 130 |
+
intersection, _ = utils3d.numpy.ray_intersection(
|
| 131 |
+
np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
|
| 132 |
+
corners[i - 1], corners[i] - corners[i - 1],
|
| 133 |
+
)
|
| 134 |
+
warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
|
| 135 |
+
tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
|
| 136 |
+
|
| 137 |
+
# get target view intrinsics
|
| 138 |
+
fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical
|
| 139 |
+
tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
|
| 140 |
+
|
| 141 |
+
# do homogeneous transformation with the rotation and intrinsics
|
| 142 |
+
# 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
|
| 143 |
+
tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
|
| 144 |
+
rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
|
| 145 |
+
image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
|
| 146 |
+
|
| 147 |
+
depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h))
|
| 148 |
+
distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics))
|
| 149 |
+
segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None
|
| 150 |
+
|
| 151 |
+
# 4.2 calculate homography warping
|
| 152 |
+
transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
|
| 153 |
+
uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
|
| 154 |
+
pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
|
| 155 |
+
uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
|
| 156 |
+
pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
|
| 157 |
+
|
| 158 |
+
tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
|
| 159 |
+
tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST)
|
| 160 |
+
tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)
|
| 161 |
+
tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5
|
| 162 |
+
tgt_depth = tgt_distance / (tgt_ray_length + 1e-12)
|
| 163 |
+
tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
|
| 164 |
+
tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None
|
| 165 |
+
|
| 166 |
+
# drop depth greater than drop_max_depth
|
| 167 |
+
max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth
|
| 168 |
+
tgt_depth_mask &= tgt_depth <= max_depth
|
| 169 |
+
tgt_depth = np.nan_to_num(tgt_depth, nan=0.0)
|
| 170 |
+
|
| 171 |
+
if self.depth_unit is not None:
|
| 172 |
+
tgt_depth *= self.depth_unit
|
| 173 |
+
|
| 174 |
+
if not np.any(tgt_depth_mask):
|
| 175 |
+
# always make sure that mask is not empty, otherwise the loss calculation will crash
|
| 176 |
+
tgt_depth_mask = np.ones_like(tgt_depth_mask)
|
| 177 |
+
tgt_depth = np.ones_like(tgt_depth)
|
| 178 |
+
instance['label_type'] = 'invalid'
|
| 179 |
+
|
| 180 |
+
tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics)
|
| 181 |
+
|
| 182 |
+
# Process segmentation labels
|
| 183 |
+
if self.include_segmentation and segmentation_mask is not None:
|
| 184 |
+
for k in ['undefined', 'unannotated', 'background', 'sky']:
|
| 185 |
+
if k in segmentation_labels:
|
| 186 |
+
del segmentation_labels[k]
|
| 187 |
+
seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True)))
|
| 188 |
+
sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True)
|
| 189 |
+
segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area}
|
| 190 |
+
|
| 191 |
+
instance.update({
|
| 192 |
+
'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
|
| 193 |
+
'depth': torch.from_numpy(tgt_depth).float(),
|
| 194 |
+
'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
|
| 195 |
+
'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
|
| 196 |
+
'points': torch.from_numpy(tgt_pts).float(),
|
| 197 |
+
'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None,
|
| 198 |
+
'segmentation_labels': segmentation_labels,
|
| 199 |
+
'is_metric': self.depth_unit is not None,
|
| 200 |
+
'has_sharp_boundary': self.has_sharp_boundary,
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
instance = {k: v for k, v in instance.items() if v is not None}
|
| 204 |
+
|
| 205 |
+
return instance
|
| 206 |
+
|
| 207 |
+
def start(self):
|
| 208 |
+
self.pipeline.start()
|
| 209 |
+
|
| 210 |
+
def stop(self):
|
| 211 |
+
self.pipeline.stop()
|
| 212 |
+
|
| 213 |
+
def __enter__(self):
|
| 214 |
+
self.start()
|
| 215 |
+
return self
|
| 216 |
+
|
| 217 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 218 |
+
self.stop()
|
| 219 |
+
|
| 220 |
+
def get(self):
|
| 221 |
+
return self.pipeline.get()
|
moge/test/metrics.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
import utils3d
|
| 8 |
+
|
| 9 |
+
from ..utils.geometry_torch import (
|
| 10 |
+
weighted_mean,
|
| 11 |
+
mask_aware_nearest_resize,
|
| 12 |
+
intrinsics_to_fov
|
| 13 |
+
)
|
| 14 |
+
from ..utils.alignment import (
|
| 15 |
+
align_points_scale_z_shift,
|
| 16 |
+
align_points_scale_xyz_shift,
|
| 17 |
+
align_points_xyz_shift,
|
| 18 |
+
align_affine_lstsq,
|
| 19 |
+
align_depth_scale,
|
| 20 |
+
align_depth_affine,
|
| 21 |
+
align_points_scale,
|
| 22 |
+
)
|
| 23 |
+
from ..utils.tools import key_average, timeit
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
|
| 27 |
+
rel = (torch.abs(pred - gt) / (gt + eps)).mean()
|
| 28 |
+
return rel.item()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
|
| 32 |
+
delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean()
|
| 33 |
+
return delta1.item()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
|
| 37 |
+
dist_gt = torch.norm(gt, dim=-1)
|
| 38 |
+
dist_err = torch.norm(pred - gt, dim=-1)
|
| 39 |
+
rel = (dist_err / (dist_gt + eps)).mean()
|
| 40 |
+
return rel.item()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
|
| 44 |
+
dist_pred = torch.norm(pred, dim=-1)
|
| 45 |
+
dist_gt = torch.norm(gt, dim=-1)
|
| 46 |
+
dist_err = torch.norm(pred - gt, dim=-1)
|
| 47 |
+
|
| 48 |
+
delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean()
|
| 49 |
+
return delta1.item()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
|
| 53 |
+
dist_err = torch.norm(pred - gt, dim=-1)
|
| 54 |
+
rel = (dist_err / diameter).mean()
|
| 55 |
+
return rel.item()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
|
| 59 |
+
dist_err = torch.norm(pred - gt, dim=-1)
|
| 60 |
+
delta1 = (dist_err < 0.25 * diameter).float().mean()
|
| 61 |
+
return delta1.item()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1):
|
| 65 |
+
neighbor_x, neight_y = torch.meshgrid(
|
| 66 |
+
torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
|
| 67 |
+
torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
|
| 68 |
+
indexing='xy'
|
| 69 |
+
)
|
| 70 |
+
neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5
|
| 71 |
+
|
| 72 |
+
pred_window = utils3d.torch.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
|
| 73 |
+
gt_window = utils3d.torch.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
|
| 74 |
+
mask_window = neighbor_mask & utils3d.torch.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
|
| 75 |
+
|
| 76 |
+
pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None]
|
| 77 |
+
gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None]
|
| 78 |
+
valid = mask[radius:-radius, radius:-radius, None, None] & mask_window
|
| 79 |
+
|
| 80 |
+
f1_list = []
|
| 81 |
+
w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist()
|
| 82 |
+
|
| 83 |
+
for t in t_list:
|
| 84 |
+
pred_label = pred_rel > 1 + t
|
| 85 |
+
gt_label = gt_rel > 1 + t
|
| 86 |
+
TP = (pred_label & gt_label & valid).float().sum()
|
| 87 |
+
precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12)
|
| 88 |
+
recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12)
|
| 89 |
+
f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12)
|
| 90 |
+
f1_list.append(f1.item())
|
| 91 |
+
|
| 92 |
+
f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list)
|
| 93 |
+
return f1_avg
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_metrics(
|
| 97 |
+
pred: Dict[str, torch.Tensor],
|
| 98 |
+
gt: Dict[str, torch.Tensor],
|
| 99 |
+
vis: bool = False
|
| 100 |
+
) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]:
|
| 101 |
+
"""
|
| 102 |
+
A unified function to compute metrics for different types of predictions and ground truths.
|
| 103 |
+
|
| 104 |
+
#### Supported keys in pred:
|
| 105 |
+
- `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant.
|
| 106 |
+
- `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant.
|
| 107 |
+
- `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant.
|
| 108 |
+
- `depth_metric`: depth map predicted by a depth estimator with no scale or shift.
|
| 109 |
+
- `points_scale_invariant`: point map predicted by a point estimator with scale invariant.
|
| 110 |
+
- `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant.
|
| 111 |
+
- `points_metric`: point map predicted by a point estimator with no scale or shift.
|
| 112 |
+
- `intrinsics`: normalized camera intrinsics matrix.
|
| 113 |
+
|
| 114 |
+
#### Required keys in gt:
|
| 115 |
+
- `depth`: depth map ground truth (in metric units if `depth_metric` is used)
|
| 116 |
+
- `points`: point map ground truth in camera coordinates.
|
| 117 |
+
- `mask`: mask indicating valid pixels in the ground truth.
|
| 118 |
+
- `intrinsics`: normalized ground-truth camera intrinsics matrix.
|
| 119 |
+
- `is_metric`: whether the depth is in metric units.
|
| 120 |
+
"""
|
| 121 |
+
metrics = {}
|
| 122 |
+
misc = {}
|
| 123 |
+
|
| 124 |
+
mask = gt['depth_mask']
|
| 125 |
+
gt_depth = gt['depth']
|
| 126 |
+
gt_points = gt['points']
|
| 127 |
+
|
| 128 |
+
height, width = mask.shape[-2:]
|
| 129 |
+
_, lr_mask, lr_index = mask_aware_nearest_resize(None, mask, (64, 64), return_index=True)
|
| 130 |
+
|
| 131 |
+
only_depth = not any('point' in k for k in pred)
|
| 132 |
+
pred_depth_aligned, pred_points_aligned = None, None
|
| 133 |
+
|
| 134 |
+
# Metric depth
|
| 135 |
+
if 'depth_metric' in pred and gt['is_metric']:
|
| 136 |
+
pred_depth, gt_depth = pred['depth_metric'], gt['depth']
|
| 137 |
+
metrics['depth_metric'] = {
|
| 138 |
+
'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
|
| 139 |
+
'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if pred_depth_aligned is None:
|
| 143 |
+
pred_depth_aligned = pred_depth
|
| 144 |
+
|
| 145 |
+
# Scale-invariant depth
|
| 146 |
+
if 'depth_scale_invariant' in pred:
|
| 147 |
+
pred_depth_scale_invariant = pred['depth_scale_invariant']
|
| 148 |
+
elif 'depth_metric' in pred:
|
| 149 |
+
pred_depth_scale_invariant = pred['depth_metric']
|
| 150 |
+
else:
|
| 151 |
+
pred_depth_scale_invariant = None
|
| 152 |
+
|
| 153 |
+
if pred_depth_scale_invariant is not None:
|
| 154 |
+
pred_depth = pred_depth_scale_invariant
|
| 155 |
+
|
| 156 |
+
pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
|
| 157 |
+
scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
|
| 158 |
+
pred_depth = pred_depth * scale
|
| 159 |
+
|
| 160 |
+
metrics['depth_scale_invariant'] = {
|
| 161 |
+
'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
|
| 162 |
+
'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
if pred_depth_aligned is None:
|
| 166 |
+
pred_depth_aligned = pred_depth
|
| 167 |
+
|
| 168 |
+
# Affine-invariant depth
|
| 169 |
+
if 'depth_affine_invariant' in pred:
|
| 170 |
+
pred_depth_affine_invariant = pred['depth_affine_invariant']
|
| 171 |
+
elif 'depth_scale_invariant' in pred:
|
| 172 |
+
pred_depth_affine_invariant = pred['depth_scale_invariant']
|
| 173 |
+
elif 'depth_metric' in pred:
|
| 174 |
+
pred_depth_affine_invariant = pred['depth_metric']
|
| 175 |
+
else:
|
| 176 |
+
pred_depth_affine_invariant = None
|
| 177 |
+
|
| 178 |
+
if pred_depth_affine_invariant is not None:
|
| 179 |
+
pred_depth = pred_depth_affine_invariant
|
| 180 |
+
|
| 181 |
+
pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
|
| 182 |
+
scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
|
| 183 |
+
pred_depth = pred_depth * scale + shift
|
| 184 |
+
|
| 185 |
+
metrics['depth_affine_invariant'] = {
|
| 186 |
+
'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
|
| 187 |
+
'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
if pred_depth_aligned is None:
|
| 191 |
+
pred_depth_aligned = pred_depth
|
| 192 |
+
|
| 193 |
+
# Affine-invariant disparity
|
| 194 |
+
if 'disparity_affine_invariant' in pred:
|
| 195 |
+
pred_disparity_affine_invariant = pred['disparity_affine_invariant']
|
| 196 |
+
elif 'depth_scale_invariant' in pred:
|
| 197 |
+
pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant']
|
| 198 |
+
elif 'depth_metric' in pred:
|
| 199 |
+
pred_disparity_affine_invariant = 1 / pred['depth_metric']
|
| 200 |
+
else:
|
| 201 |
+
pred_disparity_affine_invariant = None
|
| 202 |
+
|
| 203 |
+
if pred_disparity_affine_invariant is not None:
|
| 204 |
+
pred_disp = pred_disparity_affine_invariant
|
| 205 |
+
|
| 206 |
+
scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask])
|
| 207 |
+
pred_disp = pred_disp * scale + shift
|
| 208 |
+
|
| 209 |
+
# NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0.
|
| 210 |
+
# Therefore we clamp the disparities by minimum ground truth disparity.
|
| 211 |
+
pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item())
|
| 212 |
+
|
| 213 |
+
metrics['disparity_affine_invariant'] = {
|
| 214 |
+
'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
|
| 215 |
+
'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
if pred_depth_aligned is None:
|
| 219 |
+
pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6)
|
| 220 |
+
|
| 221 |
+
# Metric points
|
| 222 |
+
if 'points_metric' in pred and gt['is_metric']:
|
| 223 |
+
pred_points = pred['points_metric']
|
| 224 |
+
|
| 225 |
+
pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
|
| 226 |
+
shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
|
| 227 |
+
pred_points = pred_points + shift
|
| 228 |
+
|
| 229 |
+
metrics['points_metric'] = {
|
| 230 |
+
'rel': rel_point(pred_points[mask], gt_points[mask]),
|
| 231 |
+
'delta1': delta1_point(pred_points[mask], gt_points[mask])
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
if pred_points_aligned is None:
|
| 235 |
+
pred_points_aligned = pred['points_metric']
|
| 236 |
+
|
| 237 |
+
# Scale-invariant points (in camera space)
|
| 238 |
+
if 'points_scale_invariant' in pred:
|
| 239 |
+
pred_points_scale_invariant = pred['points_scale_invariant']
|
| 240 |
+
elif 'points_metric' in pred:
|
| 241 |
+
pred_points_scale_invariant = pred['points_metric']
|
| 242 |
+
else:
|
| 243 |
+
pred_points_scale_invariant = None
|
| 244 |
+
|
| 245 |
+
if pred_points_scale_invariant is not None:
|
| 246 |
+
pred_points = pred_points_scale_invariant
|
| 247 |
+
|
| 248 |
+
pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask]
|
| 249 |
+
scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
|
| 250 |
+
pred_points = pred_points * scale
|
| 251 |
+
|
| 252 |
+
metrics['points_scale_invariant'] = {
|
| 253 |
+
'rel': rel_point(pred_points[mask], gt_points[mask]),
|
| 254 |
+
'delta1': delta1_point(pred_points[mask], gt_points[mask])
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if vis and pred_points_aligned is None:
|
| 258 |
+
pred_points_aligned = pred['points_scale_invariant'] * scale
|
| 259 |
+
|
| 260 |
+
# Affine-invariant points
|
| 261 |
+
if 'points_affine_invariant' in pred:
|
| 262 |
+
pred_points_affine_invariant = pred['points_affine_invariant']
|
| 263 |
+
elif 'points_scale_invariant' in pred:
|
| 264 |
+
pred_points_affine_invariant = pred['points_scale_invariant']
|
| 265 |
+
elif 'points_metric' in pred:
|
| 266 |
+
pred_points_affine_invariant = pred['points_metric']
|
| 267 |
+
else:
|
| 268 |
+
pred_points_affine_invariant = None
|
| 269 |
+
|
| 270 |
+
if pred_points_affine_invariant is not None:
|
| 271 |
+
pred_points = pred_points_affine_invariant
|
| 272 |
+
|
| 273 |
+
pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
|
| 274 |
+
scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
|
| 275 |
+
pred_points = pred_points * scale + shift
|
| 276 |
+
|
| 277 |
+
metrics['points_affine_invariant'] = {
|
| 278 |
+
'rel': rel_point(pred_points[mask], gt_points[mask]),
|
| 279 |
+
'delta1': delta1_point(pred_points[mask], gt_points[mask])
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
if vis and pred_points_aligned is None:
|
| 283 |
+
pred_points_aligned = pred['points_affine_invariant'] * scale + shift
|
| 284 |
+
|
| 285 |
+
# Local points
|
| 286 |
+
if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()):
|
| 287 |
+
pred_points = next(pred[k] for k in pred.keys() if 'points' in k)
|
| 288 |
+
gt_points = gt['points']
|
| 289 |
+
segmentation_mask = gt['segmentation_mask']
|
| 290 |
+
segmentation_labels = gt['segmentation_labels']
|
| 291 |
+
segmentation_mask_lr = segmentation_mask[lr_index]
|
| 292 |
+
local_points_metrics = []
|
| 293 |
+
for _, seg_id in segmentation_labels.items():
|
| 294 |
+
valid_mask = (segmentation_mask == seg_id) & mask
|
| 295 |
+
|
| 296 |
+
pred_points_masked = pred_points[valid_mask]
|
| 297 |
+
gt_points_masked = gt_points[valid_mask]
|
| 298 |
+
|
| 299 |
+
valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask
|
| 300 |
+
if valid_mask_lr.sum().item() < 10:
|
| 301 |
+
continue
|
| 302 |
+
pred_points_masked_lr = pred_points[lr_index][valid_mask_lr]
|
| 303 |
+
gt_points_masked_lr = gt_points[lr_index][valid_mask_lr]
|
| 304 |
+
diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max()
|
| 305 |
+
scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0]))
|
| 306 |
+
pred_points_masked = pred_points_masked * scale + shift
|
| 307 |
+
|
| 308 |
+
local_points_metrics.append({
|
| 309 |
+
'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter),
|
| 310 |
+
'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter),
|
| 311 |
+
})
|
| 312 |
+
|
| 313 |
+
metrics['local_points'] = key_average(local_points_metrics)
|
| 314 |
+
|
| 315 |
+
# FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same.
|
| 316 |
+
# Fair evaluation of FOV requires random augmentation.
|
| 317 |
+
if 'intrinsics' in pred and 'intrinsics' in gt:
|
| 318 |
+
pred_intrinsics = pred['intrinsics']
|
| 319 |
+
gt_intrinsics = gt['intrinsics']
|
| 320 |
+
pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics)
|
| 321 |
+
gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics)
|
| 322 |
+
metrics['fov_x'] = {
|
| 323 |
+
'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(),
|
| 324 |
+
'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(),
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
# Boundary F1
|
| 328 |
+
if pred_depth_aligned is not None and gt['has_sharp_boundary']:
|
| 329 |
+
metrics['boundary'] = {
|
| 330 |
+
'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1),
|
| 331 |
+
'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2),
|
| 332 |
+
'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3),
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
if vis:
|
| 336 |
+
if pred_points_aligned is not None:
|
| 337 |
+
misc['pred_points'] = pred_points_aligned
|
| 338 |
+
if only_depth:
|
| 339 |
+
misc['pred_points'] = utils3d.torch.depth_to_points(pred_depth_aligned, intrinsics=gt['intrinsics'])
|
| 340 |
+
if pred_depth_aligned is not None:
|
| 341 |
+
misc['pred_depth'] = pred_depth_aligned
|
| 342 |
+
|
| 343 |
+
return metrics, misc
|
moge/train/__init__.py
ADDED
|
File without changes
|
moge/train/dataloader.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import random
|
| 6 |
+
from typing import *
|
| 7 |
+
import traceback
|
| 8 |
+
import itertools
|
| 9 |
+
from numbers import Number
|
| 10 |
+
import io
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision.transforms.v2.functional as TF
|
| 17 |
+
import utils3d
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from ..utils import pipeline
|
| 21 |
+
from ..utils.io import *
|
| 22 |
+
from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TrainDataLoaderPipeline:
|
| 26 |
+
def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8):
|
| 27 |
+
self.config = config
|
| 28 |
+
|
| 29 |
+
self.batch_size = batch_size
|
| 30 |
+
self.clamp_max_depth = config['clamp_max_depth']
|
| 31 |
+
self.fov_range_absolute = config.get('fov_range_absolute', 0.0)
|
| 32 |
+
self.fov_range_relative = config.get('fov_range_relative', 0.0)
|
| 33 |
+
self.center_augmentation = config.get('center_augmentation', 0.0)
|
| 34 |
+
self.image_augmentation = config.get('image_augmentation', [])
|
| 35 |
+
self.depth_interpolation = config.get('depth_interpolation', 'bilinear')
|
| 36 |
+
|
| 37 |
+
if 'image_sizes' in config:
|
| 38 |
+
self.image_size_strategy = 'fixed'
|
| 39 |
+
self.image_sizes = config['image_sizes']
|
| 40 |
+
elif 'aspect_ratio_range' in config and 'area_range' in config:
|
| 41 |
+
self.image_size_strategy = 'aspect_area'
|
| 42 |
+
self.aspect_ratio_range = config['aspect_ratio_range']
|
| 43 |
+
self.area_range = config['area_range']
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError('Invalid image size configuration')
|
| 46 |
+
|
| 47 |
+
# Load datasets
|
| 48 |
+
self.datasets = {}
|
| 49 |
+
for dataset in tqdm(config['datasets'], desc='Loading datasets'):
|
| 50 |
+
name = dataset['name']
|
| 51 |
+
content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text()
|
| 52 |
+
filenames = content.splitlines()
|
| 53 |
+
self.datasets[name] = {
|
| 54 |
+
**dataset,
|
| 55 |
+
'path': dataset['path'],
|
| 56 |
+
'filenames': filenames,
|
| 57 |
+
}
|
| 58 |
+
self.dataset_names = [dataset['name'] for dataset in config['datasets']]
|
| 59 |
+
self.dataset_weights = [dataset['weight'] for dataset in config['datasets']]
|
| 60 |
+
|
| 61 |
+
# Build pipeline
|
| 62 |
+
self.pipeline = pipeline.Sequential([
|
| 63 |
+
self._sample_batch,
|
| 64 |
+
pipeline.Unbatch(),
|
| 65 |
+
pipeline.Parallel([self._load_instance] * num_load_workers),
|
| 66 |
+
pipeline.Parallel([self._process_instance] * num_process_workers),
|
| 67 |
+
pipeline.Batch(self.batch_size),
|
| 68 |
+
self._collate_batch,
|
| 69 |
+
pipeline.Buffer(buffer_size),
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
self.invalid_instance = {
|
| 73 |
+
'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32),
|
| 74 |
+
'image': np.zeros((256, 256, 3), dtype=np.uint8),
|
| 75 |
+
'depth': np.ones((256, 256), dtype=np.float32),
|
| 76 |
+
'depth_mask': np.ones((256, 256), dtype=bool),
|
| 77 |
+
'depth_mask_inf': np.zeros((256, 256), dtype=bool),
|
| 78 |
+
'label_type': 'invalid',
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def _sample_batch(self):
|
| 82 |
+
batch_id = 0
|
| 83 |
+
last_area = None
|
| 84 |
+
while True:
|
| 85 |
+
# Depending on the sample strategy, choose a dataset and a filename
|
| 86 |
+
batch_id += 1
|
| 87 |
+
batch = []
|
| 88 |
+
|
| 89 |
+
# Sample instances
|
| 90 |
+
for _ in range(self.batch_size):
|
| 91 |
+
dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0]
|
| 92 |
+
filename = random.choice(self.datasets[dataset_name]['filenames'])
|
| 93 |
+
|
| 94 |
+
path = Path(self.datasets[dataset_name]['path'], filename)
|
| 95 |
+
|
| 96 |
+
instance = {
|
| 97 |
+
'batch_id': batch_id,
|
| 98 |
+
'seed': random.randint(0, 2 ** 32 - 1),
|
| 99 |
+
'dataset': dataset_name,
|
| 100 |
+
'filename': filename,
|
| 101 |
+
'path': path,
|
| 102 |
+
'label_type': self.datasets[dataset_name]['label_type'],
|
| 103 |
+
}
|
| 104 |
+
batch.append(instance)
|
| 105 |
+
|
| 106 |
+
# Decide the image size for this batch
|
| 107 |
+
if self.image_size_strategy == 'fixed':
|
| 108 |
+
width, height = random.choice(self.config['image_sizes'])
|
| 109 |
+
elif self.image_size_strategy == 'aspect_area':
|
| 110 |
+
area = random.uniform(*self.area_range)
|
| 111 |
+
aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch]
|
| 112 |
+
aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges))
|
| 113 |
+
aspect_ratio = random.uniform(*aspect_ratio_range)
|
| 114 |
+
width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5)
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError('Invalid image size strategy')
|
| 117 |
+
|
| 118 |
+
for instance in batch:
|
| 119 |
+
instance['width'], instance['height'] = width, height
|
| 120 |
+
|
| 121 |
+
yield batch
|
| 122 |
+
|
| 123 |
+
def _load_instance(self, instance: dict):
|
| 124 |
+
try:
|
| 125 |
+
image = read_image(Path(instance['path'], 'image.jpg'))
|
| 126 |
+
depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png')))
|
| 127 |
+
|
| 128 |
+
meta = read_meta(Path(instance['path'], 'meta.json'))
|
| 129 |
+
intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
|
| 130 |
+
depth_mask = np.isfinite(depth)
|
| 131 |
+
depth_mask_inf = np.isinf(depth)
|
| 132 |
+
depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1)
|
| 133 |
+
data = {
|
| 134 |
+
'image': image,
|
| 135 |
+
'depth': depth,
|
| 136 |
+
'depth_mask': depth_mask,
|
| 137 |
+
'depth_mask_inf': depth_mask_inf,
|
| 138 |
+
'intrinsics': intrinsics
|
| 139 |
+
}
|
| 140 |
+
instance.update({
|
| 141 |
+
**data,
|
| 142 |
+
})
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e)
|
| 145 |
+
instance.update(self.invalid_instance)
|
| 146 |
+
return instance
|
| 147 |
+
|
| 148 |
+
def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]):
|
| 149 |
+
image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type']
|
| 150 |
+
depth_unit = self.datasets[instance['dataset']].get('depth_unit', None)
|
| 151 |
+
|
| 152 |
+
raw_height, raw_width = image.shape[:2]
|
| 153 |
+
raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
|
| 154 |
+
raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
|
| 155 |
+
raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
|
| 156 |
+
tgt_width, tgt_height = instance['width'], instance['height']
|
| 157 |
+
tgt_aspect = tgt_width / tgt_height
|
| 158 |
+
|
| 159 |
+
rng = np.random.default_rng(instance['seed'])
|
| 160 |
+
|
| 161 |
+
# 1. set target fov
|
| 162 |
+
center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation)
|
| 163 |
+
fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute)
|
| 164 |
+
fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative)
|
| 165 |
+
tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
|
| 166 |
+
tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
|
| 167 |
+
tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max)
|
| 168 |
+
tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max)
|
| 169 |
+
tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect)
|
| 170 |
+
|
| 171 |
+
# 2. set target image center (principal point) and the corresponding z-direction in raw camera space
|
| 172 |
+
center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x)
|
| 173 |
+
center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y)
|
| 174 |
+
cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2)
|
| 175 |
+
direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
|
| 176 |
+
|
| 177 |
+
# 3. obtain the rotation matrix for homography warping
|
| 178 |
+
R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
|
| 179 |
+
|
| 180 |
+
# 4. shrink the target view to fit into the warped image
|
| 181 |
+
corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
|
| 182 |
+
corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
|
| 183 |
+
corners = corners[:, :2] / corners[:, 2:3]
|
| 184 |
+
tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2
|
| 185 |
+
warp_horizontal, warp_vertical = float('inf'), float('inf')
|
| 186 |
+
for i in range(4):
|
| 187 |
+
intersection, _ = utils3d.numpy.ray_intersection(
|
| 188 |
+
np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
|
| 189 |
+
corners[i - 1], corners[i] - corners[i - 1],
|
| 190 |
+
)
|
| 191 |
+
warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
|
| 192 |
+
tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
|
| 193 |
+
|
| 194 |
+
# 5. obtain the target intrinsics
|
| 195 |
+
fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical
|
| 196 |
+
tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
|
| 197 |
+
|
| 198 |
+
# 6. do homogeneous transformation
|
| 199 |
+
# 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
|
| 200 |
+
tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
|
| 201 |
+
rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
|
| 202 |
+
image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
|
| 203 |
+
|
| 204 |
+
edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01)
|
| 205 |
+
_, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True)
|
| 206 |
+
depth_nearest = depth[resize_index]
|
| 207 |
+
distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics))
|
| 208 |
+
edge_mask = edge_mask[resize_index]
|
| 209 |
+
|
| 210 |
+
if self.depth_interpolation == 'bilinear':
|
| 211 |
+
depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
|
| 212 |
+
depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
|
| 213 |
+
distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics))
|
| 214 |
+
|
| 215 |
+
depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0
|
| 216 |
+
|
| 217 |
+
# 6.2 calculate homography warping
|
| 218 |
+
transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
|
| 219 |
+
uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
|
| 220 |
+
pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
|
| 221 |
+
uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
|
| 222 |
+
pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
|
| 223 |
+
|
| 224 |
+
tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4)
|
| 225 |
+
tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics))
|
| 226 |
+
tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
|
| 227 |
+
tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length
|
| 228 |
+
tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
|
| 229 |
+
if self.depth_interpolation == 'bilinear':
|
| 230 |
+
tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
|
| 231 |
+
tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length
|
| 232 |
+
tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest)
|
| 233 |
+
else:
|
| 234 |
+
tgt_depth = tgt_depth_nearest
|
| 235 |
+
tgt_depth_mask = tgt_depth_mask_nearest
|
| 236 |
+
|
| 237 |
+
tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
|
| 238 |
+
|
| 239 |
+
# always make sure that mask is not empty
|
| 240 |
+
if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001:
|
| 241 |
+
tgt_depth_mask = np.ones_like(tgt_depth_mask)
|
| 242 |
+
tgt_depth = np.ones_like(tgt_depth)
|
| 243 |
+
instance['label_type'] = 'invalid'
|
| 244 |
+
|
| 245 |
+
# Flip augmentation
|
| 246 |
+
if rng.choice([True, False]):
|
| 247 |
+
tgt_image = np.flip(tgt_image, axis=1).copy()
|
| 248 |
+
tgt_depth = np.flip(tgt_depth, axis=1).copy()
|
| 249 |
+
tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy()
|
| 250 |
+
tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy()
|
| 251 |
+
|
| 252 |
+
# Color augmentation
|
| 253 |
+
image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation)
|
| 254 |
+
if 'jittering' in image_augmentation:
|
| 255 |
+
tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1)
|
| 256 |
+
tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3))
|
| 257 |
+
tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3))
|
| 258 |
+
tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3))
|
| 259 |
+
tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1))
|
| 260 |
+
tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3))
|
| 261 |
+
tgt_image = tgt_image.permute(1, 2, 0).numpy()
|
| 262 |
+
if 'dof' in image_augmentation:
|
| 263 |
+
if rng.uniform() < 0.5:
|
| 264 |
+
dof_strength = rng.integers(12)
|
| 265 |
+
tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth)
|
| 266 |
+
disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max()
|
| 267 |
+
tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max)
|
| 268 |
+
dof_focus = rng.uniform(disp_min, disp_max)
|
| 269 |
+
tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength)
|
| 270 |
+
if 'shot_noise' in image_augmentation:
|
| 271 |
+
if rng.uniform() < 0.5:
|
| 272 |
+
k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255
|
| 273 |
+
tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8)
|
| 274 |
+
if 'jpeg_loss' in image_augmentation:
|
| 275 |
+
if rng.uniform() < 0.5:
|
| 276 |
+
tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR)
|
| 277 |
+
if 'blurring' in image_augmentation:
|
| 278 |
+
if rng.uniform() < 0.5:
|
| 279 |
+
ratio = rng.uniform(0.25, 1)
|
| 280 |
+
tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]))
|
| 281 |
+
|
| 282 |
+
# convert depth to metric if necessary
|
| 283 |
+
if depth_unit is not None:
|
| 284 |
+
tgt_depth *= depth_unit
|
| 285 |
+
instance['is_metric'] = True
|
| 286 |
+
else:
|
| 287 |
+
instance['is_metric'] = False
|
| 288 |
+
|
| 289 |
+
# clamp depth maximum values
|
| 290 |
+
max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth
|
| 291 |
+
tgt_depth = np.clip(tgt_depth, 0, max_depth)
|
| 292 |
+
tgt_depth = np.nan_to_num(tgt_depth, nan=1.0)
|
| 293 |
+
|
| 294 |
+
if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known":
|
| 295 |
+
tgt_depth_mask_fin = tgt_depth_mask
|
| 296 |
+
else:
|
| 297 |
+
tgt_depth_mask_fin = ~tgt_depth_mask_inf
|
| 298 |
+
|
| 299 |
+
instance.update({
|
| 300 |
+
'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
|
| 301 |
+
'depth': torch.from_numpy(tgt_depth).float(),
|
| 302 |
+
'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
|
| 303 |
+
'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(),
|
| 304 |
+
'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(),
|
| 305 |
+
'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
return instance
|
| 309 |
+
|
| 310 |
+
def _collate_batch(self, instances: List[Dict[str, Any]]):
|
| 311 |
+
batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']}
|
| 312 |
+
batch = {
|
| 313 |
+
'label_type': [instance['label_type'] for instance in instances],
|
| 314 |
+
'is_metric': [instance['is_metric'] for instance in instances],
|
| 315 |
+
'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances],
|
| 316 |
+
**batch,
|
| 317 |
+
}
|
| 318 |
+
return batch
|
| 319 |
+
|
| 320 |
+
def get(self) -> Dict[str, Union[torch.Tensor, str]]:
|
| 321 |
+
return self.pipeline.get()
|
| 322 |
+
|
| 323 |
+
def start(self):
|
| 324 |
+
self.pipeline.start()
|
| 325 |
+
|
| 326 |
+
def stop(self):
|
| 327 |
+
self.pipeline.stop()
|
| 328 |
+
|
| 329 |
+
def __enter__(self):
|
| 330 |
+
self.start()
|
| 331 |
+
return self
|
| 332 |
+
|
| 333 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 334 |
+
self.pipeline.terminate()
|
| 335 |
+
self.pipeline.join()
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
|
moge/train/losses.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import utils3d
|
| 7 |
+
|
| 8 |
+
from ..utils.geometry_torch import (
|
| 9 |
+
weighted_mean,
|
| 10 |
+
harmonic_mean,
|
| 11 |
+
geometric_mean,
|
| 12 |
+
mask_aware_nearest_resize,
|
| 13 |
+
normalized_view_plane_uv,
|
| 14 |
+
angle_diff_vec3
|
| 15 |
+
)
|
| 16 |
+
from ..utils.alignment import (
|
| 17 |
+
align_points_scale_z_shift,
|
| 18 |
+
align_points_scale,
|
| 19 |
+
align_points_scale_xyz_shift,
|
| 20 |
+
align_points_z_shift,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
|
| 25 |
+
if beta == 0:
|
| 26 |
+
return err
|
| 27 |
+
else:
|
| 28 |
+
return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def affine_invariant_global_loss(
|
| 32 |
+
pred_points: torch.Tensor,
|
| 33 |
+
gt_points: torch.Tensor,
|
| 34 |
+
mask: torch.Tensor,
|
| 35 |
+
align_resolution: int = 64,
|
| 36 |
+
beta: float = 0.0,
|
| 37 |
+
trunc: float = 1.0,
|
| 38 |
+
sparsity_aware: bool = False
|
| 39 |
+
):
|
| 40 |
+
device = pred_points.device
|
| 41 |
+
|
| 42 |
+
# Align
|
| 43 |
+
(pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution))
|
| 44 |
+
scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
|
| 45 |
+
valid = scale > 0
|
| 46 |
+
scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
|
| 47 |
+
|
| 48 |
+
pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
|
| 49 |
+
|
| 50 |
+
# Compute loss
|
| 51 |
+
weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
|
| 52 |
+
weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
|
| 53 |
+
loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
|
| 54 |
+
|
| 55 |
+
if sparsity_aware:
|
| 56 |
+
# Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
|
| 57 |
+
sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
|
| 58 |
+
loss = loss / (sparsity + 1e-7)
|
| 59 |
+
|
| 60 |
+
err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
|
| 61 |
+
|
| 62 |
+
# Record any scalar metric
|
| 63 |
+
misc = {
|
| 64 |
+
'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
|
| 65 |
+
'delta': weighted_mean((err < 1).float(), mask).item()
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return loss, misc, scale.detach()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def monitoring(points: torch.Tensor):
|
| 72 |
+
return {
|
| 73 |
+
'std': points.std().item(),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def compute_anchor_sampling_weight(
|
| 78 |
+
points: torch.Tensor,
|
| 79 |
+
mask: torch.Tensor,
|
| 80 |
+
radius_2d: torch.Tensor,
|
| 81 |
+
radius_3d: torch.Tensor,
|
| 82 |
+
num_test: int = 64
|
| 83 |
+
) -> torch.Tensor:
|
| 84 |
+
# Importance sampling to balance the sampled probability of fine strutures.
|
| 85 |
+
# NOTE: MoGe-1 uses uniform random sampling instead of importance sampling.
|
| 86 |
+
# This is an incremental trick introduced later than the publication of MoGe-1 paper.
|
| 87 |
+
|
| 88 |
+
height, width = points.shape[-3:-1]
|
| 89 |
+
|
| 90 |
+
pixel_i, pixel_j = torch.meshgrid(
|
| 91 |
+
torch.arange(height, device=points.device),
|
| 92 |
+
torch.arange(width, device=points.device),
|
| 93 |
+
indexing='ij'
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
|
| 97 |
+
test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
|
| 98 |
+
test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test]
|
| 99 |
+
test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test]
|
| 100 |
+
test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test]
|
| 101 |
+
test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test]
|
| 102 |
+
test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3]
|
| 103 |
+
test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test]
|
| 104 |
+
|
| 105 |
+
weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1)
|
| 106 |
+
weight = torch.where(mask, weight, 0)
|
| 107 |
+
weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width]
|
| 108 |
+
return weight
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def affine_invariant_local_loss(
|
| 112 |
+
pred_points: torch.Tensor,
|
| 113 |
+
gt_points: torch.Tensor,
|
| 114 |
+
gt_mask: torch.Tensor,
|
| 115 |
+
focal: torch.Tensor,
|
| 116 |
+
global_scale: torch.Tensor,
|
| 117 |
+
level: Literal[4, 16, 64],
|
| 118 |
+
align_resolution: int = 32,
|
| 119 |
+
num_patches: int = 16,
|
| 120 |
+
beta: float = 0.0,
|
| 121 |
+
trunc: float = 1.0,
|
| 122 |
+
sparsity_aware: bool = False
|
| 123 |
+
):
|
| 124 |
+
device, dtype = pred_points.device, pred_points.dtype
|
| 125 |
+
*batch_shape, height, width, _ = pred_points.shape
|
| 126 |
+
batch_size = math.prod(batch_shape)
|
| 127 |
+
pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None
|
| 128 |
+
|
| 129 |
+
# Sample patch anchor points indices [num_total_patches]
|
| 130 |
+
radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5)
|
| 131 |
+
radius_3d = 0.5 / level / focal * gt_points[..., 2]
|
| 132 |
+
anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64)
|
| 133 |
+
where_mask = torch.where(gt_mask)
|
| 134 |
+
random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True)
|
| 135 |
+
patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches]
|
| 136 |
+
|
| 137 |
+
# Get patch indices [num_total_patches, patch_h, patch_w]
|
| 138 |
+
patch_i, patch_j = torch.meshgrid(
|
| 139 |
+
torch.arange(-radius_2d, radius_2d + 1, device=device),
|
| 140 |
+
torch.arange(-radius_2d, radius_2d + 1, device=device),
|
| 141 |
+
indexing='ij'
|
| 142 |
+
)
|
| 143 |
+
patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None]
|
| 144 |
+
patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width)
|
| 145 |
+
patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1)
|
| 146 |
+
|
| 147 |
+
# Get patch mask and gt patch points
|
| 148 |
+
gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j]
|
| 149 |
+
gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2]
|
| 150 |
+
gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j]
|
| 151 |
+
gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1)
|
| 152 |
+
patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j]
|
| 153 |
+
patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None]
|
| 154 |
+
|
| 155 |
+
# Pick only non-empty patches
|
| 156 |
+
MINIMUM_POINTS_PER_PATCH = 32
|
| 157 |
+
nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH)
|
| 158 |
+
num_nonempty_patches = nonempty[0].shape[0]
|
| 159 |
+
if num_nonempty_patches == 0:
|
| 160 |
+
return torch.tensor(0.0, dtype=dtype, device=device), {}
|
| 161 |
+
|
| 162 |
+
# Finalize all patch variables
|
| 163 |
+
patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty]
|
| 164 |
+
patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w]
|
| 165 |
+
gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3]
|
| 166 |
+
gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches]
|
| 167 |
+
gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3]
|
| 168 |
+
pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j]
|
| 169 |
+
|
| 170 |
+
# Align patch points
|
| 171 |
+
(pred_patch_points_lr, gt_patch_points_lr), patch_lr_mask = mask_aware_nearest_resize((pred_patch_points, gt_patch_points), mask=patch_mask, size=(align_resolution, align_resolution))
|
| 172 |
+
local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc)
|
| 173 |
+
if global_scale is not None:
|
| 174 |
+
scale_differ = local_scale / global_scale[patch_batch_idx]
|
| 175 |
+
patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0)
|
| 176 |
+
else:
|
| 177 |
+
patch_valid = local_scale > 0
|
| 178 |
+
local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0)
|
| 179 |
+
patch_mask &= patch_valid[:, None, None]
|
| 180 |
+
|
| 181 |
+
pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3]
|
| 182 |
+
|
| 183 |
+
# Compute loss
|
| 184 |
+
gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1))
|
| 185 |
+
patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w]
|
| 186 |
+
loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty]
|
| 187 |
+
|
| 188 |
+
if sparsity_aware:
|
| 189 |
+
# Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
|
| 190 |
+
sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1))
|
| 191 |
+
loss = loss / (sparsity + 1e-7)
|
| 192 |
+
loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches
|
| 193 |
+
loss = loss.reshape(batch_shape)
|
| 194 |
+
|
| 195 |
+
err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None]
|
| 196 |
+
|
| 197 |
+
# Record any scalar metric
|
| 198 |
+
misc = {
|
| 199 |
+
'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(),
|
| 200 |
+
'delta': weighted_mean((err < 1).float(), patch_mask).item()
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
return loss, misc
|
| 204 |
+
|
| 205 |
+
def normal_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 206 |
+
device, dtype = points.device, points.dtype
|
| 207 |
+
height, width = points.shape[-3:-1]
|
| 208 |
+
|
| 209 |
+
leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :]
|
| 210 |
+
upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
|
| 211 |
+
leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
|
| 212 |
+
downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
|
| 213 |
+
rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
|
| 214 |
+
|
| 215 |
+
gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :]
|
| 216 |
+
gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1)
|
| 217 |
+
gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1)
|
| 218 |
+
gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
|
| 219 |
+
gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1)
|
| 220 |
+
|
| 221 |
+
mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:]
|
| 222 |
+
mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
|
| 223 |
+
mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
|
| 224 |
+
mask_downxright = mask_leftdown & mask_rightup & mask_leftup
|
| 225 |
+
mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
|
| 226 |
+
|
| 227 |
+
MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
|
| 228 |
+
|
| 229 |
+
loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
|
| 230 |
+
+ mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
|
| 231 |
+
+ mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
|
| 232 |
+
+ mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
|
| 233 |
+
|
| 234 |
+
loss = loss.mean() / (4 * max(points.shape[-3:-1]))
|
| 235 |
+
|
| 236 |
+
return loss, {}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def edge_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
device, dtype = points.device, points.dtype
|
| 241 |
+
height, width = points.shape[-3:-1]
|
| 242 |
+
|
| 243 |
+
dx = points[..., :-1, :, :] - points[..., 1:, :, :]
|
| 244 |
+
dy = points[..., :, :-1, :] - points[..., :, 1:, :]
|
| 245 |
+
|
| 246 |
+
gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :]
|
| 247 |
+
gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :]
|
| 248 |
+
|
| 249 |
+
mask_dx = mask[..., :-1, :] & mask[..., 1:, :]
|
| 250 |
+
mask_dy = mask[..., :, :-1] & mask[..., :, 1:]
|
| 251 |
+
|
| 252 |
+
MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3)
|
| 253 |
+
|
| 254 |
+
loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
|
| 255 |
+
loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
|
| 256 |
+
loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1]))
|
| 257 |
+
|
| 258 |
+
return loss, {}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
|
| 262 |
+
loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square()
|
| 263 |
+
loss = loss.mean(dim=(-2, -1))
|
| 264 |
+
return loss, {}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
|
| 268 |
+
loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none')
|
| 269 |
+
loss = loss.mean(dim=(-2, -1))
|
| 270 |
+
return loss, {}
|
moge/train/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import fnmatch
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def any_match(s: str, patterns: List[str]) -> bool:
|
| 10 |
+
return any(fnmatch.fnmatch(s, pat) for pat in patterns)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer:
|
| 14 |
+
named_param_groups = [
|
| 15 |
+
{
|
| 16 |
+
k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', []))
|
| 17 |
+
} for param_group_config in optimizer_config['params']
|
| 18 |
+
]
|
| 19 |
+
excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)]
|
| 20 |
+
assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}'
|
| 21 |
+
optimizer_cls = getattr(torch.optim, optimizer_config['type'])
|
| 22 |
+
optimizer = optimizer_cls([
|
| 23 |
+
{
|
| 24 |
+
**param_group_config,
|
| 25 |
+
'params': list(params.values()),
|
| 26 |
+
} for param_group_config, params in zip(optimizer_config['params'], named_param_groups)
|
| 27 |
+
])
|
| 28 |
+
return optimizer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_lr_lambda(s: str) -> Callable[[int], float]:
|
| 32 |
+
epoch = sympy.symbols('epoch')
|
| 33 |
+
lr_lambda = sympy.sympify(s)
|
| 34 |
+
return sympy.lambdify(epoch, lr_lambda, 'math')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler:
|
| 38 |
+
if scheduler_config['type'] == "SequentialLR":
|
| 39 |
+
child_schedulers = [
|
| 40 |
+
build_lr_scheduler(optimizer, child_scheduler_config)
|
| 41 |
+
for child_scheduler_config in scheduler_config['params']['schedulers']
|
| 42 |
+
]
|
| 43 |
+
return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones'])
|
| 44 |
+
elif scheduler_config['type'] == "LambdaLR":
|
| 45 |
+
lr_lambda = scheduler_config['params']['lr_lambda']
|
| 46 |
+
if isinstance(lr_lambda, str):
|
| 47 |
+
lr_lambda = parse_lr_lambda(lr_lambda)
|
| 48 |
+
elif isinstance(lr_lambda, list):
|
| 49 |
+
lr_lambda = [parse_lr_lambda(l) for l in lr_lambda]
|
| 50 |
+
return torch.optim.lr_scheduler.LambdaLR(
|
| 51 |
+
optimizer,
|
| 52 |
+
lr_lambda=lr_lambda,
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type'])
|
| 56 |
+
scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {}))
|
| 57 |
+
return scheduler
|
moge/utils/__init__.py
ADDED
|
File without changes
|
moge/utils/alignment.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.types
|
| 10 |
+
import utils3d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
|
| 14 |
+
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
|
| 15 |
+
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
|
| 16 |
+
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
|
| 17 |
+
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
|
| 18 |
+
indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
|
| 19 |
+
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
|
| 20 |
+
return torch.return_types.min((minimum, indices))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
|
| 24 |
+
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
|
| 25 |
+
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
|
| 26 |
+
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
|
| 27 |
+
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
|
| 28 |
+
results = []
|
| 29 |
+
for i in range(n_chunks):
|
| 30 |
+
chunk_args = tuple(arg[i] for arg in splited_args)
|
| 31 |
+
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
|
| 32 |
+
results.append(fn(*chunk_args, **chunk_kwargs))
|
| 33 |
+
|
| 34 |
+
if isinstance(results[0], tuple):
|
| 35 |
+
return tuple(torch.cat(r, dim=0) for r in zip(*results))
|
| 36 |
+
else:
|
| 37 |
+
return torch.cat(results, dim=0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _pad_inf(x_: torch.Tensor):
|
| 41 |
+
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _pad_cumsum(cumsum: torch.Tensor):
|
| 45 |
+
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
|
| 49 |
+
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
| 53 |
+
"""
|
| 54 |
+
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
|
| 55 |
+
|
| 56 |
+
w_i must be >= 0.
|
| 57 |
+
|
| 58 |
+
### Parameters:
|
| 59 |
+
- `x`: tensor of shape (..., n)
|
| 60 |
+
- `y`: tensor of shape (..., n)
|
| 61 |
+
- `w`: tensor of shape (..., n)
|
| 62 |
+
- `trunc`: optional, float or tensor of shape (..., n) or None
|
| 63 |
+
|
| 64 |
+
### Returns:
|
| 65 |
+
- `a`: tensor of shape (...), differentiable
|
| 66 |
+
- `loss`: tensor of shape (...), value of loss function at `a`, detached
|
| 67 |
+
- `index`: tensor of shape (...), where a = y[idx] / x[idx]
|
| 68 |
+
"""
|
| 69 |
+
if trunc is None:
|
| 70 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 71 |
+
sign = torch.sign(x)
|
| 72 |
+
x, y = x * sign, y * sign
|
| 73 |
+
y_div_x = y / x.clamp_min(eps)
|
| 74 |
+
y_div_x, argsort = y_div_x.sort(dim=-1)
|
| 75 |
+
|
| 76 |
+
wx = torch.gather(x * w, dim=-1, index=argsort)
|
| 77 |
+
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
|
| 78 |
+
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
|
| 79 |
+
|
| 80 |
+
a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
|
| 81 |
+
index = argsort.gather(dim=-1, index=search).squeeze(-1)
|
| 82 |
+
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
# Reshape to (batch_size, n) for simplicity
|
| 86 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 87 |
+
batch_shape = x.shape[:-1]
|
| 88 |
+
batch_size = math.prod(batch_shape)
|
| 89 |
+
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
|
| 90 |
+
|
| 91 |
+
sign = torch.sign(x)
|
| 92 |
+
x, y = x * sign, y * sign
|
| 93 |
+
wx, wy = w * x, w * y
|
| 94 |
+
xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
|
| 95 |
+
|
| 96 |
+
y_div_x = A = y / x.clamp_min(eps)
|
| 97 |
+
B = (wy - trunc) / wx.clamp_min(eps)
|
| 98 |
+
C = (wy + trunc) / wx.clamp_min(eps)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
# Caculate prefix sum by orders of A, B, C
|
| 101 |
+
A, A_argsort = A.sort(dim=-1)
|
| 102 |
+
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
|
| 103 |
+
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
|
| 104 |
+
|
| 105 |
+
B, B_argsort = B.sort(dim=-1)
|
| 106 |
+
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
|
| 107 |
+
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
|
| 108 |
+
|
| 109 |
+
C, C_argsort = C.sort(dim=-1)
|
| 110 |
+
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
|
| 111 |
+
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
|
| 112 |
+
|
| 113 |
+
# Caculate left and right derivative of A
|
| 114 |
+
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
|
| 115 |
+
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
|
| 116 |
+
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
|
| 117 |
+
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 118 |
+
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
|
| 119 |
+
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
|
| 120 |
+
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
|
| 121 |
+
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 122 |
+
|
| 123 |
+
# Find extrema
|
| 124 |
+
is_extrema = (left_derivative < 0) & (right_derivative >= 0)
|
| 125 |
+
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
|
| 126 |
+
where_extrema_batch, where_extrema_index = torch.where(is_extrema)
|
| 127 |
+
|
| 128 |
+
# Calculate objective value at extrema
|
| 129 |
+
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
|
| 130 |
+
MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
|
| 131 |
+
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
|
| 132 |
+
extrema_value = torch.cat([
|
| 133 |
+
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
|
| 134 |
+
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
|
| 135 |
+
]) # (num_extrema,)
|
| 136 |
+
|
| 137 |
+
# Find minima among corresponding extrema
|
| 138 |
+
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
|
| 139 |
+
index = where_extrema_index[indices]
|
| 140 |
+
|
| 141 |
+
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
|
| 142 |
+
a = a.reshape(batch_shape)
|
| 143 |
+
loss = minima.reshape(batch_shape)
|
| 144 |
+
index = index.reshape(batch_shape)
|
| 145 |
+
|
| 146 |
+
return a, loss, index
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 150 |
+
"""
|
| 151 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 152 |
+
|
| 153 |
+
### Parameters:
|
| 154 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 155 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
|
| 159 |
+
|
| 160 |
+
return scale
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 164 |
+
"""
|
| 165 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 166 |
+
|
| 167 |
+
### Parameters:
|
| 168 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 169 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 170 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 171 |
+
- `trunc: float` or tensor of shape (..., N) or None
|
| 172 |
+
|
| 173 |
+
### Returns:
|
| 174 |
+
- `scale: torch.Tensor` of shape (...).
|
| 175 |
+
- `shift: torch.Tensor` of shape (...).
|
| 176 |
+
"""
|
| 177 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 178 |
+
|
| 179 |
+
# Flatten batch dimensions for simplicity
|
| 180 |
+
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
|
| 181 |
+
batch_size = math.prod(batch_shape)
|
| 182 |
+
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
|
| 183 |
+
|
| 184 |
+
# Here, we take anchors only for non-zero weights.
|
| 185 |
+
# Although the results will be still correct even anchor points have zero weight,
|
| 186 |
+
# it is wasting computation and may cause instability in some cases, e.g. too many extrema.
|
| 187 |
+
anchors_where_batch, anchors_where_n = torch.where(weight > 0)
|
| 188 |
+
|
| 189 |
+
# Stop gradient when solving optimal anchors
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
|
| 192 |
+
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
|
| 193 |
+
|
| 194 |
+
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
|
| 195 |
+
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
|
| 196 |
+
weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
|
| 197 |
+
|
| 198 |
+
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
|
| 199 |
+
|
| 200 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
|
| 201 |
+
|
| 202 |
+
# Reproduce by indexing for shorter compute graph
|
| 203 |
+
index_1 = anchors_where_n[index_anchor] # (batch_size,)
|
| 204 |
+
index_2 = index[index_anchor] # (batch_size,)
|
| 205 |
+
|
| 206 |
+
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
|
| 207 |
+
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
|
| 208 |
+
|
| 209 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
|
| 210 |
+
shift = tgt_1 - scale * src_1
|
| 211 |
+
|
| 212 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
|
| 213 |
+
|
| 214 |
+
return scale, shift
|
| 215 |
+
|
| 216 |
+
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
|
| 217 |
+
"""
|
| 218 |
+
Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
|
| 219 |
+
"""
|
| 220 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 221 |
+
|
| 222 |
+
w = weight
|
| 223 |
+
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
|
| 224 |
+
y = depth_tgt
|
| 225 |
+
|
| 226 |
+
for i in range(max_iter):
|
| 227 |
+
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
|
| 228 |
+
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
|
| 229 |
+
|
| 230 |
+
return beta[..., 0], beta[..., 1]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 234 |
+
"""
|
| 235 |
+
### Parameters:
|
| 236 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 237 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 238 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 239 |
+
|
| 240 |
+
### Returns:
|
| 241 |
+
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
|
| 242 |
+
- `b: torch.Tensor` of shape (...)
|
| 243 |
+
"""
|
| 244 |
+
dtype, device = points_src.dtype, points_src.device
|
| 245 |
+
|
| 246 |
+
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
|
| 247 |
+
|
| 248 |
+
return scale
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 252 |
+
"""
|
| 253 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 254 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 255 |
+
|
| 256 |
+
### Parameters:
|
| 257 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 258 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 259 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 260 |
+
|
| 261 |
+
### Returns:
|
| 262 |
+
- `scale: torch.Tensor` of shape (...).
|
| 263 |
+
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
|
| 264 |
+
"""
|
| 265 |
+
dtype, device = points_src.dtype, points_src.device
|
| 266 |
+
|
| 267 |
+
# Flatten batch dimensions for simplicity
|
| 268 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 269 |
+
batch_size = math.prod(batch_shape)
|
| 270 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 271 |
+
|
| 272 |
+
# Take anchors
|
| 273 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
|
| 276 |
+
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 277 |
+
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 278 |
+
|
| 279 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 280 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 281 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 282 |
+
|
| 283 |
+
# Solve optimal scale and shift for each anchor
|
| 284 |
+
MAX_ELEMENTS = 2 ** 20
|
| 285 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 286 |
+
|
| 287 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 288 |
+
|
| 289 |
+
# Reproduce by indexing for shorter compute graph
|
| 290 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 291 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 292 |
+
|
| 293 |
+
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
|
| 294 |
+
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
|
| 295 |
+
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 296 |
+
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 297 |
+
|
| 298 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 299 |
+
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 300 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 301 |
+
|
| 302 |
+
return scale, shift
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 306 |
+
"""
|
| 307 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 308 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 309 |
+
|
| 310 |
+
### Parameters:
|
| 311 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 312 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 313 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 314 |
+
|
| 315 |
+
### Returns:
|
| 316 |
+
- `scale: torch.Tensor` of shape (...).
|
| 317 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 318 |
+
"""
|
| 319 |
+
dtype, device = points_src.dtype, points_src.device
|
| 320 |
+
|
| 321 |
+
# Flatten batch dimensions for simplicity
|
| 322 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 323 |
+
batch_size = math.prod(batch_shape)
|
| 324 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 325 |
+
|
| 326 |
+
# Take anchors
|
| 327 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 328 |
+
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 331 |
+
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 332 |
+
|
| 333 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 334 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 335 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 336 |
+
|
| 337 |
+
# Solve optimal scale and shift for each anchor
|
| 338 |
+
MAX_ELEMENTS = 2 ** 20
|
| 339 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 340 |
+
|
| 341 |
+
# Get optimal scale and shift for each batch element
|
| 342 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 343 |
+
|
| 344 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 345 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 346 |
+
|
| 347 |
+
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 348 |
+
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 349 |
+
|
| 350 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 351 |
+
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 352 |
+
|
| 353 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 354 |
+
|
| 355 |
+
return scale, shift
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 359 |
+
"""
|
| 360 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 361 |
+
|
| 362 |
+
### Parameters:
|
| 363 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 364 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 365 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 366 |
+
|
| 367 |
+
### Returns:
|
| 368 |
+
- `scale: torch.Tensor` of shape (...).
|
| 369 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 370 |
+
"""
|
| 371 |
+
dtype, device = points_src.dtype, points_src.device
|
| 372 |
+
|
| 373 |
+
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
|
| 374 |
+
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
|
| 375 |
+
|
| 376 |
+
return shift
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 380 |
+
"""
|
| 381 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 382 |
+
|
| 383 |
+
### Parameters:
|
| 384 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 385 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 386 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 387 |
+
|
| 388 |
+
### Returns:
|
| 389 |
+
- `scale: torch.Tensor` of shape (...).
|
| 390 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 391 |
+
"""
|
| 392 |
+
dtype, device = points_src.dtype, points_src.device
|
| 393 |
+
|
| 394 |
+
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
|
| 395 |
+
|
| 396 |
+
return shift
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 400 |
+
"""
|
| 401 |
+
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
|
| 402 |
+
|
| 403 |
+
### Parameters:
|
| 404 |
+
- `x: torch.Tensor` of shape (..., N)
|
| 405 |
+
- `y: torch.Tensor` of shape (..., N)
|
| 406 |
+
- `w: torch.Tensor` of shape (..., N)
|
| 407 |
+
|
| 408 |
+
### Returns:
|
| 409 |
+
- `a: torch.Tensor` of shape (...,)
|
| 410 |
+
- `b: torch.Tensor` of shape (...,)
|
| 411 |
+
"""
|
| 412 |
+
w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
|
| 413 |
+
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
|
| 414 |
+
B = (w_sqrt * y)[..., None]
|
| 415 |
+
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
|
| 416 |
+
return a, b
|
moge/utils/download.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import *
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = ["download_file", "download_bytes"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
|
| 12 |
+
# Ensure headers is a dict if not provided
|
| 13 |
+
headers = headers or {}
|
| 14 |
+
|
| 15 |
+
# Initialize local variables
|
| 16 |
+
file_path = Path(filepath)
|
| 17 |
+
downloaded_bytes = 0
|
| 18 |
+
|
| 19 |
+
# Check if we should resume the download
|
| 20 |
+
if resume and file_path.exists():
|
| 21 |
+
downloaded_bytes = file_path.stat().st_size
|
| 22 |
+
headers['Range'] = f"bytes={downloaded_bytes}-"
|
| 23 |
+
|
| 24 |
+
# Make a GET request to fetch the file
|
| 25 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
| 26 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
| 27 |
+
|
| 28 |
+
# Calculate the total size to download
|
| 29 |
+
total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
|
| 30 |
+
|
| 31 |
+
# Display a progress bar while downloading
|
| 32 |
+
with (
|
| 33 |
+
tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
|
| 34 |
+
open(file_path, 'ab') as file,
|
| 35 |
+
):
|
| 36 |
+
# Set the initial position of the progress bar
|
| 37 |
+
pbar.update(downloaded_bytes)
|
| 38 |
+
|
| 39 |
+
# Write the content to the file in chunks
|
| 40 |
+
for chunk in response.iter_content(chunk_size=4096):
|
| 41 |
+
file.write(chunk)
|
| 42 |
+
pbar.update(len(chunk))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def download_bytes(url: str, headers: dict = None) -> bytes:
|
| 46 |
+
# Ensure headers is a dict if not provided
|
| 47 |
+
headers = headers or {}
|
| 48 |
+
|
| 49 |
+
# Make a GET request to fetch the file
|
| 50 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
| 51 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
| 52 |
+
|
| 53 |
+
# Read the content of the response
|
| 54 |
+
return response.content
|
| 55 |
+
|
moge/utils/geometry_numpy.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from functools import partial
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy.signal import fftconvolve
|
| 8 |
+
import numpy as np
|
| 9 |
+
import utils3d
|
| 10 |
+
|
| 11 |
+
from .tools import timeit
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 15 |
+
if w is None:
|
| 16 |
+
return np.mean(x, axis=axis)
|
| 17 |
+
else:
|
| 18 |
+
w = w.astype(x.dtype)
|
| 19 |
+
return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 23 |
+
if w is None:
|
| 24 |
+
return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
|
| 25 |
+
else:
|
| 26 |
+
w = w.astype(x.dtype)
|
| 27 |
+
return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
|
| 31 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 32 |
+
if aspect_ratio is None:
|
| 33 |
+
aspect_ratio = width / height
|
| 34 |
+
|
| 35 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 36 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 37 |
+
|
| 38 |
+
u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
|
| 39 |
+
v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
|
| 40 |
+
u, v = np.meshgrid(u, v, indexing='xy')
|
| 41 |
+
uv = np.stack([u, v], axis=-1)
|
| 42 |
+
return uv
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def focal_to_fov_numpy(focal: np.ndarray):
|
| 46 |
+
return 2 * np.arctan(0.5 / focal)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fov_to_focal_numpy(fov: np.ndarray):
|
| 50 |
+
return 0.5 / np.tan(fov / 2)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 54 |
+
fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
|
| 55 |
+
fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
|
| 56 |
+
return fov_x, fov_y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def point_map_to_depth_legacy_numpy(points: np.ndarray):
|
| 60 |
+
height, width = points.shape[-3:-1]
|
| 61 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 62 |
+
uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
|
| 63 |
+
_, uv = np.broadcast_arrays(points[..., :2], uv)
|
| 64 |
+
|
| 65 |
+
# Solve least squares problem
|
| 66 |
+
b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
|
| 67 |
+
A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
|
| 68 |
+
|
| 69 |
+
M = A.swapaxes(-2, -1) @ A
|
| 70 |
+
solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
|
| 71 |
+
focal, shift = solution
|
| 72 |
+
|
| 73 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 74 |
+
fov_x = np.arctan(width / diagonal / focal) * 2
|
| 75 |
+
fov_y = np.arctan(height / diagonal / focal) * 2
|
| 76 |
+
return depth, fov_x, fov_y, shift
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
|
| 80 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
| 81 |
+
from scipy.optimize import least_squares
|
| 82 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 83 |
+
|
| 84 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 85 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 86 |
+
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 87 |
+
err = (f * xy_proj - uv).ravel()
|
| 88 |
+
return err
|
| 89 |
+
|
| 90 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 91 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 92 |
+
|
| 93 |
+
xy_proj = xy / (z + optim_shift)[: , None]
|
| 94 |
+
optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 95 |
+
|
| 96 |
+
return optim_shift, optim_focal
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
|
| 100 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
|
| 101 |
+
from scipy.optimize import least_squares
|
| 102 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 103 |
+
|
| 104 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 105 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 106 |
+
err = (focal * xy_proj - uv).ravel()
|
| 107 |
+
return err
|
| 108 |
+
|
| 109 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 110 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 111 |
+
|
| 112 |
+
return optim_shift
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 116 |
+
import cv2
|
| 117 |
+
assert points.shape[-1] == 3, "Points should (H, W, 3)"
|
| 118 |
+
|
| 119 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 120 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 121 |
+
|
| 122 |
+
uv = normalized_view_plane_uv_numpy(width=width, height=height)
|
| 123 |
+
|
| 124 |
+
if mask is None:
|
| 125 |
+
points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
|
| 126 |
+
uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
|
| 127 |
+
else:
|
| 128 |
+
(points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
|
| 129 |
+
|
| 130 |
+
if points_lr.size < 2:
|
| 131 |
+
return 1., 0.
|
| 132 |
+
|
| 133 |
+
if focal is None:
|
| 134 |
+
focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
|
| 135 |
+
else:
|
| 136 |
+
shift = solve_optimal_shift(uv_lr, points_lr, focal)
|
| 137 |
+
|
| 138 |
+
return focal, shift
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def mask_aware_nearest_resize_numpy(
|
| 142 |
+
inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
|
| 143 |
+
mask: np.ndarray,
|
| 144 |
+
size: Tuple[int, int],
|
| 145 |
+
return_index: bool = False
|
| 146 |
+
) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
|
| 147 |
+
"""
|
| 148 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 149 |
+
|
| 150 |
+
### Parameters
|
| 151 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 152 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 153 |
+
- `size`: target size (width, height)
|
| 154 |
+
|
| 155 |
+
### Returns
|
| 156 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 157 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 158 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
|
| 159 |
+
"""
|
| 160 |
+
height, width = mask.shape[-2:]
|
| 161 |
+
target_width, target_height = size
|
| 162 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 163 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 164 |
+
filter_size = filter_h_i * filter_w_i
|
| 165 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 166 |
+
|
| 167 |
+
# Window the original mask and uv
|
| 168 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 169 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 170 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 171 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 172 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 173 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 174 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 175 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 176 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 177 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 178 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 179 |
+
|
| 180 |
+
# Gather the target pixels's local window
|
| 181 |
+
target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 182 |
+
target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 183 |
+
target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 184 |
+
|
| 185 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 186 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 187 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 188 |
+
|
| 189 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 190 |
+
dist = np.square(target_window_centers - target_centers[..., None])
|
| 191 |
+
dist = dist[..., 0, :] + dist[..., 1, :]
|
| 192 |
+
dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
|
| 193 |
+
nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
|
| 194 |
+
nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 195 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 196 |
+
target_mask = np.any(target_window_mask, axis=-1)
|
| 197 |
+
batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 198 |
+
|
| 199 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 200 |
+
|
| 201 |
+
if inputs is None:
|
| 202 |
+
outputs = None
|
| 203 |
+
elif isinstance(inputs, np.ndarray):
|
| 204 |
+
outputs = inputs[index]
|
| 205 |
+
elif isinstance(inputs, Sequence):
|
| 206 |
+
outputs = tuple(x[index] for x in inputs)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 209 |
+
|
| 210 |
+
if return_index:
|
| 211 |
+
return outputs, target_mask, index
|
| 212 |
+
else:
|
| 213 |
+
return outputs, target_mask
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
|
| 217 |
+
"""
|
| 218 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 219 |
+
|
| 220 |
+
### Parameters
|
| 221 |
+
- `image`: Input 2D image of shape (..., H, W, C)
|
| 222 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
| 223 |
+
- `target_width`: target width of the resized map
|
| 224 |
+
- `target_height`: target height of the resized map
|
| 225 |
+
|
| 226 |
+
### Returns
|
| 227 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
|
| 228 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
| 229 |
+
"""
|
| 230 |
+
height, width = mask.shape[-2:]
|
| 231 |
+
|
| 232 |
+
if image.shape[-2:] == (height, width):
|
| 233 |
+
omit_channel_dim = True
|
| 234 |
+
else:
|
| 235 |
+
omit_channel_dim = False
|
| 236 |
+
if omit_channel_dim:
|
| 237 |
+
image = image[..., None]
|
| 238 |
+
|
| 239 |
+
image = np.where(mask[..., None], image, 0)
|
| 240 |
+
|
| 241 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 242 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
|
| 243 |
+
filter_size = filter_h_i * filter_w_i
|
| 244 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 245 |
+
|
| 246 |
+
# Window the original mask and uv (non-copy)
|
| 247 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 248 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 249 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 250 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 251 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 252 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 253 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 254 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 255 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 256 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 257 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 258 |
+
|
| 259 |
+
# Gather the target pixels's local window
|
| 260 |
+
target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 261 |
+
target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 262 |
+
target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 263 |
+
target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 264 |
+
|
| 265 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 266 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 267 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 268 |
+
|
| 269 |
+
# Compute pixel area in the local windows
|
| 270 |
+
target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
|
| 271 |
+
target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
|
| 272 |
+
target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
|
| 273 |
+
target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
|
| 274 |
+
|
| 275 |
+
# Weighted sum by area
|
| 276 |
+
target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
|
| 277 |
+
target_mask = np.sum(target_window_area, axis=-1) >= 0.25
|
| 278 |
+
target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
|
| 279 |
+
|
| 280 |
+
if omit_channel_dim:
|
| 281 |
+
target_image = target_image[..., 0]
|
| 282 |
+
|
| 283 |
+
return target_image, target_mask
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def norm3d(x: np.ndarray) -> np.ndarray:
|
| 287 |
+
"Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
|
| 288 |
+
return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
|
| 292 |
+
disp = np.where(mask, 1 / depth, 0)
|
| 293 |
+
disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
|
| 294 |
+
mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
|
| 295 |
+
kernel_size = 2 * thickness + 1
|
| 296 |
+
disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 297 |
+
mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 298 |
+
|
| 299 |
+
disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
|
| 300 |
+
fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
|
| 301 |
+
bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
|
| 302 |
+
|
| 303 |
+
edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
|
| 304 |
+
& (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
|
| 305 |
+
|
| 306 |
+
return edge_mask
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def disk_kernel(radius: int) -> np.ndarray:
|
| 310 |
+
"""
|
| 311 |
+
Generate disk kernel with given radius.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
radius (int): Radius of the disk (in pixels).
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
|
| 318 |
+
"""
|
| 319 |
+
# Create coordinate grid centered at (0,0)
|
| 320 |
+
L = np.arange(-radius, radius + 1)
|
| 321 |
+
X, Y = np.meshgrid(L, L)
|
| 322 |
+
# Generate disk: region inside circle with radius R is 1
|
| 323 |
+
kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
|
| 324 |
+
# Normalize the kernel
|
| 325 |
+
kernel /= np.sum(kernel)
|
| 326 |
+
return kernel
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
|
| 330 |
+
"""
|
| 331 |
+
Apply disk blur to an image using FFT convolution.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
image (np.ndarray): Input image, can be grayscale or color.
|
| 335 |
+
radius (int): Blur radius (in pixels).
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
np.ndarray: Blurred image.
|
| 339 |
+
"""
|
| 340 |
+
if radius == 0:
|
| 341 |
+
return image
|
| 342 |
+
kernel = disk_kernel(radius)
|
| 343 |
+
if image.ndim == 2:
|
| 344 |
+
blurred = fftconvolve(image, kernel, mode='same')
|
| 345 |
+
elif image.ndim == 3:
|
| 346 |
+
channels = []
|
| 347 |
+
for i in range(image.shape[2]):
|
| 348 |
+
blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
|
| 349 |
+
channels.append(blurred_channel)
|
| 350 |
+
blurred = np.stack(channels, axis=-1)
|
| 351 |
+
else:
|
| 352 |
+
raise ValueError("Image must be 2D or 3D.")
|
| 353 |
+
return blurred
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def depth_of_field(
|
| 357 |
+
img: np.ndarray,
|
| 358 |
+
disp: np.ndarray,
|
| 359 |
+
focus_disp : float,
|
| 360 |
+
max_blur_radius : int = 10,
|
| 361 |
+
) -> np.ndarray:
|
| 362 |
+
"""
|
| 363 |
+
Apply depth of field effect to an image.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
img (numpy.ndarray): (H, W, 3) input image.
|
| 367 |
+
depth (numpy.ndarray): (H, W) depth map of the scene.
|
| 368 |
+
focus_depth (float): Focus depth of the lens.
|
| 369 |
+
strength (float): Strength of the depth of field effect.
|
| 370 |
+
max_blur_radius (int): Maximum blur radius (in pixels).
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
|
| 374 |
+
"""
|
| 375 |
+
# Precalculate dialated depth map for each blur radius
|
| 376 |
+
max_disp = np.max(disp)
|
| 377 |
+
disp = disp / max_disp
|
| 378 |
+
focus_disp = focus_disp / max_disp
|
| 379 |
+
dilated_disp = []
|
| 380 |
+
for radius in range(max_blur_radius + 1):
|
| 381 |
+
dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
|
| 382 |
+
|
| 383 |
+
# Determine the blur radius for each pixel based on the depth map
|
| 384 |
+
blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 385 |
+
for radius in range(max_blur_radius + 1):
|
| 386 |
+
dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 387 |
+
mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
|
| 388 |
+
blur_radii[mask] = dialted_blur_radii[mask]
|
| 389 |
+
blur_radii = np.clip(blur_radii, 0, max_blur_radius)
|
| 390 |
+
blur_radii = cv2.blur(blur_radii, (5, 5))
|
| 391 |
+
|
| 392 |
+
# Precalculate the blured image for each blur radius
|
| 393 |
+
unique_radii = np.unique(blur_radii)
|
| 394 |
+
precomputed = {}
|
| 395 |
+
for radius in range(max_blur_radius + 1):
|
| 396 |
+
if radius not in unique_radii:
|
| 397 |
+
continue
|
| 398 |
+
precomputed[radius] = disk_blur(img, radius)
|
| 399 |
+
|
| 400 |
+
# Composit the blured image for each pixel
|
| 401 |
+
output = np.zeros_like(img)
|
| 402 |
+
for r in unique_radii:
|
| 403 |
+
mask = blur_radii == r
|
| 404 |
+
output[mask] = precomputed[r][mask]
|
| 405 |
+
|
| 406 |
+
return output
|
moge/utils/geometry_torch.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.types
|
| 10 |
+
import utils3d
|
| 11 |
+
|
| 12 |
+
from .tools import timeit
|
| 13 |
+
from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 17 |
+
if w is None:
|
| 18 |
+
return x.mean(dim=dim, keepdim=keepdim)
|
| 19 |
+
else:
|
| 20 |
+
w = w.to(x.dtype)
|
| 21 |
+
return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 25 |
+
if w is None:
|
| 26 |
+
return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
|
| 27 |
+
else:
|
| 28 |
+
w = w.to(x.dtype)
|
| 29 |
+
return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 33 |
+
if w is None:
|
| 34 |
+
return x.add(eps).log().mean(dim=dim).exp()
|
| 35 |
+
else:
|
| 36 |
+
w = w.to(x.dtype)
|
| 37 |
+
return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
| 41 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 42 |
+
if aspect_ratio is None:
|
| 43 |
+
aspect_ratio = width / height
|
| 44 |
+
|
| 45 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 46 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 47 |
+
|
| 48 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
| 49 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
| 50 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
| 51 |
+
uv = torch.stack([u, v], dim=-1)
|
| 52 |
+
return uv
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
| 56 |
+
kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
|
| 57 |
+
kernel = kernel / kernel.sum()
|
| 58 |
+
kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
|
| 59 |
+
input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
|
| 60 |
+
input = F.conv2d(input, kernel, groups=input.shape[1])
|
| 61 |
+
return input
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def focal_to_fov(focal: torch.Tensor):
|
| 65 |
+
return 2 * torch.atan(0.5 / focal)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def fov_to_focal(fov: torch.Tensor):
|
| 69 |
+
return 0.5 / torch.tan(fov / 2)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
|
| 73 |
+
return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
|
| 74 |
+
|
| 75 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor):
|
| 76 |
+
"""
|
| 77 |
+
Returns field of view in radians from normalized intrinsics matrix.
|
| 78 |
+
### Parameters:
|
| 79 |
+
- intrinsics: torch.Tensor of shape (..., 3, 3)
|
| 80 |
+
|
| 81 |
+
### Returns:
|
| 82 |
+
- fov_x: torch.Tensor of shape (...)
|
| 83 |
+
- fov_y: torch.Tensor of shape (...)
|
| 84 |
+
"""
|
| 85 |
+
focal_x = intrinsics[..., 0, 0]
|
| 86 |
+
focal_y = intrinsics[..., 1, 1]
|
| 87 |
+
return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def point_map_to_depth_legacy(points: torch.Tensor):
|
| 91 |
+
height, width = points.shape[-3:-1]
|
| 92 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 93 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 94 |
+
|
| 95 |
+
# Solve least squares problem
|
| 96 |
+
b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
|
| 97 |
+
A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
|
| 98 |
+
|
| 99 |
+
M = A.transpose(-2, -1) @ A
|
| 100 |
+
solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
|
| 101 |
+
focal, shift = solution.unbind(-1)
|
| 102 |
+
|
| 103 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 104 |
+
fov_x = torch.atan(width / diagonal / focal) * 2
|
| 105 |
+
fov_y = torch.atan(height / diagonal / focal) * 2
|
| 106 |
+
return depth, fov_x, fov_y, shift
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def view_plane_uv_to_focal(uv: torch.Tensor):
|
| 110 |
+
normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
|
| 111 |
+
focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
|
| 112 |
+
return focal
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 116 |
+
"""
|
| 117 |
+
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
| 118 |
+
|
| 119 |
+
Note that it assumes:
|
| 120 |
+
- the optical center is at the center of the map
|
| 121 |
+
- the map is undistorted
|
| 122 |
+
- the map is isometric in the x and y directions
|
| 123 |
+
|
| 124 |
+
### Parameters:
|
| 125 |
+
- `points: torch.Tensor` of shape (..., H, W, 3)
|
| 126 |
+
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
| 127 |
+
|
| 128 |
+
### Returns:
|
| 129 |
+
- `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
|
| 130 |
+
- `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
|
| 131 |
+
"""
|
| 132 |
+
shape = points.shape
|
| 133 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 134 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 135 |
+
|
| 136 |
+
points = points.reshape(-1, *shape[-3:])
|
| 137 |
+
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
| 138 |
+
focal = focal.reshape(-1) if focal is not None else None
|
| 139 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 140 |
+
|
| 141 |
+
points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
|
| 142 |
+
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
|
| 143 |
+
mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
|
| 144 |
+
|
| 145 |
+
uv_lr_np = uv_lr.cpu().numpy()
|
| 146 |
+
points_lr_np = points_lr.detach().cpu().numpy()
|
| 147 |
+
focal_np = focal.cpu().numpy() if focal is not None else None
|
| 148 |
+
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
| 149 |
+
optim_shift, optim_focal = [], []
|
| 150 |
+
for i in range(points.shape[0]):
|
| 151 |
+
points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
| 152 |
+
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
| 153 |
+
if uv_lr_i_np.shape[0] < 2:
|
| 154 |
+
optim_focal.append(1)
|
| 155 |
+
optim_shift.append(0)
|
| 156 |
+
continue
|
| 157 |
+
if focal is None:
|
| 158 |
+
optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
|
| 159 |
+
optim_focal.append(float(optim_focal_i))
|
| 160 |
+
else:
|
| 161 |
+
optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
|
| 162 |
+
optim_shift.append(float(optim_shift_i))
|
| 163 |
+
optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 164 |
+
|
| 165 |
+
if focal is None:
|
| 166 |
+
optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 167 |
+
else:
|
| 168 |
+
optim_focal = focal.reshape(shape[:-3])
|
| 169 |
+
|
| 170 |
+
return optim_focal, optim_shift
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def mask_aware_nearest_resize(
|
| 174 |
+
inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
|
| 175 |
+
mask: torch.BoolTensor,
|
| 176 |
+
size: Tuple[int, int],
|
| 177 |
+
return_index: bool = False
|
| 178 |
+
) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
|
| 179 |
+
"""
|
| 180 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 181 |
+
|
| 182 |
+
### Parameters
|
| 183 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 184 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 185 |
+
- `size`: target size (target_width, target_height)
|
| 186 |
+
|
| 187 |
+
### Returns
|
| 188 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 189 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 190 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
|
| 191 |
+
"""
|
| 192 |
+
height, width = mask.shape[-2:]
|
| 193 |
+
target_width, target_height = size
|
| 194 |
+
device = mask.device
|
| 195 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 196 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 197 |
+
filter_size = filter_h_i * filter_w_i
|
| 198 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 199 |
+
|
| 200 |
+
# Window the original mask and uv
|
| 201 |
+
uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
|
| 202 |
+
indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
|
| 203 |
+
padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
|
| 204 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 205 |
+
padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
|
| 206 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 207 |
+
padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
|
| 208 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 209 |
+
windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 210 |
+
windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
|
| 211 |
+
windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 212 |
+
|
| 213 |
+
# Gather the target pixels's local window
|
| 214 |
+
target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
|
| 215 |
+
target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
|
| 216 |
+
target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
|
| 217 |
+
|
| 218 |
+
target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 219 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 220 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 221 |
+
target_window_indices = target_window_indices.expand_as(target_window_mask)
|
| 222 |
+
|
| 223 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 224 |
+
dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
|
| 225 |
+
nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
|
| 226 |
+
nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 227 |
+
target_mask = torch.any(target_window_mask, dim=-1)
|
| 228 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 229 |
+
batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 230 |
+
|
| 231 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 232 |
+
|
| 233 |
+
if inputs is None:
|
| 234 |
+
outputs = None
|
| 235 |
+
elif isinstance(inputs, torch.Tensor):
|
| 236 |
+
outputs = inputs[index]
|
| 237 |
+
elif isinstance(inputs, Sequence):
|
| 238 |
+
outputs = tuple(x[index] for x in inputs)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 241 |
+
|
| 242 |
+
if return_index:
|
| 243 |
+
return outputs, target_mask, index
|
| 244 |
+
else:
|
| 245 |
+
return outputs, target_mask
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
|
| 249 |
+
*batch_shape, height, width = depth.shape
|
| 250 |
+
depth = depth.reshape(-1, 1, height, width)
|
| 251 |
+
mask = mask.reshape(-1, 1, height, width)
|
| 252 |
+
if pooler =='max':
|
| 253 |
+
pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 254 |
+
output_mask = pooled_depth > depth * (1 + rtol)
|
| 255 |
+
elif pooler =='min':
|
| 256 |
+
pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 257 |
+
output_mask = pooled_depth < depth * (1 - rtol)
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(f'Unsupported pooler: {pooler}')
|
| 260 |
+
output_mask = output_mask.reshape(*batch_shape, height, width)
|
| 261 |
+
return output_mask
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 265 |
+
device, dtype = depth.device, depth.dtype
|
| 266 |
+
|
| 267 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 268 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 269 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 270 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 271 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 272 |
+
|
| 273 |
+
x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
|
| 274 |
+
A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
|
| 275 |
+
A = mask_window[..., None] * A
|
| 276 |
+
I = torch.eye(3, device=device, dtype=dtype)
|
| 277 |
+
|
| 278 |
+
affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
|
| 279 |
+
diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
|
| 280 |
+
|
| 281 |
+
edge_mask = mask & (diff > tol).any(dim=-1)
|
| 282 |
+
|
| 283 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
|
| 284 |
+
fg_edge_mask = edge_mask & (disp > disp_mean)
|
| 285 |
+
# fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
|
| 286 |
+
bg_edge_mask = edge_mask & ~fg_edge_mask
|
| 287 |
+
return fg_edge_mask, bg_edge_mask
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 291 |
+
device, dtype = depth.device, depth.dtype
|
| 292 |
+
|
| 293 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 294 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 295 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 296 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 297 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 298 |
+
|
| 299 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
|
| 300 |
+
fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
|
| 301 |
+
bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
|
| 302 |
+
|
| 303 |
+
fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 304 |
+
bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 305 |
+
|
| 306 |
+
return fg_edge_mask, bg_edge_mask
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
|
| 310 |
+
kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
|
| 311 |
+
for _ in range(iterations):
|
| 312 |
+
input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
|
| 313 |
+
mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
|
| 314 |
+
if filter =='min':
|
| 315 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
|
| 316 |
+
elif filter =='max':
|
| 317 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
|
| 318 |
+
elif filter == 'mean':
|
| 319 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
|
| 320 |
+
elif filter =='median':
|
| 321 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
|
| 322 |
+
mask = mask_window.any(dim=(-2, -1))
|
| 323 |
+
return input, mask
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
|
| 327 |
+
device, dtype = depth.device, depth.dtype
|
| 328 |
+
height, width = depth.shape[-2:]
|
| 329 |
+
radius = kernel_size // 2
|
| 330 |
+
|
| 331 |
+
duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
|
| 332 |
+
|
| 333 |
+
log_depth = depth.clamp_min_(eps).log()
|
| 334 |
+
log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
|
| 335 |
+
|
| 336 |
+
weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
|
| 337 |
+
tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
|
| 338 |
+
|
| 339 |
+
uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype)
|
| 340 |
+
K_inv = torch.inverse(intrinsics)
|
| 341 |
+
|
| 342 |
+
grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
|
| 343 |
+
/ (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
|
| 344 |
+
laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
|
| 345 |
+
|
| 346 |
+
laplacian = laplacian.clamp(-0.1, 0.1)
|
| 347 |
+
log_depth_refine = log_depth.clone()
|
| 348 |
+
|
| 349 |
+
for _ in range(iterations):
|
| 350 |
+
log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
|
| 351 |
+
|
| 352 |
+
depth_refine = log_depth_refine.exp()
|
| 353 |
+
|
| 354 |
+
return depth_refine
|
moge/utils/io.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from typing import IO
|
| 4 |
+
import zipfile
|
| 5 |
+
import json
|
| 6 |
+
import io
|
| 7 |
+
from typing import *
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
from PIL import Image, PngImagePlugin
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
from .tools import timeit
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def save_glb(
|
| 19 |
+
save_path: Union[str, os.PathLike],
|
| 20 |
+
vertices: np.ndarray,
|
| 21 |
+
faces: np.ndarray,
|
| 22 |
+
vertex_uvs: np.ndarray,
|
| 23 |
+
texture: np.ndarray,
|
| 24 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 25 |
+
):
|
| 26 |
+
import trimesh
|
| 27 |
+
import trimesh.visual
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
trimesh.Trimesh(
|
| 31 |
+
vertices=vertices,
|
| 32 |
+
vertex_normals=vertex_normals,
|
| 33 |
+
faces=faces,
|
| 34 |
+
visual = trimesh.visual.texture.TextureVisuals(
|
| 35 |
+
uv=vertex_uvs,
|
| 36 |
+
material=trimesh.visual.material.PBRMaterial(
|
| 37 |
+
baseColorTexture=Image.fromarray(texture),
|
| 38 |
+
metallicFactor=0.5,
|
| 39 |
+
roughnessFactor=1.0
|
| 40 |
+
)
|
| 41 |
+
),
|
| 42 |
+
process=False
|
| 43 |
+
).export(save_path)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_ply(
|
| 47 |
+
save_path: Union[str, os.PathLike],
|
| 48 |
+
vertices: np.ndarray,
|
| 49 |
+
faces: np.ndarray,
|
| 50 |
+
vertex_colors: np.ndarray,
|
| 51 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 52 |
+
):
|
| 53 |
+
import trimesh
|
| 54 |
+
import trimesh.visual
|
| 55 |
+
from PIL import Image
|
| 56 |
+
|
| 57 |
+
trimesh.Trimesh(
|
| 58 |
+
vertices=vertices,
|
| 59 |
+
faces=faces,
|
| 60 |
+
vertex_colors=vertex_colors,
|
| 61 |
+
vertex_normals=vertex_normals,
|
| 62 |
+
process=False
|
| 63 |
+
).export(save_path)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 67 |
+
"""
|
| 68 |
+
Read a image, return uint8 RGB array of shape (H, W, 3).
|
| 69 |
+
"""
|
| 70 |
+
if isinstance(path, (str, os.PathLike)):
|
| 71 |
+
data = Path(path).read_bytes()
|
| 72 |
+
else:
|
| 73 |
+
data = path.read()
|
| 74 |
+
image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
|
| 79 |
+
"""
|
| 80 |
+
Write a image, input uint8 RGB array of shape (H, W, 3).
|
| 81 |
+
"""
|
| 82 |
+
data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
|
| 83 |
+
if isinstance(path, (str, os.PathLike)):
|
| 84 |
+
Path(path).write_bytes(data)
|
| 85 |
+
else:
|
| 86 |
+
path.write(data)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]:
|
| 90 |
+
"""
|
| 91 |
+
Read a depth image, return float32 depth array of shape (H, W).
|
| 92 |
+
"""
|
| 93 |
+
if isinstance(path, (str, os.PathLike)):
|
| 94 |
+
data = Path(path).read_bytes()
|
| 95 |
+
else:
|
| 96 |
+
data = path.read()
|
| 97 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 98 |
+
near = float(pil_image.info.get('near'))
|
| 99 |
+
far = float(pil_image.info.get('far'))
|
| 100 |
+
unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None
|
| 101 |
+
depth = np.array(pil_image)
|
| 102 |
+
mask_nan, mask_inf = depth == 0, depth == 65535
|
| 103 |
+
depth = (depth.astype(np.float32) - 1) / 65533
|
| 104 |
+
depth = near ** (1 - depth) * far ** depth
|
| 105 |
+
depth[mask_nan] = np.nan
|
| 106 |
+
depth[mask_inf] = np.inf
|
| 107 |
+
return depth, unit
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def write_depth(
|
| 111 |
+
path: Union[str, os.PathLike, IO],
|
| 112 |
+
depth: np.ndarray,
|
| 113 |
+
unit: float = None,
|
| 114 |
+
max_range: float = 1e5,
|
| 115 |
+
compression_level: int = 7,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Encode and write a depth image as 16-bit PNG format.
|
| 119 |
+
### Parameters:
|
| 120 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 121 |
+
The file path or file object to write to.
|
| 122 |
+
- `depth: np.ndarray`
|
| 123 |
+
The depth array, float32 array of shape (H, W).
|
| 124 |
+
May contain `NaN` for invalid values and `Inf` for infinite values.
|
| 125 |
+
- `unit: float = None`
|
| 126 |
+
The unit of the depth values.
|
| 127 |
+
|
| 128 |
+
Depth values are encoded as follows:
|
| 129 |
+
- 0: unknown
|
| 130 |
+
- 1 ~ 65534: depth values in logarithmic
|
| 131 |
+
- 65535: infinity
|
| 132 |
+
|
| 133 |
+
metadata is stored in the PNG file as text fields:
|
| 134 |
+
- `near`: the minimum depth value
|
| 135 |
+
- `far`: the maximum depth value
|
| 136 |
+
- `unit`: the unit of the depth values (optional)
|
| 137 |
+
"""
|
| 138 |
+
mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
|
| 139 |
+
|
| 140 |
+
depth = depth.astype(np.float32)
|
| 141 |
+
mask_finite = depth
|
| 142 |
+
near = max(depth[mask_values].min(), 1e-5)
|
| 143 |
+
far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
|
| 144 |
+
depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
|
| 145 |
+
depth[mask_nan] = 0
|
| 146 |
+
depth[mask_inf] = 65535
|
| 147 |
+
|
| 148 |
+
pil_image = Image.fromarray(depth)
|
| 149 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 150 |
+
pnginfo.add_text('near', str(near))
|
| 151 |
+
pnginfo.add_text('far', str(far))
|
| 152 |
+
if unit is not None:
|
| 153 |
+
pnginfo.add_text('unit', str(unit))
|
| 154 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
|
| 158 |
+
"""
|
| 159 |
+
Read a segmentation mask
|
| 160 |
+
### Parameters:
|
| 161 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 162 |
+
The file path or file object to read from.
|
| 163 |
+
### Returns:
|
| 164 |
+
- `Tuple[np.ndarray, Dict[str, int]]`
|
| 165 |
+
A tuple containing:
|
| 166 |
+
- `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
|
| 167 |
+
- `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
|
| 168 |
+
"""
|
| 169 |
+
if isinstance(path, (str, os.PathLike)):
|
| 170 |
+
data = Path(path).read_bytes()
|
| 171 |
+
else:
|
| 172 |
+
data = path.read()
|
| 173 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 174 |
+
labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
|
| 175 |
+
mask = np.array(pil_image)
|
| 176 |
+
return mask, labels
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
|
| 180 |
+
"""
|
| 181 |
+
Write a segmentation mask and label mapping, as PNG format.
|
| 182 |
+
### Parameters:
|
| 183 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 184 |
+
The file path or file object to write to.
|
| 185 |
+
- `mask: np.ndarray`
|
| 186 |
+
The segmentation mask, uint8 or uint16 array of shape (H, W).
|
| 187 |
+
- `labels: Dict[str, int] = None`
|
| 188 |
+
The label mapping, a dictionary of {label_name: label_id}.
|
| 189 |
+
- `compression_level: int = 7`
|
| 190 |
+
The compression level for PNG compression.
|
| 191 |
+
"""
|
| 192 |
+
assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
|
| 193 |
+
pil_image = Image.fromarray(mask)
|
| 194 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 195 |
+
if labels is not None:
|
| 196 |
+
labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
|
| 197 |
+
pnginfo.add_text('labels', labels_json)
|
| 198 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 203 |
+
"""
|
| 204 |
+
Read a normal image, return float32 normal array of shape (H, W, 3).
|
| 205 |
+
"""
|
| 206 |
+
if isinstance(path, (str, os.PathLike)):
|
| 207 |
+
data = Path(path).read_bytes()
|
| 208 |
+
else:
|
| 209 |
+
data = path.read()
|
| 210 |
+
normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
|
| 211 |
+
mask_nan = np.all(normal == 0, axis=-1)
|
| 212 |
+
normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
|
| 213 |
+
normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
|
| 214 |
+
normal[mask_nan] = np.nan
|
| 215 |
+
return normal
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
|
| 219 |
+
"""
|
| 220 |
+
Write a normal image, input float32 normal array of shape (H, W, 3).
|
| 221 |
+
"""
|
| 222 |
+
mask_nan = np.isnan(normal).any(axis=-1)
|
| 223 |
+
normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
|
| 224 |
+
normal[mask_nan] = 0
|
| 225 |
+
data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
|
| 226 |
+
if isinstance(path, (str, os.PathLike)):
|
| 227 |
+
Path(path).write_bytes(data)
|
| 228 |
+
else:
|
| 229 |
+
path.write(data)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]:
|
| 233 |
+
return json.loads(Path(path).read_text())
|
| 234 |
+
|
| 235 |
+
def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]):
|
| 236 |
+
Path(path).write_text(json.dumps(meta))
|