gangweix commited on
Commit
e2eafba
·
verified ·
1 Parent(s): c50a6f0

Upload 56 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. moge/__init__.py +0 -0
  2. moge/model/__init__.py +18 -0
  3. moge/model/dinov2/__init__.py +6 -0
  4. moge/model/dinov2/hub/__init__.py +4 -0
  5. moge/model/dinov2/hub/backbones.py +156 -0
  6. moge/model/dinov2/hub/utils.py +39 -0
  7. moge/model/dinov2/layers/__init__.py +11 -0
  8. moge/model/dinov2/layers/attention.py +100 -0
  9. moge/model/dinov2/layers/block.py +259 -0
  10. moge/model/dinov2/layers/dino_head.py +58 -0
  11. moge/model/dinov2/layers/drop_path.py +34 -0
  12. moge/model/dinov2/layers/layer_scale.py +27 -0
  13. moge/model/dinov2/layers/mlp.py +40 -0
  14. moge/model/dinov2/layers/patch_embed.py +88 -0
  15. moge/model/dinov2/layers/swiglu_ffn.py +72 -0
  16. moge/model/dinov2/models/__init__.py +43 -0
  17. moge/model/dinov2/models/vision_transformer.py +407 -0
  18. moge/model/dinov2/utils/__init__.py +4 -0
  19. moge/model/dinov2/utils/cluster.py +95 -0
  20. moge/model/dinov2/utils/config.py +72 -0
  21. moge/model/dinov2/utils/dtype.py +37 -0
  22. moge/model/dinov2/utils/param_groups.py +103 -0
  23. moge/model/dinov2/utils/utils.py +95 -0
  24. moge/model/modules.py +254 -0
  25. moge/model/utils.py +49 -0
  26. moge/model/v1.py +392 -0
  27. moge/model/v2.py +303 -0
  28. moge/scripts/__init__.py +0 -0
  29. moge/scripts/app.py +301 -0
  30. moge/scripts/cli.py +27 -0
  31. moge/scripts/eval_baseline.py +165 -0
  32. moge/scripts/infer.py +170 -0
  33. moge/scripts/infer_baseline.py +140 -0
  34. moge/scripts/infer_panorama.py +162 -0
  35. moge/scripts/train.py +452 -0
  36. moge/scripts/vis_data.py +84 -0
  37. moge/test/__init__.py +0 -0
  38. moge/test/baseline.py +43 -0
  39. moge/test/dataloader.py +221 -0
  40. moge/test/metrics.py +343 -0
  41. moge/train/__init__.py +0 -0
  42. moge/train/dataloader.py +338 -0
  43. moge/train/losses.py +270 -0
  44. moge/train/utils.py +57 -0
  45. moge/utils/__init__.py +0 -0
  46. moge/utils/alignment.py +416 -0
  47. moge/utils/download.py +55 -0
  48. moge/utils/geometry_numpy.py +406 -0
  49. moge/utils/geometry_torch.py +354 -0
  50. moge/utils/io.py +236 -0
moge/__init__.py ADDED
File without changes
moge/model/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import *
3
+
4
+ if TYPE_CHECKING:
5
+ from .v1 import MoGeModel as MoGeModelV1
6
+ from .v2 import MoGeModel as MoGeModelV2
7
+
8
+
9
+ def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
10
+ assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
11
+
12
+ try:
13
+ module = importlib.import_module(f'.{version}', __package__)
14
+ except ModuleNotFoundError:
15
+ raise ValueError(f'Model version "{version}" not found.')
16
+
17
+ cls = getattr(module, 'MoGeModel')
18
+ return cls
moge/model/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
moge/model/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
moge/model/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
moge/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
moge/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+ from torch import nn
17
+
18
+
19
+ logger = logging.getLogger("dinov2")
20
+
21
+
22
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
23
+ try:
24
+ if XFORMERS_ENABLED:
25
+ from xformers.ops import memory_efficient_attention, unbind
26
+
27
+ XFORMERS_AVAILABLE = True
28
+ # warnings.warn("xFormers is available (Attention)")
29
+ else:
30
+ # warnings.warn("xFormers is disabled (Attention)")
31
+ raise ImportError
32
+ except ImportError:
33
+ XFORMERS_AVAILABLE = False
34
+ # warnings.warn("xFormers is not available (Attention)")
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ num_heads: int = 8,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ attn_drop: float = 0.0,
45
+ proj_drop: float = 0.0,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.num_heads = num_heads
49
+ head_dim = dim // num_heads
50
+ self.scale = head_dim**-0.5
51
+
52
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
53
+ self.attn_drop = nn.Dropout(attn_drop)
54
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
55
+ self.proj_drop = nn.Dropout(proj_drop)
56
+
57
+ # # Deprecated implementation, extremely slow
58
+ # def forward(self, x: Tensor, attn_bias=None) -> Tensor:
59
+ # B, N, C = x.shape
60
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
61
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
62
+ # attn = q @ k.transpose(-2, -1)
63
+ # attn = attn.softmax(dim=-1)
64
+ # attn = self.attn_drop(attn)
65
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
66
+ # x = self.proj(x)
67
+ # x = self.proj_drop(x)
68
+ # return x
69
+
70
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
73
+
74
+ q, k, v = qkv.unbind(0) # (B, H, N, C // H)
75
+
76
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
77
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+ class MemEffAttention(Attention):
84
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
85
+ if not XFORMERS_AVAILABLE:
86
+ if attn_bias is not None:
87
+ raise AssertionError("xFormers is required for using nested tensors")
88
+ return super().forward(x)
89
+
90
+ B, N, C = x.shape
91
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
92
+
93
+ q, k, v = unbind(qkv, 2)
94
+
95
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
96
+ x = x.reshape([B, N, C])
97
+
98
+ x = self.proj(x)
99
+ x = self.proj_drop(x)
100
+ return x
moge/model/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
moge/model/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
moge/model/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
moge/model/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
moge/model/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
moge/model/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
moge/model/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
moge/model/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
moge/model/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable, Optional, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ @property
173
+ def onnx_compatible_mode(self):
174
+ return getattr(self, "_onnx_compatible_mode", False)
175
+
176
+ @onnx_compatible_mode.setter
177
+ def onnx_compatible_mode(self, value: bool):
178
+ self._onnx_compatible_mode = value
179
+
180
+ def init_weights(self):
181
+ trunc_normal_(self.pos_embed, std=0.02)
182
+ nn.init.normal_(self.cls_token, std=1e-6)
183
+ if self.register_tokens is not None:
184
+ nn.init.normal_(self.register_tokens, std=1e-6)
185
+ named_apply(init_weights_vit_timm, self)
186
+
187
+ def interpolate_pos_encoding(self, x, h, w):
188
+ previous_dtype = x.dtype
189
+ npatch = x.shape[1] - 1
190
+ batch_size = x.shape[0]
191
+ N = self.pos_embed.shape[1] - 1
192
+ if not self.onnx_compatible_mode and npatch == N and w == h:
193
+ return self.pos_embed
194
+ pos_embed = self.pos_embed.float()
195
+ class_pos_embed = pos_embed[:, 0, :]
196
+ patch_pos_embed = pos_embed[:, 1:, :]
197
+ dim = x.shape[-1]
198
+ h0, w0 = h // self.patch_size, w // self.patch_size
199
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
200
+ assert N == M * M
201
+ kwargs = {}
202
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
203
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
204
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
205
+ sx = float(w0 + self.interpolate_offset) / M
206
+ sy = float(h0 + self.interpolate_offset) / M
207
+ kwargs["scale_factor"] = (sy, sx)
208
+ else:
209
+ # Simply specify an output size instead of a scale factor
210
+ kwargs["size"] = (h0, w0)
211
+
212
+ patch_pos_embed = nn.functional.interpolate(
213
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
214
+ mode="bicubic",
215
+ antialias=self.interpolate_antialias,
216
+ **kwargs,
217
+ )
218
+
219
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
220
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
221
+ return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
222
+
223
+ def prepare_tokens_with_masks(self, x, masks=None):
224
+ B, nc, h, w = x.shape
225
+ x = self.patch_embed(x)
226
+
227
+ if masks is not None:
228
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
229
+
230
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
231
+ x = x + self.interpolate_pos_encoding(x, h, w)
232
+
233
+ if self.register_tokens is not None:
234
+ x = torch.cat(
235
+ (
236
+ x[:, :1],
237
+ self.register_tokens.expand(x.shape[0], -1, -1),
238
+ x[:, 1:],
239
+ ),
240
+ dim=1,
241
+ )
242
+
243
+ return x
244
+
245
+ def forward_features_list(self, x_list, masks_list):
246
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)]
247
+ for blk in self.blocks:
248
+ x = blk(x)
249
+
250
+ all_x = x
251
+ output = []
252
+ for x, masks in zip(all_x, masks_list):
253
+ x_norm = self.norm(x)
254
+ output.append(
255
+ {
256
+ "x_norm_clstoken": x_norm[:, 0],
257
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
258
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
259
+ "x_prenorm": x,
260
+ "masks": masks,
261
+ }
262
+ )
263
+ return output
264
+
265
+ def forward_features(self, x, masks=None):
266
+ if isinstance(x, list):
267
+ return self.forward_features_list(x, masks)
268
+
269
+ x = self.prepare_tokens_with_masks(x, masks)
270
+
271
+ for blk in self.blocks:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=False, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
moge/model/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
moge/model/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
moge/model/dinov2/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
moge/model/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
moge/model/dinov2/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
moge/model/modules.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ import importlib
4
+ import itertools
5
+ import functools
6
+ import sys
7
+
8
+ import torch
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .dinov2.models.vision_transformer import DinoVisionTransformer
14
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
15
+ from ..utils.geometry_torch import normalized_view_plane_uv
16
+
17
+
18
+ class ResidualConvBlock(nn.Module):
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int = None,
23
+ hidden_channels: int = None,
24
+ kernel_size: int = 3,
25
+ padding_mode: str = 'replicate',
26
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
27
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
28
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
29
+ ):
30
+ super(ResidualConvBlock, self).__init__()
31
+ if out_channels is None:
32
+ out_channels = in_channels
33
+ if hidden_channels is None:
34
+ hidden_channels = in_channels
35
+
36
+ if activation =='relu':
37
+ activation_cls = nn.ReLU
38
+ elif activation == 'leaky_relu':
39
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
40
+ elif activation =='silu':
41
+ activation_cls = nn.SiLU
42
+ elif activation == 'elu':
43
+ activation_cls = nn.ELU
44
+ else:
45
+ raise ValueError(f'Unsupported activation function: {activation}')
46
+
47
+ self.layers = nn.Sequential(
48
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
49
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
50
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
51
+ nn.Identity(),
52
+ activation_cls(),
53
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
54
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
55
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
56
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
57
+ nn.Identity(),
58
+ activation_cls(),
59
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
60
+ )
61
+
62
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
63
+
64
+ def forward(self, x):
65
+ skip = self.skip_connection(x)
66
+ x = self.layers(x)
67
+ x = x + skip
68
+ return x
69
+
70
+
71
+ class DINOv2Encoder(nn.Module):
72
+ "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
73
+ backbone: DinoVisionTransformer
74
+ image_mean: torch.Tensor
75
+ image_std: torch.Tensor
76
+ dim_features: int
77
+
78
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
79
+ super(DINOv2Encoder, self).__init__()
80
+
81
+ self.intermediate_layers = intermediate_layers
82
+
83
+ # Load the backbone
84
+ self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
85
+ self.backbone_name = backbone
86
+ self.backbone = self.hub_loader(pretrained=False)
87
+
88
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
89
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
90
+
91
+ self.output_projections = nn.ModuleList([
92
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
93
+ for _ in range(self.num_features)
94
+ ])
95
+
96
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
97
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
98
+
99
+ @property
100
+ def onnx_compatible_mode(self):
101
+ return getattr(self, "_onnx_compatible_mode", False)
102
+
103
+ @onnx_compatible_mode.setter
104
+ def onnx_compatible_mode(self, value: bool):
105
+ self._onnx_compatible_mode = value
106
+ self.backbone.onnx_compatible_mode = value
107
+
108
+ def init_weights(self):
109
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
110
+ self.backbone.load_state_dict(pretrained_backbone_state_dict)
111
+
112
+ def enable_gradient_checkpointing(self):
113
+ for i in range(len(self.backbone.blocks)):
114
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
115
+
116
+ def enable_pytorch_native_sdpa(self):
117
+ for i in range(len(self.backbone.blocks)):
118
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
119
+
120
+ def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
122
+ image_14 = (image_14 - self.image_mean) / self.image_std
123
+
124
+ # Get intermediate layers from the backbone
125
+ features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
126
+
127
+ # Project features to the desired dimensionality
128
+ x = torch.stack([
129
+ proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
130
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
131
+ ], dim=1).sum(dim=1)
132
+
133
+ if return_class_token:
134
+ return x, features[-1][1]
135
+ else:
136
+ return x
137
+
138
+
139
+ class Resampler(nn.Sequential):
140
+ def __init__(self,
141
+ in_channels: int,
142
+ out_channels: int,
143
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
144
+ scale_factor: int = 2,
145
+ ):
146
+ if type_ == 'pixel_shuffle':
147
+ nn.Sequential.__init__(self,
148
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
149
+ nn.PixelShuffle(scale_factor),
150
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
151
+ )
152
+ for i in range(1, scale_factor ** 2):
153
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
154
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
155
+ elif type_ in ['nearest', 'bilinear']:
156
+ nn.Sequential.__init__(self,
157
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
158
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
159
+ )
160
+ elif type_ == 'conv_transpose':
161
+ nn.Sequential.__init__(self,
162
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
163
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
164
+ )
165
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
166
+ elif type_ == 'pixel_unshuffle':
167
+ nn.Sequential.__init__(self,
168
+ nn.PixelUnshuffle(scale_factor),
169
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
170
+ )
171
+ elif type_ == 'avg_pool':
172
+ nn.Sequential.__init__(self,
173
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
174
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
175
+ )
176
+ elif type_ == 'max_pool':
177
+ nn.Sequential.__init__(self,
178
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
179
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
180
+ )
181
+ else:
182
+ raise ValueError(f'Unsupported resampler type: {type_}')
183
+
184
+ class MLP(nn.Sequential):
185
+ def __init__(self, dims: Sequence[int]):
186
+ nn.Sequential.__init__(self,
187
+ *itertools.chain(*[
188
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
189
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
190
+ ]),
191
+ nn.Linear(dims[-2], dims[-1]),
192
+ )
193
+
194
+
195
+ class ConvStack(nn.Module):
196
+ def __init__(self,
197
+ dim_in: List[Optional[int]],
198
+ dim_res_blocks: List[int],
199
+ dim_out: List[Optional[int]],
200
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
201
+ dim_times_res_block_hidden: int = 1,
202
+ num_res_blocks: int = 1,
203
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
204
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
205
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
206
+ ):
207
+ super().__init__()
208
+ self.input_blocks = nn.ModuleList([
209
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
210
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
211
+ ])
212
+ self.resamplers = nn.ModuleList([
213
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
214
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
215
+ dim_res_blocks[:-1],
216
+ dim_res_blocks[1:],
217
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
218
+ ))
219
+ ])
220
+ self.res_blocks = nn.ModuleList([
221
+ nn.Sequential(
222
+ *(
223
+ ResidualConvBlock(
224
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
225
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
226
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
227
+ )
228
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
229
+ ])
230
+ self.output_blocks = nn.ModuleList([
231
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
232
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
233
+ ])
234
+
235
+ def enable_gradient_checkpointing(self):
236
+ for i in range(len(self.resamplers)):
237
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
238
+ for i in range(len(self.res_blocks)):
239
+ for j in range(len(self.res_blocks[i])):
240
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
241
+
242
+ def forward(self, in_features: List[torch.Tensor]):
243
+ out_features = []
244
+ for i in range(len(self.res_blocks)):
245
+ feature = self.input_blocks[i](in_features[i])
246
+ if i == 0:
247
+ x = feature
248
+ elif feature is not None:
249
+ x = x + feature
250
+ x = self.res_blocks[i](x)
251
+ out_features.append(self.output_blocks[i](x))
252
+ if i < len(self.res_blocks) - 1:
253
+ x = self.resamplers[i](x)
254
+ return out_features
moge/model/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def wrap_module_with_gradient_checkpointing(module: nn.Module):
8
+ from torch.utils.checkpoint import checkpoint
9
+ class _CheckpointingWrapper(module.__class__):
10
+ _restore_cls = module.__class__
11
+ def forward(self, *args, **kwargs):
12
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
13
+
14
+ module.__class__ = _CheckpointingWrapper
15
+ return module
16
+
17
+
18
+ def unwrap_module_with_gradient_checkpointing(module: nn.Module):
19
+ module.__class__ = module.__class__._restore_cls
20
+
21
+
22
+ def wrap_dinov2_attention_with_sdpa(module: nn.Module):
23
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
24
+ class _AttentionWrapper(module.__class__):
25
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
26
+ B, N, C = x.shape
27
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
28
+
29
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
30
+
31
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
32
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
33
+
34
+ x = self.proj(x)
35
+ x = self.proj_drop(x)
36
+ return x
37
+ module.__class__ = _AttentionWrapper
38
+ return module
39
+
40
+
41
+ def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
42
+ group_to_use = torch.distributed.group.WORLD
43
+ world_size = group_to_use.size()
44
+ grad = bucket.buffer()
45
+ grad.div_(world_size)
46
+ torch.distributed.all_reduce(grad, group=group_to_use)
47
+ fut = torch.futures.Future()
48
+ fut.set_result(grad)
49
+ return fut
moge/model/v1.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import importlib
6
+ import warnings
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils
13
+ import torch.utils.checkpoint
14
+ import torch.version
15
+ import utils3d
16
+ from huggingface_hub import hf_hub_download
17
+
18
+
19
+ from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask
20
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
21
+ from ..utils.tools import timeit
22
+
23
+
24
+ class ResidualConvBlock(nn.Module):
25
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
26
+ super(ResidualConvBlock, self).__init__()
27
+ if out_channels is None:
28
+ out_channels = in_channels
29
+ if hidden_channels is None:
30
+ hidden_channels = in_channels
31
+
32
+ if activation =='relu':
33
+ activation_cls = lambda: nn.ReLU(inplace=True)
34
+ elif activation == 'leaky_relu':
35
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
36
+ elif activation =='silu':
37
+ activation_cls = lambda: nn.SiLU(inplace=True)
38
+ elif activation == 'elu':
39
+ activation_cls = lambda: nn.ELU(inplace=True)
40
+ else:
41
+ raise ValueError(f'Unsupported activation function: {activation}')
42
+
43
+ self.layers = nn.Sequential(
44
+ nn.GroupNorm(1, in_channels),
45
+ activation_cls(),
46
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
47
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
48
+ activation_cls(),
49
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
50
+ )
51
+
52
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
53
+
54
+ def forward(self, x):
55
+ skip = self.skip_connection(x)
56
+ x = self.layers(x)
57
+ x = x + skip
58
+ return x
59
+
60
+
61
+ class Head(nn.Module):
62
+ def __init__(
63
+ self,
64
+ num_features: int,
65
+ dim_in: int,
66
+ dim_out: List[int],
67
+ dim_proj: int = 512,
68
+ dim_upsample: List[int] = [256, 128, 128],
69
+ dim_times_res_block_hidden: int = 1,
70
+ num_res_blocks: int = 1,
71
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
72
+ last_res_blocks: int = 0,
73
+ last_conv_channels: int = 32,
74
+ last_conv_size: int = 1
75
+ ):
76
+ super().__init__()
77
+
78
+ self.projects = nn.ModuleList([
79
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
80
+ ])
81
+
82
+ self.upsample_blocks = nn.ModuleList([
83
+ nn.Sequential(
84
+ self._make_upsampler(in_ch + 2, out_ch),
85
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
86
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
87
+ ])
88
+
89
+ self.output_block = nn.ModuleList([
90
+ self._make_output_block(
91
+ dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
92
+ ) for dim_out_ in dim_out
93
+ ])
94
+
95
+ def _make_upsampler(self, in_channels: int, out_channels: int):
96
+ upsampler = nn.Sequential(
97
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
98
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
99
+ )
100
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
101
+ return upsampler
102
+
103
+ def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
104
+ return nn.Sequential(
105
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
106
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
107
+ nn.ReLU(inplace=True),
108
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
109
+ )
110
+
111
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
112
+ img_h, img_w = image.shape[-2:]
113
+ patch_h, patch_w = img_h // 14, img_w // 14
114
+
115
+ # Process the hidden states
116
+ x = torch.stack([
117
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
118
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
119
+ ], dim=1).sum(dim=1)
120
+
121
+ # Upsample stage
122
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
123
+ for i, block in enumerate(self.upsample_blocks):
124
+ # UV coordinates is for awareness of image aspect ratio
125
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
126
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
127
+ x = torch.cat([x, uv], dim=1)
128
+ for layer in block:
129
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
130
+
131
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
132
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
133
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
134
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
135
+ x = torch.cat([x, uv], dim=1)
136
+
137
+ if isinstance(self.output_block, nn.ModuleList):
138
+ output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
139
+ else:
140
+ output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
141
+
142
+ return output
143
+
144
+
145
+ class MoGeModel(nn.Module):
146
+ image_mean: torch.Tensor
147
+ image_std: torch.Tensor
148
+
149
+ def __init__(self,
150
+ encoder: str = 'dinov2_vitb14',
151
+ intermediate_layers: Union[int, List[int]] = 4,
152
+ dim_proj: int = 512,
153
+ dim_upsample: List[int] = [256, 128, 128],
154
+ dim_times_res_block_hidden: int = 1,
155
+ num_res_blocks: int = 1,
156
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
157
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
158
+ num_tokens_range: Tuple[Number, Number] = [1200, 2500],
159
+ last_res_blocks: int = 0,
160
+ last_conv_channels: int = 32,
161
+ last_conv_size: int = 1,
162
+ mask_threshold: float = 0.5,
163
+ **deprecated_kwargs
164
+ ):
165
+ super(MoGeModel, self).__init__()
166
+
167
+ if deprecated_kwargs:
168
+ # Process legacy arguments
169
+ if 'trained_area_range' in deprecated_kwargs:
170
+ num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2]
171
+ del deprecated_kwargs['trained_area_range']
172
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
173
+
174
+ self.encoder = encoder
175
+ self.remap_output = remap_output
176
+ self.intermediate_layers = intermediate_layers
177
+ self.num_tokens_range = num_tokens_range
178
+ self.mask_threshold = mask_threshold
179
+
180
+ # NOTE: We have copied the DINOv2 code in torchhub to this repository.
181
+ # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
182
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
183
+ self.backbone = hub_loader(pretrained=False)
184
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
185
+
186
+ self.head = Head(
187
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
188
+ dim_in=dim_feature,
189
+ dim_out=[3, 1],
190
+ dim_proj=dim_proj,
191
+ dim_upsample=dim_upsample,
192
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
193
+ num_res_blocks=num_res_blocks,
194
+ res_block_norm=res_block_norm,
195
+ last_res_blocks=last_res_blocks,
196
+ last_conv_channels=last_conv_channels,
197
+ last_conv_size=last_conv_size
198
+ )
199
+
200
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
201
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
202
+
203
+ self.register_buffer("image_mean", image_mean)
204
+ self.register_buffer("image_std", image_std)
205
+
206
+ @property
207
+ def device(self) -> torch.device:
208
+ return next(self.parameters()).device
209
+
210
+ @property
211
+ def dtype(self) -> torch.dtype:
212
+ return next(self.parameters()).dtype
213
+
214
+ @classmethod
215
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
216
+ """
217
+ Load a model from a checkpoint file.
218
+
219
+ ### Parameters:
220
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
221
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
222
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
223
+
224
+ ### Returns:
225
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
226
+ """
227
+ if Path(pretrained_model_name_or_path).exists():
228
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
229
+ else:
230
+ cached_checkpoint_path = hf_hub_download(
231
+ repo_id=pretrained_model_name_or_path,
232
+ repo_type="model",
233
+ filename="model.pt",
234
+ **hf_kwargs
235
+ )
236
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
237
+ model_config = checkpoint['model_config']
238
+ if model_kwargs is not None:
239
+ model_config.update(model_kwargs)
240
+ model = cls(**model_config)
241
+ model.load_state_dict(checkpoint['model'])
242
+ return model
243
+
244
+ def init_weights(self):
245
+ "Load the backbone with pretrained dinov2 weights from torch hub"
246
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
247
+ self.backbone.load_state_dict(state_dict)
248
+
249
+ def enable_gradient_checkpointing(self):
250
+ for i in range(len(self.backbone.blocks)):
251
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
252
+
253
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
254
+ if self.remap_output == 'linear':
255
+ pass
256
+ elif self.remap_output =='sinh':
257
+ points = torch.sinh(points)
258
+ elif self.remap_output == 'exp':
259
+ xy, z = points.split([2, 1], dim=-1)
260
+ z = torch.exp(z)
261
+ points = torch.cat([xy * z, z], dim=-1)
262
+ elif self.remap_output =='sinh_exp':
263
+ xy, z = points.split([2, 1], dim=-1)
264
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
265
+ else:
266
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
267
+ return points
268
+
269
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
270
+ original_height, original_width = image.shape[-2:]
271
+
272
+ # Resize to expected resolution defined by num_tokens
273
+ resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5
274
+ resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor)
275
+ image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True)
276
+
277
+ # Apply image transformation for DINOv2
278
+ image = (image - self.image_mean) / self.image_std
279
+ image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
280
+
281
+ # Get intermediate layers from the backbone
282
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
283
+
284
+ # Predict points (and mask)
285
+ output = self.head(features, image)
286
+ points, mask = output
287
+
288
+ # Make sure fp32 precision for output
289
+ with torch.autocast(device_type=image.device.type, dtype=torch.float32):
290
+ # Resize to original resolution
291
+ points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
292
+ mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
293
+
294
+ # Post-process points and mask
295
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
296
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
297
+
298
+ return_dict = {'points': points, 'mask': mask}
299
+ return return_dict
300
+
301
+ @torch.inference_mode()
302
+ def infer(
303
+ self,
304
+ image: torch.Tensor,
305
+ fov_x: Union[Number, torch.Tensor] = None,
306
+ resolution_level: int = 9,
307
+ num_tokens: int = None,
308
+ apply_mask: bool = True,
309
+ force_projection: bool = True,
310
+ use_fp16: bool = True,
311
+ ) -> Dict[str, torch.Tensor]:
312
+ """
313
+ User-friendly inference function
314
+
315
+ ### Parameters
316
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
317
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
318
+ - `resolution_level`: An integer [0-9] for the resolution level for inference.
319
+ The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
320
+ `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
321
+ - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
322
+ `resolution_level` will be ignored if `num_tokens` is provided. Default: None
323
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
324
+ - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
325
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
326
+
327
+ ### Returns
328
+
329
+ A dictionary containing the following keys:
330
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
331
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
332
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
333
+ """
334
+ if image.dim() == 3:
335
+ omit_batch_dim = True
336
+ image = image.unsqueeze(0)
337
+ else:
338
+ omit_batch_dim = False
339
+ image = image.to(dtype=self.dtype, device=self.device)
340
+
341
+ original_height, original_width = image.shape[-2:]
342
+ aspect_ratio = original_width / original_height
343
+
344
+ if num_tokens is None:
345
+ min_tokens, max_tokens = self.num_tokens_range
346
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
347
+
348
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
349
+ output = self.forward(image, num_tokens)
350
+ points, mask = output['points'], output['mask']
351
+
352
+ # Always process the output in fp32 precision
353
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
354
+ points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
355
+
356
+ mask_binary = mask > self.mask_threshold
357
+
358
+ # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
359
+ if fov_x is None:
360
+ focal, shift = recover_focal_shift(points, mask_binary)
361
+ else:
362
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
363
+ if focal.ndim == 0:
364
+ focal = focal[None].expand(points.shape[0])
365
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
366
+ fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
367
+ fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
368
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
369
+ depth = points[..., 2] + shift[..., None, None]
370
+
371
+ # If projection constraint is forced, recompute the point map using the actual depth map
372
+ if force_projection:
373
+ points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
374
+ else:
375
+ points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
376
+
377
+ # Apply mask if needed
378
+ if apply_mask:
379
+ points = torch.where(mask_binary[..., None], points, torch.inf)
380
+ depth = torch.where(mask_binary, depth, torch.inf)
381
+
382
+ return_dict = {
383
+ 'points': points,
384
+ 'intrinsics': intrinsics,
385
+ 'depth': depth,
386
+ 'mask': mask_binary,
387
+ }
388
+
389
+ if omit_batch_dim:
390
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
391
+
392
+ return return_dict
moge/model/v2.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import warnings
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils
11
+ import torch.utils.checkpoint
12
+ import torch.amp
13
+ import torch.version
14
+ import utils3d
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
18
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
19
+ from .modules import DINOv2Encoder, MLP, ConvStack
20
+
21
+
22
+ class MoGeModel(nn.Module):
23
+ encoder: DINOv2Encoder
24
+ neck: ConvStack
25
+ points_head: ConvStack
26
+ mask_head: ConvStack
27
+ scale_head: MLP
28
+ onnx_compatible_mode: bool
29
+
30
+ def __init__(self,
31
+ encoder: Dict[str, Any],
32
+ neck: Dict[str, Any],
33
+ points_head: Dict[str, Any] = None,
34
+ mask_head: Dict[str, Any] = None,
35
+ normal_head: Dict[str, Any] = None,
36
+ scale_head: Dict[str, Any] = None,
37
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
38
+ num_tokens_range: List[int] = [1200, 3600],
39
+ **deprecated_kwargs
40
+ ):
41
+ super(MoGeModel, self).__init__()
42
+ if deprecated_kwargs:
43
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
44
+
45
+ self.remap_output = remap_output
46
+ self.num_tokens_range = num_tokens_range
47
+
48
+ self.encoder = DINOv2Encoder(**encoder)
49
+ self.neck = ConvStack(**neck)
50
+ if points_head is not None:
51
+ self.points_head = ConvStack(**points_head)
52
+ if mask_head is not None:
53
+ self.mask_head = ConvStack(**mask_head)
54
+ if normal_head is not None:
55
+ self.normal_head = ConvStack(**normal_head)
56
+ if scale_head is not None:
57
+ self.scale_head = MLP(**scale_head)
58
+
59
+ @property
60
+ def device(self) -> torch.device:
61
+ return next(self.parameters()).device
62
+
63
+ @property
64
+ def dtype(self) -> torch.dtype:
65
+ return next(self.parameters()).dtype
66
+
67
+ @property
68
+ def onnx_compatible_mode(self) -> bool:
69
+ return getattr(self, "_onnx_compatible_mode", False)
70
+
71
+ @onnx_compatible_mode.setter
72
+ def onnx_compatible_mode(self, value: bool):
73
+ self._onnx_compatible_mode = value
74
+ self.encoder.onnx_compatible_mode = value
75
+
76
+ @classmethod
77
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
78
+ """
79
+ Load a model from a checkpoint file.
80
+
81
+ ### Parameters:
82
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
83
+ - `compiled`
84
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
85
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
86
+
87
+ ### Returns:
88
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
89
+ """
90
+ if Path(pretrained_model_name_or_path).exists():
91
+ checkpoint_path = pretrained_model_name_or_path
92
+ else:
93
+ checkpoint_path = hf_hub_download(
94
+ repo_id=pretrained_model_name_or_path,
95
+ repo_type="model",
96
+ filename="model.pt",
97
+ **hf_kwargs
98
+ )
99
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
100
+
101
+ model_config = checkpoint['model_config']
102
+ if model_kwargs is not None:
103
+ model_config.update(model_kwargs)
104
+ model = cls(**model_config)
105
+ model.load_state_dict(checkpoint['model'], strict=False)
106
+
107
+ return model
108
+
109
+ def init_weights(self):
110
+ self.encoder.init_weights()
111
+
112
+ def enable_gradient_checkpointing(self):
113
+ self.encoder.enable_gradient_checkpointing()
114
+ self.neck.enable_gradient_checkpointing()
115
+ for head in ['points_head', 'normal_head', 'mask_head']:
116
+ if hasattr(self, head):
117
+ getattr(self, head).enable_gradient_checkpointing()
118
+
119
+ def enable_pytorch_native_sdpa(self):
120
+ self.encoder.enable_pytorch_native_sdpa()
121
+
122
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
123
+ if self.remap_output == 'linear':
124
+ pass
125
+ elif self.remap_output =='sinh':
126
+ points = torch.sinh(points)
127
+ elif self.remap_output == 'exp':
128
+ xy, z = points.split([2, 1], dim=-1)
129
+ z = torch.exp(z)
130
+ points = torch.cat([xy * z, z], dim=-1)
131
+ elif self.remap_output =='sinh_exp':
132
+ xy, z = points.split([2, 1], dim=-1)
133
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
134
+ else:
135
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
136
+ return points
137
+
138
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
139
+ batch_size, _, img_h, img_w = image.shape
140
+ device, dtype = image.device, image.dtype
141
+
142
+ aspect_ratio = img_w / img_h
143
+ base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
144
+ num_tokens = base_h * base_w
145
+
146
+ # Backbones encoding
147
+ features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
148
+ features = [features, None, None, None, None]
149
+
150
+ # Concat UVs for aspect ratio input
151
+ for level in range(5):
152
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
153
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
154
+ if features[level] is None:
155
+ features[level] = uv
156
+ else:
157
+ features[level] = torch.concat([features[level], uv], dim=1)
158
+
159
+ # Shared neck
160
+ features = self.neck(features)
161
+
162
+ # Heads decoding
163
+
164
+ points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
165
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
166
+
167
+ # Resize
168
+ points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
169
+
170
+ # Remap output
171
+ if points is not None:
172
+ points = points.permute(0, 2, 3, 1)
173
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
174
+ if normal is not None:
175
+ normal = normal.permute(0, 2, 3, 1)
176
+ normal = F.normalize(normal, dim=-1)
177
+ if mask is not None:
178
+ mask = mask.squeeze(1).sigmoid()
179
+ if metric_scale is not None:
180
+ metric_scale = metric_scale.squeeze(1).exp()
181
+
182
+ return_dict = {
183
+ 'points': points,
184
+ 'normal': normal,
185
+ 'mask': mask,
186
+ 'metric_scale': metric_scale
187
+ }
188
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
189
+
190
+ return return_dict
191
+
192
+ @torch.inference_mode()
193
+ def infer(
194
+ self,
195
+ image: torch.Tensor,
196
+ num_tokens: int = None,
197
+ resolution_level: int = 9,
198
+ force_projection: bool = True,
199
+ apply_mask: Literal[False, True, 'blend'] = True,
200
+ fov_x: Optional[Union[Number, torch.Tensor]] = None,
201
+ use_fp16: bool = True,
202
+ ) -> Dict[str, torch.Tensor]:
203
+ """
204
+ User-friendly inference function
205
+
206
+ ### Parameters
207
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
208
+ - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
209
+ More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
210
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
211
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
212
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
213
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
214
+
215
+ ### Returns
216
+
217
+ A dictionary containing the following keys:
218
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
219
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
220
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
221
+ """
222
+ if image.dim() == 3:
223
+ omit_batch_dim = True
224
+ image = image.unsqueeze(0)
225
+ else:
226
+ omit_batch_dim = False
227
+ image = image.to(dtype=self.dtype, device=self.device)
228
+
229
+ original_height, original_width = image.shape[-2:]
230
+ area = original_height * original_width
231
+ aspect_ratio = original_width / original_height
232
+
233
+ # Determine the number of base tokens to use
234
+ if num_tokens is None:
235
+ min_tokens, max_tokens = self.num_tokens_range
236
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
237
+
238
+ # Forward pass
239
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
240
+ output = self.forward(image, num_tokens=num_tokens)
241
+ points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
242
+
243
+ # Always process the output in fp32 precision
244
+ points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
245
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
246
+ if mask is not None:
247
+ mask_binary = mask > 0.5
248
+ else:
249
+ mask_binary = None
250
+
251
+ if points is not None:
252
+ # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
253
+ # NOTE: Focal here is the focal length relative to half the image diagonal
254
+ if fov_x is None:
255
+ # Recover focal and shift from predicted point map
256
+ focal, shift = recover_focal_shift(points, mask_binary)
257
+ else:
258
+ # Focal is known, recover shift only
259
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
260
+ if focal.ndim == 0:
261
+ focal = focal[None].expand(points.shape[0])
262
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
263
+ fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
264
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
265
+ points[..., 2] += shift[..., None, None]
266
+ if mask_binary is not None:
267
+ mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
268
+ depth = points[..., 2].clone()
269
+ else:
270
+ depth, intrinsics = None, None
271
+
272
+ # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
273
+ if force_projection and depth is not None:
274
+ points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
275
+
276
+ # Apply metric scale
277
+ if metric_scale is not None:
278
+ if points is not None:
279
+ points *= metric_scale[:, None, None, None]
280
+ if depth is not None:
281
+ depth *= metric_scale[:, None, None]
282
+
283
+ # Apply mask
284
+ if apply_mask and mask_binary is not None:
285
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
286
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
287
+ normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
288
+
289
+ return depth.squeeze().cpu().numpy(), mask_binary.squeeze().cpu().numpy(), intrinsics.squeeze().cpu().numpy()
290
+
291
+ # return_dict = {
292
+ # 'points': points,
293
+ # 'intrinsics': intrinsics,
294
+ # 'depth': depth,
295
+ # 'mask': mask_binary,
296
+ # 'normal': normal
297
+ # }
298
+ # return_dict = {k: v for k, v in return_dict.items() if v is not None}
299
+
300
+ # if omit_batch_dim:
301
+ # return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
302
+
303
+ # return return_dict
moge/scripts/__init__.py ADDED
File without changes
moge/scripts/app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ import sys
4
+ from pathlib import Path
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+ import time
8
+ import uuid
9
+ import tempfile
10
+ import itertools
11
+ from typing import *
12
+ import atexit
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ import shutil
15
+
16
+ import click
17
+
18
+
19
+ @click.command(help='Web demo')
20
+ @click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
21
+ @click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.')
22
+ @click.option('--version', 'model_version', default='v2', help='The version of the model.')
23
+ @click.option('--fp16', 'use_fp16', is_flag=True, help='Whether to use fp16 inference.')
24
+ def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool):
25
+ print("Import modules...")
26
+ # Lazy import
27
+ import cv2
28
+ import torch
29
+ import numpy as np
30
+ import trimesh
31
+ import trimesh.visual
32
+ from PIL import Image
33
+ import gradio as gr
34
+ try:
35
+ import spaces # This is for deployment at huggingface.co/spaces
36
+ HUGGINFACE_SPACES_INSTALLED = True
37
+ except ImportError:
38
+ HUGGINFACE_SPACES_INSTALLED = False
39
+
40
+ import utils3d
41
+ from moge.utils.io import write_normal
42
+ from moge.utils.vis import colorize_depth, colorize_normal
43
+ from moge.model import import_model_class_by_version
44
+ from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
45
+ from moge.utils.tools import timeit
46
+
47
+ print("Load model...")
48
+ if pretrained_model_name_or_path is None:
49
+ DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
50
+ "v1": "Ruicheng/moge-vitl",
51
+ "v2": "Ruicheng/moge-2-vitl-normal",
52
+ }
53
+ pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
54
+ model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval()
55
+ if use_fp16:
56
+ model.half()
57
+ thread_pool_executor = ThreadPoolExecutor(max_workers=1)
58
+
59
+ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
60
+ def _delete():
61
+ try:
62
+ os.remove(path)
63
+ except FileNotFoundError:
64
+ pass
65
+ def _wait_and_delete():
66
+ time.sleep(delay)
67
+ _delete(path)
68
+ thread_pool_executor.submit(_wait_and_delete)
69
+ atexit.register(_delete)
70
+
71
+ # Inference on GPU.
72
+ @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x)
73
+ def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]:
74
+ image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255
75
+ output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16)
76
+ output = {k: v.cpu().numpy() for k, v in output.items()}
77
+ return output
78
+
79
+ # Full inference pipeline
80
+ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None):
81
+ larger_size = max(image.shape[:2])
82
+ if larger_size > max_size:
83
+ scale = max_size / larger_size
84
+ image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
85
+
86
+ height, width = image.shape[:2]
87
+
88
+ resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9)
89
+ output = run_with_gpu(image, resolution_level_int, apply_mask)
90
+
91
+ points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None)
92
+
93
+ if remove_edge:
94
+ mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04)
95
+ else:
96
+ mask_cleaned = mask
97
+
98
+ results = {
99
+ **output,
100
+ 'mask_cleaned': mask_cleaned,
101
+ 'image': image
102
+ }
103
+
104
+ # depth & normal visualization
105
+ depth_vis = colorize_depth(depth)
106
+ if normal is not None:
107
+ normal_vis = colorize_normal(normal)
108
+ else:
109
+ normal_vis = gr.update(label="Normal map (not avalable for this model)")
110
+
111
+ # mesh & pointcloud
112
+ if normal is None:
113
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
114
+ points,
115
+ image.astype(np.float32) / 255,
116
+ utils3d.numpy.image_uv(width=width, height=height),
117
+ mask=mask_cleaned,
118
+ tri=True
119
+ )
120
+ vertex_normals = None
121
+ else:
122
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
123
+ points,
124
+ image.astype(np.float32) / 255,
125
+ utils3d.numpy.image_uv(width=width, height=height),
126
+ normal,
127
+ mask=mask_cleaned,
128
+ tri=True
129
+ )
130
+ vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
131
+ vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32)
132
+ if vertex_normals is not None:
133
+ vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
134
+
135
+ tempdir = Path(tempfile.gettempdir(), 'moge')
136
+ tempdir.mkdir(exist_ok=True)
137
+ output_path = Path(tempdir, request.session_hash)
138
+ shutil.rmtree(output_path, ignore_errors=True)
139
+ output_path.mkdir(exist_ok=True, parents=True)
140
+ trimesh.Trimesh(
141
+ vertices=vertices,
142
+ faces=faces,
143
+ visual = trimesh.visual.texture.TextureVisuals(
144
+ uv=vertex_uvs,
145
+ material=trimesh.visual.material.PBRMaterial(
146
+ baseColorTexture=Image.fromarray(image),
147
+ metallicFactor=0.5,
148
+ roughnessFactor=1.0
149
+ )
150
+ ),
151
+ vertex_normals=vertex_normals,
152
+ process=False
153
+ ).export(output_path / 'mesh.glb')
154
+ pointcloud = trimesh.PointCloud(
155
+ vertices=vertices,
156
+ colors=vertex_colors,
157
+ )
158
+ pointcloud.vertex_normals = vertex_normals
159
+ pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True)
160
+ trimesh.PointCloud(
161
+ vertices=vertices,
162
+ colors=vertex_colors,
163
+ ).export(output_path / 'pointcloud.glb', include_normals=True)
164
+ cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255)
165
+ cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
166
+ cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
167
+ if normal is not None:
168
+ cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
169
+
170
+ files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png']
171
+ if normal is not None:
172
+ files.append('normal.exr')
173
+
174
+ for f in files:
175
+ delete_later(output_path / f)
176
+
177
+ # FOV
178
+ intrinsics = results['intrinsics']
179
+ fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
180
+ fov_x, fov_y = np.rad2deg([fov_x, fov_y])
181
+
182
+ # messages
183
+ viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.'
184
+ if resolution_level != 'Ultra':
185
+ depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.'
186
+ else:
187
+ depth_message = ""
188
+
189
+ return (
190
+ results,
191
+ depth_vis,
192
+ normal_vis,
193
+ output_path / 'pointcloud.glb',
194
+ [(output_path / f).as_posix() for f in files if (output_path / f).exists()],
195
+ f'- **Horizontal FOV: {fov_x:.1f}°**. \n - **Vertical FOV: {fov_y:.1f}°**',
196
+ viewer_message,
197
+ depth_message
198
+ )
199
+
200
+ def reset_measure(results: Dict[str, np.ndarray]):
201
+ return [results['image'], [], ""]
202
+
203
+
204
+ def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData):
205
+ point2d = event.index[0], event.index[1]
206
+ measure_points.append(point2d)
207
+
208
+ image = results['image'].copy()
209
+ for p in measure_points:
210
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
211
+
212
+ depth_text = ""
213
+ for i, p in enumerate(measure_points):
214
+ d = results['depth'][p[1], p[0]]
215
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
216
+
217
+ if len(measure_points) == 2:
218
+ point1, point2 = measure_points
219
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
220
+ distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]])
221
+ measure_points = []
222
+
223
+ distance_text = f"- **Distance: {distance:.2f}m**"
224
+
225
+ text = depth_text + distance_text
226
+ return [image, measure_points, text]
227
+ else:
228
+ return [image, measure_points, depth_text]
229
+
230
+ print("Create Gradio app...")
231
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
232
+ gr.Markdown(
233
+ f'''
234
+ <div align="center">
235
+ <h1> Turn a 2D image into 3D with MoGe <a title="Github" href="https://github.com/microsoft/MoGe" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/MoGe?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1>
236
+ </div>
237
+ ''')
238
+ results = gr.State(value=None)
239
+ measure_points = gr.State(value=[])
240
+
241
+ with gr.Row():
242
+ with gr.Column():
243
+ input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image")
244
+ with gr.Accordion(label="Settings", open=False):
245
+ max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048)
246
+ resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High')
247
+ apply_mask = gr.Checkbox(value=True, label="Apply mask")
248
+ remove_edges = gr.Checkbox(value=True, label="Remove edges")
249
+ submit_btn = gr.Button("Submit", variant='primary')
250
+
251
+ with gr.Column():
252
+ with gr.Tabs():
253
+ with gr.Tab("3D View"):
254
+ viewer_message = gr.Markdown("")
255
+ model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh")
256
+ fov = gr.Markdown()
257
+ with gr.Tab("Depth"):
258
+ depth_message = gr.Markdown("")
259
+ depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False)
260
+ with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')):
261
+ normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False)
262
+ with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')):
263
+ gr.Markdown("### Click on the image to measure the distance between two points. \n"
264
+ "**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).")
265
+ measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[])
266
+ measure_text = gr.Markdown("")
267
+ with gr.Tab("Download"):
268
+ files = gr.File(type='filepath', label="Output Files")
269
+
270
+ if Path('example_images').exists():
271
+ example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']])))
272
+ examples = gr.Examples(
273
+ examples = example_image_paths,
274
+ inputs=input_image,
275
+ label="Examples"
276
+ )
277
+
278
+ submit_btn.click(
279
+ fn=lambda: [None, None, None, None, None, "", "", ""],
280
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
281
+ ).then(
282
+ fn=run,
283
+ inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges],
284
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
285
+ ).then(
286
+ fn=reset_measure,
287
+ inputs=[results],
288
+ outputs=[measure_image, measure_points, measure_text]
289
+ )
290
+
291
+ measure_image.select(
292
+ fn=measure,
293
+ inputs=[results, measure_points],
294
+ outputs=[measure_image, measure_points, measure_text]
295
+ )
296
+
297
+ demo.launch(share=share)
298
+
299
+
300
+ if __name__ == '__main__':
301
+ main()
moge/scripts/cli.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from pathlib import Path
4
+ import sys
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+
8
+ import click
9
+
10
+
11
+ @click.group(help='MoGe command line interface.')
12
+ def cli():
13
+ pass
14
+
15
+ def main():
16
+ from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data
17
+ cli.add_command(app.main, name='app')
18
+ cli.add_command(infer.main, name='infer')
19
+ cli.add_command(infer_baseline.main, name='infer_baseline')
20
+ cli.add_command(infer_panorama.main, name='infer_panorama')
21
+ cli.add_command(eval_baseline.main, name='eval_baseline')
22
+ cli.add_command(vis_data.main, name='vis_data')
23
+ cli()
24
+
25
+
26
+ if __name__ == '__main__':
27
+ main()
moge/scripts/eval_baseline.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
5
+ sys.path.insert(0, _package_root)
6
+ import json
7
+ from typing import *
8
+ import importlib
9
+ import importlib.util
10
+
11
+ import click
12
+
13
+
14
+ @click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.')
15
+ @click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.')
16
+ @click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. '
17
+ 'Defaults to "configs/eval/all_benchmarks.json".')
18
+ @click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.')
19
+ @click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.')
20
+ @click.option('--dump_pred', is_flag=True, help='Dump predition results.')
21
+ @click.option('--dump_gt', is_flag=True, help='Dump ground truth.')
22
+ @click.pass_context
23
+ def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool):
24
+ # Lazy import
25
+ import cv2
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import utils3d
31
+
32
+ from moge.test.baseline import MGEBaselineInterface
33
+ from moge.test.dataloader import EvalDataLoaderPipeline
34
+ from moge.test.metrics import compute_metrics
35
+ from moge.utils.geometry_torch import intrinsics_to_fov
36
+ from moge.utils.vis import colorize_depth, colorize_normal
37
+ from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
38
+
39
+ # Load the baseline model
40
+ module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
41
+ baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
42
+ baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
43
+
44
+ # Load the evaluation configurations
45
+ with open(config_path, 'r') as f:
46
+ config = json.load(f)
47
+
48
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
49
+ all_metrics = {}
50
+ # Iterate over the dataset
51
+ for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'):
52
+ filenames, metrics_list = [], []
53
+ with (
54
+ EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe,
55
+ tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar
56
+ ):
57
+ # Iterate over the samples in the dataset
58
+ for i in range(len(eval_data_pipe)):
59
+ sample = eval_data_pipe.get()
60
+ sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()}
61
+ image = sample['image']
62
+ gt_intrinsics = sample['intrinsics']
63
+
64
+ # Inference
65
+ torch.cuda.synchronize()
66
+ with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer:
67
+ if oracle_mode:
68
+ pred = baseline.infer_for_evaluation(image, gt_intrinsics)
69
+ else:
70
+ pred = baseline.infer_for_evaluation(image)
71
+ torch.cuda.synchronize()
72
+
73
+ # Compute metrics
74
+ metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt)
75
+ metrics['inference_time'] = timer.time
76
+ metrics_list.append(metrics)
77
+
78
+ # Dump results
79
+ dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', ''))
80
+ if dump_pred:
81
+ dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True)
82
+ cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
83
+
84
+ with Path(dump_path, 'pred', 'metrics.json').open('w') as f:
85
+ json.dump(metrics, f, indent=4)
86
+
87
+ if 'pred_points' in misc:
88
+ points = misc['pred_points'].cpu().numpy()
89
+ cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
90
+
91
+ if 'pred_depth' in misc:
92
+ depth = misc['pred_depth'].cpu().numpy()
93
+ if 'mask' in pred:
94
+ mask = pred['mask'].cpu().numpy()
95
+ depth = np.where(mask, depth, np.inf)
96
+ cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
97
+
98
+ if 'mask' in pred:
99
+ mask = pred['mask'].cpu().numpy()
100
+ cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8))
101
+
102
+ if 'normal' in pred:
103
+ normal = pred['normal'].cpu().numpy()
104
+ cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
105
+
106
+ if 'intrinsics' in pred:
107
+ intrinsics = pred['intrinsics']
108
+ fov_x, fov_y = intrinsics_to_fov(intrinsics)
109
+ with open(dump_path / 'pred' / 'fov.json', 'w') as f:
110
+ json.dump({
111
+ 'fov_x': np.rad2deg(fov_x.item()),
112
+ 'fov_y': np.rad2deg(fov_y.item()),
113
+ 'intrinsics': intrinsics.cpu().numpy().tolist(),
114
+ }, f)
115
+
116
+ if dump_gt:
117
+ dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True)
118
+ cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
119
+
120
+ if 'points' in sample:
121
+ points = sample['points']
122
+ cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
123
+
124
+ if 'depth' in sample:
125
+ depth = sample['depth']
126
+ mask = sample['depth_mask']
127
+ cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR))
128
+
129
+ if 'normal' in sample:
130
+ normal = sample['normal']
131
+ cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR))
132
+
133
+ if 'depth_mask' in sample:
134
+ mask = sample['depth_mask']
135
+ cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8))
136
+
137
+ if 'intrinsics' in sample:
138
+ intrinsics = sample['intrinsics']
139
+ fov_x, fov_y = intrinsics_to_fov(intrinsics)
140
+ with open(dump_path / 'gt' / 'info.json', 'w') as f:
141
+ json.dump({
142
+ 'fov_x': np.rad2deg(fov_x.item()),
143
+ 'fov_y': np.rad2deg(fov_y.item()),
144
+ 'intrinsics': intrinsics.cpu().numpy().tolist(),
145
+ }, f)
146
+
147
+ # Save intermediate results
148
+ if i % 100 == 0 or i == len(eval_data_pipe) - 1:
149
+ Path(output_path).write_text(
150
+ json.dumps({
151
+ **all_metrics,
152
+ benchmark_name: key_average(metrics_list)
153
+ }, indent=4)
154
+ )
155
+ pbar.update(1)
156
+
157
+ all_metrics[benchmark_name] = key_average(metrics_list)
158
+
159
+ # Save final results
160
+ all_metrics['mean'] = key_average(list(all_metrics.values()))
161
+ Path(output_path).write_text(json.dumps(all_metrics, indent=4))
162
+
163
+
164
+ if __name__ == '__main__':
165
+ main()
moge/scripts/infer.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from pathlib import Path
4
+ import sys
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+ from typing import *
8
+ import itertools
9
+ import json
10
+ import warnings
11
+
12
+
13
+ import click
14
+
15
+
16
+ @click.command(help='Inference script')
17
+ @click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.')
18
+ @click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.')
19
+ @click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path')
20
+ @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default=None, help='Pretrained model name or path. If not provided, the corresponding default model will be chosen.')
21
+ @click.option('--version', 'model_version', type=click.Choice(['v1', 'v2']), default='v2', help='Model version. Defaults to "v2"')
22
+ @click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
23
+ @click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for much faster inference.')
24
+ @click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
25
+ @click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \
26
+ Higher value means more tokens and the finer details will be captured, but inference can be slower. \
27
+ Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \
28
+ `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.')
29
+ @click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \
30
+ `resolution_level` will be ignored if `num_tokens` is provided. Default: None')
31
+ @click.option('--threshold', type=float, default=0.04, help='Threshold for removing edges. Defaults to 0.01. Smaller value removes more edges. "inf" means no thresholding.')
32
+ @click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps (image, point map, depth map, normal map, mask) and fov.')
33
+ @click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
34
+ @click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
35
+ @click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
36
+ def main(
37
+ input_path: str,
38
+ fov_x_: float,
39
+ output_path: str,
40
+ pretrained_model_name_or_path: str,
41
+ model_version: str,
42
+ device_name: str,
43
+ use_fp16: bool,
44
+ resize_to: int,
45
+ resolution_level: int,
46
+ num_tokens: int,
47
+ threshold: float,
48
+ save_maps_: bool,
49
+ save_glb_: bool,
50
+ save_ply_: bool,
51
+ show: bool,
52
+ ):
53
+ import cv2
54
+ import numpy as np
55
+ import torch
56
+ from PIL import Image
57
+ from tqdm import tqdm
58
+ import trimesh
59
+ import trimesh.visual
60
+ import click
61
+
62
+ from moge.model import import_model_class_by_version
63
+ from moge.utils.io import save_glb, save_ply
64
+ from moge.utils.vis import colorize_depth, colorize_normal
65
+ from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
66
+ import utils3d
67
+
68
+ device = torch.device(device_name)
69
+
70
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
71
+ if Path(input_path).is_dir():
72
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
73
+ else:
74
+ image_paths = [Path(input_path)]
75
+
76
+ if len(image_paths) == 0:
77
+ raise FileNotFoundError(f'No image files found in {input_path}')
78
+
79
+ if pretrained_model_name_or_path is None:
80
+ DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
81
+ "v1": "Ruicheng/moge-vitl",
82
+ "v2": "Ruicheng/moge-2-vitl-normal",
83
+ }
84
+ pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
85
+ model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval()
86
+ if use_fp16:
87
+ model.half()
88
+
89
+ if not any([save_maps_, save_glb_, save_ply_]):
90
+ warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
91
+ save_maps_ = save_glb_ = save_ply_ = True
92
+
93
+ for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
94
+ image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
95
+ height, width = image.shape[:2]
96
+ if resize_to is not None:
97
+ height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
98
+ image = cv2.resize(image, (width, height), cv2.INTER_AREA)
99
+ image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
100
+
101
+ # Inference
102
+ output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16)
103
+ points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy()
104
+ normal = output['normal'].cpu().numpy() if 'normal' in output else None
105
+
106
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
107
+ save_path.mkdir(exist_ok=True, parents=True)
108
+
109
+ # Save images / maps
110
+ if save_maps_:
111
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
112
+ cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
113
+ cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
114
+ cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8))
115
+ cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
116
+ if normal is not None:
117
+ cv2.imwrite(str(save_path / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
118
+ fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
119
+ with open(save_path / 'fov.json', 'w') as f:
120
+ json.dump({
121
+ 'fov_x': round(float(np.rad2deg(fov_x)), 2),
122
+ 'fov_y': round(float(np.rad2deg(fov_y)), 2),
123
+ }, f)
124
+
125
+ # Export mesh & visulization
126
+ if save_glb_ or save_ply_ or show:
127
+ mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=threshold)
128
+ if normal is None:
129
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
130
+ points,
131
+ image.astype(np.float32) / 255,
132
+ utils3d.numpy.image_uv(width=width, height=height),
133
+ mask=mask_cleaned,
134
+ tri=True
135
+ )
136
+ vertex_normals = None
137
+ else:
138
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
139
+ points,
140
+ image.astype(np.float32) / 255,
141
+ utils3d.numpy.image_uv(width=width, height=height),
142
+ normal,
143
+ mask=mask_cleaned,
144
+ tri=True
145
+ )
146
+ # When exporting the model, follow the OpenGL coordinate conventions:
147
+ # - world coordinate system: x right, y up, z backward.
148
+ # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
149
+ vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
150
+ if normal is not None:
151
+ vertex_normals = vertex_normals * [1, -1, -1]
152
+
153
+ if save_glb_:
154
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image, vertex_normals)
155
+
156
+ if save_ply_:
157
+ save_ply(save_path / 'pointcloud.ply', vertices, np.zeros((0, 3), dtype=np.int32), vertex_colors, vertex_normals)
158
+
159
+ if show:
160
+ trimesh.Trimesh(
161
+ vertices=vertices,
162
+ vertex_colors=vertex_colors,
163
+ vertex_normals=vertex_normals,
164
+ faces=faces,
165
+ process=False
166
+ ).show()
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
moge/scripts/infer_baseline.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from pathlib import Path
4
+ import sys
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+ import json
8
+ from pathlib import Path
9
+ from typing import *
10
+ import itertools
11
+ import warnings
12
+
13
+ import click
14
+
15
+
16
+ @click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods')
17
+ @click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.')
18
+ @click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder')
19
+ @click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder')
20
+ @click.option('--size', 'image_size', type=int, default=None, help='Resize input image')
21
+ @click.option('--skip', is_flag=True, help='Skip existing output')
22
+ @click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps')
23
+ @click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format')
24
+ @click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format')
25
+ @click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh')
26
+ @click.pass_context
27
+ def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float):
28
+ # Lazy import
29
+ import cv2
30
+ import numpy as np
31
+ from tqdm import tqdm
32
+ import torch
33
+ import utils3d
34
+
35
+ from moge.utils.io import save_ply, save_glb
36
+ from moge.utils.geometry_numpy import intrinsics_to_fov_numpy
37
+ from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity
38
+ from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
39
+ from moge.test.baseline import MGEBaselineInterface
40
+
41
+ # Load the baseline model
42
+ module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
43
+ baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
44
+ baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
45
+
46
+ # Input images list
47
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
48
+ if Path(input_path).is_dir():
49
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
50
+ else:
51
+ image_paths = [Path(input_path)]
52
+
53
+ if not any([save_maps_, save_glb_, save_ply_]):
54
+ warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.')
55
+ save_maps_ = True
56
+
57
+ for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
58
+ # Load one image at a time
59
+ image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
60
+ height, width = image_np.shape[:2]
61
+ if image_size is not None and max(image_np.shape[:2]) > image_size:
62
+ height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height))
63
+ image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA)
64
+ image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device)
65
+
66
+ # Inference
67
+ torch.cuda.synchronize()
68
+ with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)):
69
+ output = baseline.infer(image)
70
+ torch.cuda.synchronize()
71
+
72
+ inference_time = timer.average_time
73
+ pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'})
74
+
75
+ # Save the output
76
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
77
+ if skip and save_path.exists():
78
+ continue
79
+ save_path.mkdir(parents=True, exist_ok=True)
80
+
81
+ if save_maps_:
82
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
83
+
84
+ if 'mask' in output:
85
+ mask = output['mask'].cpu().numpy()
86
+ cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8))
87
+
88
+ for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']:
89
+ if k in output:
90
+ points = output[k].cpu().numpy()
91
+ cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
92
+
93
+ for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']:
94
+ if k in output:
95
+ depth = output[k].cpu().numpy()
96
+ cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
97
+ if k in ['depth_metric', 'depth_scale_invariant']:
98
+ depth_vis = colorize_depth(depth)
99
+ elif k == 'depth_affine_invariant':
100
+ depth_vis = colorize_depth_affine(depth)
101
+ elif k == 'disparity_affine_invariant':
102
+ depth_vis = colorize_disparity(depth)
103
+ cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR))
104
+
105
+ if 'intrinsics' in output:
106
+ intrinsics = output['intrinsics'].cpu().numpy()
107
+ fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics)
108
+ with open(save_path / 'fov.json', 'w') as f:
109
+ json.dump({
110
+ 'fov_x': float(np.rad2deg(fov_x)),
111
+ 'fov_y': float(np.rad2deg(fov_y)),
112
+ 'intrinsics': intrinsics.tolist()
113
+ }, f, indent=4)
114
+
115
+ # Export mesh & visulization
116
+ if save_ply_ or save_glb_:
117
+ assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output'
118
+ points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy()
119
+ mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool)
120
+ normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask)
121
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
122
+ points,
123
+ image_np.astype(np.float32) / 255,
124
+ utils3d.numpy.image_uv(width=width, height=height),
125
+ mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)),
126
+ tri=True
127
+ )
128
+ # When exporting the model, follow the OpenGL coordinate conventions:
129
+ # - world coordinate system: x right, y up, z backward.
130
+ # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
131
+ vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
132
+
133
+ if save_glb_:
134
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np)
135
+
136
+ if save_ply_:
137
+ save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
138
+
139
+ if __name__ == '__main__':
140
+ main()
moge/scripts/infer_panorama.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from pathlib import Path
4
+ import sys
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+ from typing import *
8
+ import itertools
9
+ import json
10
+ import warnings
11
+
12
+ import click
13
+
14
+
15
+ @click.command(help='Inference script for panorama images')
16
+ @click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.')
17
+ @click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path')
18
+ @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"')
19
+ @click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
20
+ @click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
21
+ @click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.')
22
+ @click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.')
23
+ @click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.')
24
+ @click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.')
25
+ @click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).')
26
+ @click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
27
+ @click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
28
+ @click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
29
+ def main(
30
+ input_path: str,
31
+ output_path: str,
32
+ pretrained_model_name_or_path: str,
33
+ device_name: str,
34
+ resize_to: int,
35
+ resolution_level: int,
36
+ threshold: float,
37
+ batch_size: int,
38
+ save_splitted: bool,
39
+ save_maps_: bool,
40
+ save_glb_: bool,
41
+ save_ply_: bool,
42
+ show: bool,
43
+ ):
44
+ # Lazy import
45
+ import cv2
46
+ import numpy as np
47
+ from numpy import ndarray
48
+ import torch
49
+ from PIL import Image
50
+ from tqdm import tqdm, trange
51
+ import trimesh
52
+ import trimesh.visual
53
+ from scipy.sparse import csr_array, hstack, vstack
54
+ from scipy.ndimage import convolve
55
+ from scipy.sparse.linalg import lsmr
56
+
57
+ import utils3d
58
+ from moge.model.v1 import MoGeModel
59
+ from moge.utils.io import save_glb, save_ply
60
+ from moge.utils.vis import colorize_depth
61
+ from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth
62
+
63
+
64
+ device = torch.device(device_name)
65
+
66
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
67
+ if Path(input_path).is_dir():
68
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
69
+ else:
70
+ image_paths = [Path(input_path)]
71
+
72
+ if len(image_paths) == 0:
73
+ raise FileNotFoundError(f'No image files found in {input_path}')
74
+
75
+ # Write outputs
76
+ if not any([save_maps_, save_glb_, save_ply_]):
77
+ warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
78
+ save_maps_ = save_glb_ = save_ply_ = True
79
+
80
+ model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval()
81
+
82
+ for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)):
83
+ image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
84
+ height, width = image.shape[:2]
85
+ if resize_to is not None:
86
+ height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
87
+ image = cv2.resize(image, (width, height), cv2.INTER_AREA)
88
+
89
+ splitted_extrinsics, splitted_intriniscs = get_panorama_cameras()
90
+ splitted_resolution = 512
91
+ splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution)
92
+
93
+ # Infer each view
94
+ print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
95
+
96
+ splitted_distance_maps, splitted_masks = [], []
97
+ for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False):
98
+ image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
99
+ fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size])))
100
+ fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device)
101
+ output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False)
102
+ distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy()
103
+ splitted_distance_maps.extend(list(distance_map))
104
+ splitted_masks.extend(list(mask))
105
+
106
+ # Save splitted
107
+ if save_splitted:
108
+ splitted_save_path = Path(output_path, image_path.stem, 'splitted')
109
+ splitted_save_path.mkdir(exist_ok=True, parents=True)
110
+ for i in range(len(splitted_images)):
111
+ cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR))
112
+ cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR))
113
+
114
+ # Merge
115
+ print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging')
116
+
117
+ merging_width, merging_height = min(1920, width), min(960, height)
118
+ panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs)
119
+ panorama_depth = panorama_depth.astype(np.float32)
120
+ panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR)
121
+ panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0
122
+ points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height))
123
+
124
+ # Write outputs
125
+ print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
126
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
127
+ save_path.mkdir(exist_ok=True, parents=True)
128
+ if save_maps_:
129
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
130
+ cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR))
131
+ cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
132
+ cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
133
+ cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8))
134
+
135
+ # Export mesh & visulization
136
+ if save_glb_ or save_ply_ or show:
137
+ normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask)
138
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
139
+ points,
140
+ image.astype(np.float32) / 255,
141
+ utils3d.numpy.image_uv(width=width, height=height),
142
+ mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)),
143
+ tri=True
144
+ )
145
+
146
+ if save_glb_:
147
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image)
148
+
149
+ if save_ply_:
150
+ save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
151
+
152
+ if show:
153
+ trimesh.Trimesh(
154
+ vertices=vertices,
155
+ vertex_colors=vertex_colors,
156
+ faces=faces,
157
+ process=False
158
+ ).show()
159
+
160
+
161
+ if __name__ == '__main__':
162
+ main()
moge/scripts/train.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
5
+ sys.path.insert(0, _package_root)
6
+ import json
7
+ import time
8
+ import random
9
+ from typing import *
10
+ import itertools
11
+ from contextlib import nullcontext
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ import io
14
+
15
+ import numpy as np
16
+ import cv2
17
+ from PIL import Image
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.version
22
+ import accelerate
23
+ from accelerate import Accelerator, DistributedDataParallelKwargs
24
+ from accelerate.utils import set_seed
25
+ import utils3d
26
+ import click
27
+ from tqdm import tqdm, trange
28
+ import mlflow
29
+ torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled
30
+
31
+ from moge.train.dataloader import TrainDataLoaderPipeline
32
+ from moge.train.losses import (
33
+ affine_invariant_global_loss,
34
+ affine_invariant_local_loss,
35
+ edge_loss,
36
+ normal_loss,
37
+ mask_l2_loss,
38
+ mask_bce_loss,
39
+ monitoring,
40
+ )
41
+ from moge.train.utils import build_optimizer, build_lr_scheduler
42
+ from moge.utils.geometry_torch import intrinsics_to_fov
43
+ from moge.utils.vis import colorize_depth, colorize_normal
44
+ from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict
45
+ from moge.test.metrics import compute_metrics
46
+
47
+
48
+ @click.command()
49
+ @click.option('--config', 'config_path', type=str, default='configs/debug.json')
50
+ @click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace')
51
+ @click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load')
52
+ @click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device')
53
+ @click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients')
54
+ @click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone')
55
+ @click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16')
56
+ @click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights')
57
+ @click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model')
58
+ @click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations')
59
+ @click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations')
60
+ @click.option('--vis_every', type=int, default=0, help='Visualize every n iterations')
61
+ @click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size')
62
+ @click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow')
63
+ @click.option('--seed', type=int, default=0, help='Random seed')
64
+ def main(
65
+ config_path: str,
66
+ workspace: str,
67
+ checkpoint_path: str,
68
+ batch_size_forward: int,
69
+ gradient_accumulation_steps: int,
70
+ enable_gradient_checkpointing: bool,
71
+ enable_mixed_precision: bool,
72
+ enable_ema: bool,
73
+ num_iterations: int,
74
+ save_every: int,
75
+ log_every: int,
76
+ vis_every: int,
77
+ num_vis_images: int,
78
+ enable_mlflow: bool,
79
+ seed: Optional[int],
80
+ ):
81
+ # Load config
82
+ with open(config_path, 'r') as f:
83
+ config = json.load(f)
84
+
85
+ accelerator = Accelerator(
86
+ gradient_accumulation_steps=gradient_accumulation_steps,
87
+ mixed_precision='fp16' if enable_mixed_precision else None,
88
+ kwargs_handlers=[
89
+ DistributedDataParallelKwargs(find_unused_parameters=True)
90
+ ]
91
+ )
92
+ device = accelerator.device
93
+ batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes
94
+
95
+ # Log config
96
+ if accelerator.is_main_process:
97
+ if enable_mlflow:
98
+ try:
99
+ mlflow.log_params({
100
+ **click.get_current_context().params,
101
+ 'batch_size_total': batch_size_total,
102
+ })
103
+ except:
104
+ print('Failed to log config to MLFlow')
105
+ Path(workspace).mkdir(parents=True, exist_ok=True)
106
+ with Path(workspace).joinpath('config.json').open('w') as f:
107
+ json.dump(config, f, indent=4)
108
+
109
+ # Set seed
110
+ if seed is not None:
111
+ set_seed(seed, device_specific=True)
112
+
113
+ # Initialize model
114
+ print('Initialize model')
115
+ with accelerator.local_main_process_first():
116
+ from moge.model import import_model_class_by_version
117
+ MoGeModel = import_model_class_by_version(config['model_version'])
118
+ model = MoGeModel(**config['model'])
119
+ count_total_parameters = sum(p.numel() for p in model.parameters())
120
+ print(f'Total parameters: {count_total_parameters}')
121
+
122
+ # Set up EMA model
123
+ if enable_ema and accelerator.is_main_process:
124
+ ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter
125
+ ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn)
126
+
127
+ # Set gradient checkpointing
128
+ if enable_gradient_checkpointing:
129
+ model.enable_gradient_checkpointing()
130
+ import warnings
131
+ warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint")
132
+
133
+ # Initalize optimizer & lr scheduler
134
+ optimizer = build_optimizer(model, config['optimizer'])
135
+ lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler'])
136
+
137
+ count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups]
138
+ for i, count in enumerate(count_grouped_parameters):
139
+ print(f'- Group {i}: {count} parameters')
140
+
141
+ # Attempt to load checkpoint
142
+ checkpoint: Dict[str, Any]
143
+ with accelerator.local_main_process_first():
144
+ if checkpoint_path.endswith('.pt'):
145
+ # - Load specific checkpoint file
146
+ print(f'Load checkpoint: {checkpoint_path}')
147
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
148
+ elif checkpoint_path == "latest":
149
+ # - Load latest
150
+ checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt')
151
+ if checkpoint_path.exists():
152
+ print(f'Load checkpoint: {checkpoint_path}')
153
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
154
+ i_step = checkpoint['step']
155
+ if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
156
+ print(f'Load model checkpoint: {checkpoint_model_path}')
157
+ checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
158
+ if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
159
+ print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
160
+ checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
161
+ if enable_ema and accelerator.is_main_process:
162
+ if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
163
+ print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
164
+ checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
165
+ else:
166
+ checkpoint = None
167
+ elif checkpoint_path is not None:
168
+ # - Load by step number
169
+ i_step = int(checkpoint_path)
170
+ checkpoint = {'step': i_step}
171
+ if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
172
+ print(f'Load model checkpoint: {checkpoint_model_path}')
173
+ checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
174
+ if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
175
+ print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
176
+ checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
177
+ if enable_ema and accelerator.is_main_process:
178
+ if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
179
+ print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
180
+ checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
181
+ else:
182
+ checkpoint = None
183
+
184
+ if checkpoint is None:
185
+ # Initialize model weights
186
+ print('Initialize model weights')
187
+ with accelerator.local_main_process_first():
188
+ model.init_weights()
189
+ initial_step = 0
190
+ else:
191
+ model.load_state_dict(checkpoint['model'], strict=False)
192
+ if 'step' in checkpoint:
193
+ initial_step = checkpoint['step'] + 1
194
+ else:
195
+ initial_step = 0
196
+ if 'optimizer' in checkpoint:
197
+ optimizer.load_state_dict(checkpoint['optimizer'])
198
+ if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint:
199
+ ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False)
200
+ if 'lr_scheduler' in checkpoint:
201
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
202
+
203
+ del checkpoint
204
+
205
+ model, optimizer = accelerator.prepare(model, optimizer)
206
+ if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel):
207
+ # Hacking potential gradient synchronization issue in ROCm backend
208
+ from moge.model.utils import sync_ddp_hook
209
+ model.register_comm_hook(None, sync_ddp_hook)
210
+
211
+ # Initialize training data pipeline
212
+ with accelerator.local_main_process_first():
213
+ train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward)
214
+
215
+ def _write_bytes_retry_loop(save_path: Path, data: bytes):
216
+ while True:
217
+ try:
218
+ save_path.write_bytes(data)
219
+ break
220
+ except Exception as e:
221
+ print('Error while saving checkpoint, retrying in 1 minute: ', e)
222
+ time.sleep(60)
223
+
224
+ # Ready to train
225
+ records = []
226
+ model.train()
227
+ with (
228
+ train_data_pipe,
229
+ tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar,
230
+ ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor,
231
+ ):
232
+ # Get some batches for visualization
233
+ if accelerator.is_main_process:
234
+ batches_for_vis: List[Dict[str, torch.Tensor]] = []
235
+ num_vis_images = num_vis_images // batch_size_forward * batch_size_forward
236
+ for _ in range(num_vis_images // batch_size_forward):
237
+ batch = train_data_pipe.get()
238
+ batches_for_vis.append(batch)
239
+
240
+ # Visualize GT
241
+ if vis_every > 0 and accelerator.is_main_process and initial_step == 0:
242
+ save_dir = Path(workspace).joinpath('vis/gt')
243
+ for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)):
244
+ image, gt_depth, gt_mask, gt_mask_inf, gt_intrinsics, info = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_inf'], batch['intrinsics'], batch['info']
245
+ gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics)
246
+ gt_normal, gt_normal_mask = utils3d.torch.points_to_normals(gt_points, gt_mask)
247
+ for i_instance in range(batch['image'].shape[0]):
248
+ idx = i_batch * batch_size_forward + i_instance
249
+ image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
250
+ gt_depth_i = gt_depth[i_instance].numpy()
251
+ gt_mask_i = gt_mask[i_instance].numpy()
252
+ gt_mask_inf_i = gt_mask_inf[i_instance].numpy()
253
+ gt_points_i = gt_points[i_instance].numpy()
254
+ gt_normal_i = gt_normal[i_instance].numpy()
255
+ save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
256
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR))
257
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
258
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), gt_mask_i * 255)
259
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i, gt_mask_i), cv2.COLOR_RGB2BGR))
260
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR))
261
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask_inf.png')), gt_mask_inf_i * 255)
262
+ with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f:
263
+ json.dump(info[i_instance], f)
264
+
265
+ # Reset seed to avoid training on the same data when resuming training
266
+ if seed is not None:
267
+ set_seed(seed + initial_step, device_specific=True)
268
+
269
+ # Training loop
270
+ for i_step in range(initial_step, num_iterations):
271
+
272
+ i_accumulate, weight_accumulate = 0, 0
273
+ while i_accumulate < gradient_accumulation_steps:
274
+ # Load batch
275
+ batch = train_data_pipe.get()
276
+ image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric']
277
+ image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device)
278
+ current_batch_size = image.shape[0]
279
+ if all(label == 'invalid' for label in label_type):
280
+ continue # NOTE: Skip all-invalid batches to avoid messing up the optimizer.
281
+
282
+ gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics)
283
+ gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5
284
+
285
+ with accelerator.accumulate(model):
286
+ # Forward
287
+ if i_step <= config.get('low_resolution_training_steps', 0):
288
+ num_tokens = config['model']['num_tokens_range'][0]
289
+ else:
290
+ num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0]
291
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision):
292
+ output = model(image, num_tokens=num_tokens)
293
+ pred_points, pred_mask, pred_metric_scale = output['points'], output['mask'], output.get('metric_scale', None)
294
+
295
+ # Compute loss (per instance)
296
+ loss_list, weight_list = [], []
297
+ for i in range(current_batch_size):
298
+ gt_metric_scale = None
299
+ loss_dict, weight_dict, misc_dict = {}, {}, {}
300
+ misc_dict['monitoring'] = monitoring(pred_points[i])
301
+ for k, v in config['loss'][label_type[i]].items():
302
+ weight_dict[k] = v['weight']
303
+ if v['function'] == 'affine_invariant_global_loss':
304
+ loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], gt_mask[i], **v['params'])
305
+ elif v['function'] == 'affine_invariant_local_loss':
306
+ loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_mask[i], gt_focal[i], gt_metric_scale, **v['params'])
307
+ elif v['function'] == 'normal_loss':
308
+ loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i], gt_mask[i])
309
+ elif v['function'] == 'edge_loss':
310
+ loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i], gt_mask[i])
311
+ elif v['function'] == 'mask_bce_loss':
312
+ loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
313
+ elif v['function'] == 'mask_l2_loss':
314
+ loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
315
+ else:
316
+ raise ValueError(f'Undefined loss function: {v["function"]}')
317
+ weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()}
318
+ loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()}
319
+ loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device))
320
+ loss_list.append(loss_)
321
+
322
+ if torch.isnan(loss_).item():
323
+ pbar.write(f'NaN loss in process {accelerator.process_index}')
324
+ pbar.write(str(loss_dict))
325
+
326
+ misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()}
327
+ records.append({
328
+ **{k: v.item() for k, v in loss_dict.items()},
329
+ **misc_dict,
330
+ })
331
+
332
+ loss = sum(loss_list) / len(loss_list)
333
+
334
+ # Backward & update
335
+ accelerator.backward(loss)
336
+ if accelerator.sync_gradients:
337
+ if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
338
+ if accelerator.is_main_process:
339
+ pbar.write(f'NaN gradients, skip update')
340
+ optimizer.zero_grad()
341
+ continue
342
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
343
+
344
+ optimizer.step()
345
+ optimizer.zero_grad()
346
+
347
+ i_accumulate += 1
348
+
349
+ lr_scheduler.step()
350
+
351
+ # EMA update
352
+ if enable_ema and accelerator.is_main_process and accelerator.sync_gradients:
353
+ ema_model.update_parameters(model)
354
+
355
+ # Log metrics
356
+ if i_step == initial_step or i_step % log_every == 0:
357
+ records = [key_average(records)]
358
+ records = accelerator.gather_for_metrics(records, use_gather_object=True)
359
+ if accelerator.is_main_process:
360
+ records = key_average(records)
361
+ if enable_mlflow:
362
+ try:
363
+ mlflow.log_metrics(records, step=i_step)
364
+ except Exception as e:
365
+ print(f'Error while logging metrics to mlflow: {e}')
366
+ records = []
367
+
368
+ # Save model weight checkpoint
369
+ if accelerator.is_main_process and (i_step % save_every == 0):
370
+ # NOTE: Writing checkpoint is done in a separate thread to avoid blocking the main process
371
+ pbar.write(f'Save checkpoint: {i_step:08d}')
372
+ Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True)
373
+
374
+ # Model checkpoint
375
+ with io.BytesIO() as f:
376
+ torch.save({
377
+ 'model_config': config['model'],
378
+ 'model': accelerator.unwrap_model(model).state_dict(),
379
+ }, f)
380
+ checkpoint_bytes = f.getvalue()
381
+ save_checkpoint_executor.submit(
382
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes
383
+ )
384
+
385
+ # Optimizer checkpoint
386
+ with io.BytesIO() as f:
387
+ torch.save({
388
+ 'model_config': config['model'],
389
+ 'step': i_step,
390
+ 'optimizer': optimizer.state_dict(),
391
+ 'lr_scheduler': lr_scheduler.state_dict(),
392
+ }, f)
393
+ checkpoint_bytes = f.getvalue()
394
+ save_checkpoint_executor.submit(
395
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes
396
+ )
397
+
398
+ # EMA model checkpoint
399
+ if enable_ema:
400
+ with io.BytesIO() as f:
401
+ torch.save({
402
+ 'model_config': config['model'],
403
+ 'model': ema_model.module.state_dict(),
404
+ }, f)
405
+ checkpoint_bytes = f.getvalue()
406
+ save_checkpoint_executor.submit(
407
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes
408
+ )
409
+
410
+ # Latest checkpoint
411
+ with io.BytesIO() as f:
412
+ torch.save({
413
+ 'model_config': config['model'],
414
+ 'step': i_step,
415
+ }, f)
416
+ checkpoint_bytes = f.getvalue()
417
+ save_checkpoint_executor.submit(
418
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes
419
+ )
420
+
421
+ # Visualize
422
+ if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0):
423
+ unwrapped_model = accelerator.unwrap_model(model)
424
+ save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}')
425
+ save_dir.mkdir(parents=True, exist_ok=True)
426
+ with torch.inference_mode():
427
+ for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)):
428
+ image, gt_depth, gt_mask, gt_intrinsics = batch['image'], batch['depth'], batch['depth_mask'], batch['intrinsics']
429
+ image, gt_depth, gt_mask, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_intrinsics.to(device)
430
+
431
+ output = unwrapped_model.infer(image)
432
+ pred_points, pred_depth, pred_mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy()
433
+ image = image.cpu().numpy()
434
+
435
+ for i_instance in range(image.shape[0]):
436
+ idx = i_batch * batch_size_forward + i_instance
437
+ image_i = (image[i_instance].transpose(1, 2, 0) * 255).astype(np.uint8)
438
+ pred_points_i = pred_points[i_instance]
439
+ pred_mask_i = pred_mask[i_instance]
440
+ pred_depth_i = pred_depth[i_instance]
441
+ save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
442
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR))
443
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
444
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask_i * 255)
445
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth_i, pred_mask_i), cv2.COLOR_RGB2BGR))
446
+
447
+ pbar.set_postfix({'loss': loss.item()}, refresh=False)
448
+ pbar.update(1)
449
+
450
+
451
+ if __name__ == '__main__':
452
+ main()
moge/scripts/vis_data.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ import sys
4
+ from pathlib import Path
5
+ if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
6
+ sys.path.insert(0, _package_root)
7
+
8
+ import click
9
+
10
+
11
+ @click.command()
12
+ @click.argument('folder_or_path', type=click.Path(exists=True))
13
+ @click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder')
14
+ @click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth')
15
+ @click.option('--fov', type=float, default=None, help='field of view in degrees')
16
+ @click.option('--show', 'show', is_flag=True, help='show point cloud')
17
+ @click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name')
18
+ @click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file')
19
+ @click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image')
20
+ @click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask')
21
+ @click.option('--version', 'version', type=str, default='v3', help='version of rgbd data')
22
+ def main(
23
+ folder_or_path: str,
24
+ output_folder: str,
25
+ max_depth: float,
26
+ fov: float,
27
+ depth_filename: str,
28
+ show: bool,
29
+ save_ply: bool,
30
+ save_depth_vis: bool,
31
+ inf_mask: bool,
32
+ version: str
33
+ ):
34
+ # Lazy import
35
+ import cv2
36
+ import numpy as np
37
+ import utils3d
38
+ from tqdm import tqdm
39
+ import trimesh
40
+
41
+ from moge.utils.io import read_image, read_depth, read_meta
42
+ from moge.utils.vis import colorize_depth, colorize_normal
43
+
44
+ filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json'))
45
+
46
+ for filepath in tqdm(filepaths):
47
+ image = read_image(Path(filepath, 'image.jpg'))
48
+ depth, unit = read_depth(Path(filepath, depth_filename))
49
+ meta = read_meta(Path(filepath,'meta.json'))
50
+ depth_mask = np.isfinite(depth)
51
+ depth_mask_inf = (depth == np.inf)
52
+ intrinsics = np.array(meta['intrinsics'])
53
+
54
+ extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera
55
+ verts = utils3d.numpy.unproject_cv(utils3d.numpy.image_uv(*image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics)
56
+
57
+ depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth)
58
+ point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255)
59
+
60
+ if show:
61
+ point_cloud.show()
62
+
63
+ if output_folder is None:
64
+ output_path = filepath
65
+ else:
66
+ output_path = Path(output_folder, filepath.name)
67
+ output_path.mkdir(exist_ok=True, parents=True)
68
+
69
+ if inf_mask:
70
+ depth = np.where(depth_mask_inf, np.inf, depth)
71
+ depth_mask = depth_mask | depth_mask_inf
72
+
73
+ if save_depth_vis:
74
+ p = output_path.joinpath('depth_vis.png')
75
+ cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR))
76
+ print(f"{p}")
77
+
78
+ if save_ply:
79
+ p = output_path.joinpath('pointcloud.ply')
80
+ point_cloud.export(p)
81
+ print(f"{p}")
82
+
83
+ if __name__ == '__main__':
84
+ main()
moge/test/__init__.py ADDED
File without changes
moge/test/baseline.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import click
4
+ import torch
5
+
6
+
7
+ class MGEBaselineInterface:
8
+ """
9
+ Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
10
+ """
11
+ device: torch.device
12
+
13
+ @click.command()
14
+ @staticmethod
15
+ def load(*args, **kwargs) -> "MGEBaselineInterface":
16
+ """
17
+ Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
18
+ """
19
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
20
+
21
+ def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
22
+ """
23
+ ### Parameters
24
+ `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
25
+ `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
26
+
27
+ ### Returns
28
+ A dictionary containing:
29
+ - `points_*`. point map output in OpenCV identity camera space.
30
+ Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
31
+ - `depth_*`. depth map output
32
+ Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
33
+ - `disparity_affine_invariant`. affine disparity map output
34
+ """
35
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
36
+
37
+ def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
38
+ """
39
+ If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
40
+
41
+ By default, this method simply calls `infer()`.
42
+ """
43
+ return self.infer(image, intrinsics)
moge/test/dataloader.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import *
3
+ from pathlib import Path
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import cv2
10
+ import utils3d
11
+
12
+ from ..utils import pipeline
13
+ from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d
14
+ from ..utils.io import *
15
+ from ..utils.tools import timeit
16
+
17
+
18
+ class EvalDataLoaderPipeline:
19
+
20
+ def __init__(
21
+ self,
22
+ path: str,
23
+ width: int,
24
+ height: int,
25
+ split: int = '.index.txt',
26
+ drop_max_depth: float = 1000.,
27
+ num_load_workers: int = 4,
28
+ num_process_workers: int = 8,
29
+ include_segmentation: bool = False,
30
+ include_normal: bool = False,
31
+ depth_to_normal: bool = False,
32
+ max_segments: int = 100,
33
+ min_seg_area: int = 1000,
34
+ depth_unit: str = None,
35
+ has_sharp_boundary = False,
36
+ subset: int = None,
37
+ ):
38
+ filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines()
39
+ filenames = filenames[::subset]
40
+ self.width = width
41
+ self.height = height
42
+ self.drop_max_depth = drop_max_depth
43
+ self.path = Path(path)
44
+ self.filenames = filenames
45
+ self.include_segmentation = include_segmentation
46
+ self.include_normal = include_normal
47
+ self.max_segments = max_segments
48
+ self.min_seg_area = min_seg_area
49
+ self.depth_to_normal = depth_to_normal
50
+ self.depth_unit = depth_unit
51
+ self.has_sharp_boundary = has_sharp_boundary
52
+
53
+ self.rng = np.random.default_rng(seed=0)
54
+
55
+ self.pipeline = pipeline.Sequential([
56
+ self._generator,
57
+ pipeline.Parallel([self._load_instance] * num_load_workers),
58
+ pipeline.Parallel([self._process_instance] * num_process_workers),
59
+ pipeline.Buffer(4)
60
+ ])
61
+
62
+ def __len__(self):
63
+ return math.ceil(len(self.filenames))
64
+
65
+ def _generator(self):
66
+ for idx in range(len(self)):
67
+ yield idx
68
+
69
+ def _load_instance(self, idx):
70
+ if idx >= len(self.filenames):
71
+ return None
72
+
73
+ path = self.path.joinpath(self.filenames[idx])
74
+
75
+ instance = {
76
+ 'filename': self.filenames[idx],
77
+ 'width': self.width,
78
+ 'height': self.height,
79
+ }
80
+ instance['image'] = read_image(Path(path, 'image.jpg'))
81
+
82
+ depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead
83
+ instance.update({
84
+ 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1),
85
+ 'depth_mask': np.isfinite(depth),
86
+ 'depth_mask_inf': np.isinf(depth),
87
+ })
88
+
89
+ if self.include_segmentation:
90
+ segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png'))
91
+ instance.update({
92
+ 'segmentation_mask': segmentation_mask,
93
+ 'segmentation_labels': segmentation_labels,
94
+ })
95
+
96
+ meta = read_meta(Path(path, 'meta.json'))
97
+ instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32)
98
+
99
+ return instance
100
+
101
+ def _process_instance(self, instance: dict):
102
+ if instance is None:
103
+ return None
104
+
105
+ image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics']
106
+ segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None)
107
+
108
+ raw_height, raw_width = image.shape[:2]
109
+ raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
110
+ raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
111
+ tgt_width, tgt_height = instance['width'], instance['height']
112
+ tgt_aspect = tgt_width / tgt_height
113
+
114
+ # set expected target view field
115
+ tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect)
116
+ tgt_vertical = tgt_horizontal / tgt_aspect
117
+
118
+ # set target view direction
119
+ cu, cv = 0.5, 0.5
120
+ direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
121
+ R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
122
+
123
+ # restrict target view field within the raw view
124
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
125
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
126
+ corners = corners[:, :2] / corners[:, 2:3]
127
+
128
+ warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
129
+ for i in range(4):
130
+ intersection, _ = utils3d.numpy.ray_intersection(
131
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
132
+ corners[i - 1], corners[i] - corners[i - 1],
133
+ )
134
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
135
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
136
+
137
+ # get target view intrinsics
138
+ fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical
139
+ tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
140
+
141
+ # do homogeneous transformation with the rotation and intrinsics
142
+ # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
143
+ tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
144
+ rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
145
+ image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
146
+
147
+ depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h))
148
+ distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics))
149
+ segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None
150
+
151
+ # 4.2 calculate homography warping
152
+ transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
153
+ uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
154
+ pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
155
+ uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
156
+ pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
157
+
158
+ tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
159
+ tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST)
160
+ tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)
161
+ tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5
162
+ tgt_depth = tgt_distance / (tgt_ray_length + 1e-12)
163
+ tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
164
+ tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None
165
+
166
+ # drop depth greater than drop_max_depth
167
+ max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth
168
+ tgt_depth_mask &= tgt_depth <= max_depth
169
+ tgt_depth = np.nan_to_num(tgt_depth, nan=0.0)
170
+
171
+ if self.depth_unit is not None:
172
+ tgt_depth *= self.depth_unit
173
+
174
+ if not np.any(tgt_depth_mask):
175
+ # always make sure that mask is not empty, otherwise the loss calculation will crash
176
+ tgt_depth_mask = np.ones_like(tgt_depth_mask)
177
+ tgt_depth = np.ones_like(tgt_depth)
178
+ instance['label_type'] = 'invalid'
179
+
180
+ tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics)
181
+
182
+ # Process segmentation labels
183
+ if self.include_segmentation and segmentation_mask is not None:
184
+ for k in ['undefined', 'unannotated', 'background', 'sky']:
185
+ if k in segmentation_labels:
186
+ del segmentation_labels[k]
187
+ seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True)))
188
+ sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True)
189
+ segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area}
190
+
191
+ instance.update({
192
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
193
+ 'depth': torch.from_numpy(tgt_depth).float(),
194
+ 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
195
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
196
+ 'points': torch.from_numpy(tgt_pts).float(),
197
+ 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None,
198
+ 'segmentation_labels': segmentation_labels,
199
+ 'is_metric': self.depth_unit is not None,
200
+ 'has_sharp_boundary': self.has_sharp_boundary,
201
+ })
202
+
203
+ instance = {k: v for k, v in instance.items() if v is not None}
204
+
205
+ return instance
206
+
207
+ def start(self):
208
+ self.pipeline.start()
209
+
210
+ def stop(self):
211
+ self.pipeline.stop()
212
+
213
+ def __enter__(self):
214
+ self.start()
215
+ return self
216
+
217
+ def __exit__(self, exc_type, exc_value, traceback):
218
+ self.stop()
219
+
220
+ def get(self):
221
+ return self.pipeline.get()
moge/test/metrics.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import utils3d
8
+
9
+ from ..utils.geometry_torch import (
10
+ weighted_mean,
11
+ mask_aware_nearest_resize,
12
+ intrinsics_to_fov
13
+ )
14
+ from ..utils.alignment import (
15
+ align_points_scale_z_shift,
16
+ align_points_scale_xyz_shift,
17
+ align_points_xyz_shift,
18
+ align_affine_lstsq,
19
+ align_depth_scale,
20
+ align_depth_affine,
21
+ align_points_scale,
22
+ )
23
+ from ..utils.tools import key_average, timeit
24
+
25
+
26
+ def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
27
+ rel = (torch.abs(pred - gt) / (gt + eps)).mean()
28
+ return rel.item()
29
+
30
+
31
+ def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
32
+ delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean()
33
+ return delta1.item()
34
+
35
+
36
+ def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
37
+ dist_gt = torch.norm(gt, dim=-1)
38
+ dist_err = torch.norm(pred - gt, dim=-1)
39
+ rel = (dist_err / (dist_gt + eps)).mean()
40
+ return rel.item()
41
+
42
+
43
+ def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
44
+ dist_pred = torch.norm(pred, dim=-1)
45
+ dist_gt = torch.norm(gt, dim=-1)
46
+ dist_err = torch.norm(pred - gt, dim=-1)
47
+
48
+ delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean()
49
+ return delta1.item()
50
+
51
+
52
+ def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
53
+ dist_err = torch.norm(pred - gt, dim=-1)
54
+ rel = (dist_err / diameter).mean()
55
+ return rel.item()
56
+
57
+
58
+ def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
59
+ dist_err = torch.norm(pred - gt, dim=-1)
60
+ delta1 = (dist_err < 0.25 * diameter).float().mean()
61
+ return delta1.item()
62
+
63
+
64
+ def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1):
65
+ neighbor_x, neight_y = torch.meshgrid(
66
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
67
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
68
+ indexing='xy'
69
+ )
70
+ neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5
71
+
72
+ pred_window = utils3d.torch.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
73
+ gt_window = utils3d.torch.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
74
+ mask_window = neighbor_mask & utils3d.torch.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
75
+
76
+ pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None]
77
+ gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None]
78
+ valid = mask[radius:-radius, radius:-radius, None, None] & mask_window
79
+
80
+ f1_list = []
81
+ w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist()
82
+
83
+ for t in t_list:
84
+ pred_label = pred_rel > 1 + t
85
+ gt_label = gt_rel > 1 + t
86
+ TP = (pred_label & gt_label & valid).float().sum()
87
+ precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12)
88
+ recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12)
89
+ f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12)
90
+ f1_list.append(f1.item())
91
+
92
+ f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list)
93
+ return f1_avg
94
+
95
+
96
+ def compute_metrics(
97
+ pred: Dict[str, torch.Tensor],
98
+ gt: Dict[str, torch.Tensor],
99
+ vis: bool = False
100
+ ) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]:
101
+ """
102
+ A unified function to compute metrics for different types of predictions and ground truths.
103
+
104
+ #### Supported keys in pred:
105
+ - `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant.
106
+ - `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant.
107
+ - `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant.
108
+ - `depth_metric`: depth map predicted by a depth estimator with no scale or shift.
109
+ - `points_scale_invariant`: point map predicted by a point estimator with scale invariant.
110
+ - `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant.
111
+ - `points_metric`: point map predicted by a point estimator with no scale or shift.
112
+ - `intrinsics`: normalized camera intrinsics matrix.
113
+
114
+ #### Required keys in gt:
115
+ - `depth`: depth map ground truth (in metric units if `depth_metric` is used)
116
+ - `points`: point map ground truth in camera coordinates.
117
+ - `mask`: mask indicating valid pixels in the ground truth.
118
+ - `intrinsics`: normalized ground-truth camera intrinsics matrix.
119
+ - `is_metric`: whether the depth is in metric units.
120
+ """
121
+ metrics = {}
122
+ misc = {}
123
+
124
+ mask = gt['depth_mask']
125
+ gt_depth = gt['depth']
126
+ gt_points = gt['points']
127
+
128
+ height, width = mask.shape[-2:]
129
+ _, lr_mask, lr_index = mask_aware_nearest_resize(None, mask, (64, 64), return_index=True)
130
+
131
+ only_depth = not any('point' in k for k in pred)
132
+ pred_depth_aligned, pred_points_aligned = None, None
133
+
134
+ # Metric depth
135
+ if 'depth_metric' in pred and gt['is_metric']:
136
+ pred_depth, gt_depth = pred['depth_metric'], gt['depth']
137
+ metrics['depth_metric'] = {
138
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
139
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
140
+ }
141
+
142
+ if pred_depth_aligned is None:
143
+ pred_depth_aligned = pred_depth
144
+
145
+ # Scale-invariant depth
146
+ if 'depth_scale_invariant' in pred:
147
+ pred_depth_scale_invariant = pred['depth_scale_invariant']
148
+ elif 'depth_metric' in pred:
149
+ pred_depth_scale_invariant = pred['depth_metric']
150
+ else:
151
+ pred_depth_scale_invariant = None
152
+
153
+ if pred_depth_scale_invariant is not None:
154
+ pred_depth = pred_depth_scale_invariant
155
+
156
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
157
+ scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
158
+ pred_depth = pred_depth * scale
159
+
160
+ metrics['depth_scale_invariant'] = {
161
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
162
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
163
+ }
164
+
165
+ if pred_depth_aligned is None:
166
+ pred_depth_aligned = pred_depth
167
+
168
+ # Affine-invariant depth
169
+ if 'depth_affine_invariant' in pred:
170
+ pred_depth_affine_invariant = pred['depth_affine_invariant']
171
+ elif 'depth_scale_invariant' in pred:
172
+ pred_depth_affine_invariant = pred['depth_scale_invariant']
173
+ elif 'depth_metric' in pred:
174
+ pred_depth_affine_invariant = pred['depth_metric']
175
+ else:
176
+ pred_depth_affine_invariant = None
177
+
178
+ if pred_depth_affine_invariant is not None:
179
+ pred_depth = pred_depth_affine_invariant
180
+
181
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
182
+ scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
183
+ pred_depth = pred_depth * scale + shift
184
+
185
+ metrics['depth_affine_invariant'] = {
186
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
187
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
188
+ }
189
+
190
+ if pred_depth_aligned is None:
191
+ pred_depth_aligned = pred_depth
192
+
193
+ # Affine-invariant disparity
194
+ if 'disparity_affine_invariant' in pred:
195
+ pred_disparity_affine_invariant = pred['disparity_affine_invariant']
196
+ elif 'depth_scale_invariant' in pred:
197
+ pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant']
198
+ elif 'depth_metric' in pred:
199
+ pred_disparity_affine_invariant = 1 / pred['depth_metric']
200
+ else:
201
+ pred_disparity_affine_invariant = None
202
+
203
+ if pred_disparity_affine_invariant is not None:
204
+ pred_disp = pred_disparity_affine_invariant
205
+
206
+ scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask])
207
+ pred_disp = pred_disp * scale + shift
208
+
209
+ # NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0.
210
+ # Therefore we clamp the disparities by minimum ground truth disparity.
211
+ pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item())
212
+
213
+ metrics['disparity_affine_invariant'] = {
214
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
215
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
216
+ }
217
+
218
+ if pred_depth_aligned is None:
219
+ pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6)
220
+
221
+ # Metric points
222
+ if 'points_metric' in pred and gt['is_metric']:
223
+ pred_points = pred['points_metric']
224
+
225
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
226
+ shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
227
+ pred_points = pred_points + shift
228
+
229
+ metrics['points_metric'] = {
230
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
231
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
232
+ }
233
+
234
+ if pred_points_aligned is None:
235
+ pred_points_aligned = pred['points_metric']
236
+
237
+ # Scale-invariant points (in camera space)
238
+ if 'points_scale_invariant' in pred:
239
+ pred_points_scale_invariant = pred['points_scale_invariant']
240
+ elif 'points_metric' in pred:
241
+ pred_points_scale_invariant = pred['points_metric']
242
+ else:
243
+ pred_points_scale_invariant = None
244
+
245
+ if pred_points_scale_invariant is not None:
246
+ pred_points = pred_points_scale_invariant
247
+
248
+ pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask]
249
+ scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
250
+ pred_points = pred_points * scale
251
+
252
+ metrics['points_scale_invariant'] = {
253
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
254
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
255
+ }
256
+
257
+ if vis and pred_points_aligned is None:
258
+ pred_points_aligned = pred['points_scale_invariant'] * scale
259
+
260
+ # Affine-invariant points
261
+ if 'points_affine_invariant' in pred:
262
+ pred_points_affine_invariant = pred['points_affine_invariant']
263
+ elif 'points_scale_invariant' in pred:
264
+ pred_points_affine_invariant = pred['points_scale_invariant']
265
+ elif 'points_metric' in pred:
266
+ pred_points_affine_invariant = pred['points_metric']
267
+ else:
268
+ pred_points_affine_invariant = None
269
+
270
+ if pred_points_affine_invariant is not None:
271
+ pred_points = pred_points_affine_invariant
272
+
273
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
274
+ scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
275
+ pred_points = pred_points * scale + shift
276
+
277
+ metrics['points_affine_invariant'] = {
278
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
279
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
280
+ }
281
+
282
+ if vis and pred_points_aligned is None:
283
+ pred_points_aligned = pred['points_affine_invariant'] * scale + shift
284
+
285
+ # Local points
286
+ if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()):
287
+ pred_points = next(pred[k] for k in pred.keys() if 'points' in k)
288
+ gt_points = gt['points']
289
+ segmentation_mask = gt['segmentation_mask']
290
+ segmentation_labels = gt['segmentation_labels']
291
+ segmentation_mask_lr = segmentation_mask[lr_index]
292
+ local_points_metrics = []
293
+ for _, seg_id in segmentation_labels.items():
294
+ valid_mask = (segmentation_mask == seg_id) & mask
295
+
296
+ pred_points_masked = pred_points[valid_mask]
297
+ gt_points_masked = gt_points[valid_mask]
298
+
299
+ valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask
300
+ if valid_mask_lr.sum().item() < 10:
301
+ continue
302
+ pred_points_masked_lr = pred_points[lr_index][valid_mask_lr]
303
+ gt_points_masked_lr = gt_points[lr_index][valid_mask_lr]
304
+ diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max()
305
+ scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0]))
306
+ pred_points_masked = pred_points_masked * scale + shift
307
+
308
+ local_points_metrics.append({
309
+ 'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter),
310
+ 'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter),
311
+ })
312
+
313
+ metrics['local_points'] = key_average(local_points_metrics)
314
+
315
+ # FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same.
316
+ # Fair evaluation of FOV requires random augmentation.
317
+ if 'intrinsics' in pred and 'intrinsics' in gt:
318
+ pred_intrinsics = pred['intrinsics']
319
+ gt_intrinsics = gt['intrinsics']
320
+ pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics)
321
+ gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics)
322
+ metrics['fov_x'] = {
323
+ 'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(),
324
+ 'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(),
325
+ }
326
+
327
+ # Boundary F1
328
+ if pred_depth_aligned is not None and gt['has_sharp_boundary']:
329
+ metrics['boundary'] = {
330
+ 'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1),
331
+ 'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2),
332
+ 'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3),
333
+ }
334
+
335
+ if vis:
336
+ if pred_points_aligned is not None:
337
+ misc['pred_points'] = pred_points_aligned
338
+ if only_depth:
339
+ misc['pred_points'] = utils3d.torch.depth_to_points(pred_depth_aligned, intrinsics=gt['intrinsics'])
340
+ if pred_depth_aligned is not None:
341
+ misc['pred_depth'] = pred_depth_aligned
342
+
343
+ return metrics, misc
moge/train/__init__.py ADDED
File without changes
moge/train/dataloader.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import json
4
+ import time
5
+ import random
6
+ from typing import *
7
+ import traceback
8
+ import itertools
9
+ from numbers import Number
10
+ import io
11
+
12
+ import numpy as np
13
+ import cv2
14
+ from PIL import Image
15
+ import torch
16
+ import torchvision.transforms.v2.functional as TF
17
+ import utils3d
18
+ from tqdm import tqdm
19
+
20
+ from ..utils import pipeline
21
+ from ..utils.io import *
22
+ from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field
23
+
24
+
25
+ class TrainDataLoaderPipeline:
26
+ def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8):
27
+ self.config = config
28
+
29
+ self.batch_size = batch_size
30
+ self.clamp_max_depth = config['clamp_max_depth']
31
+ self.fov_range_absolute = config.get('fov_range_absolute', 0.0)
32
+ self.fov_range_relative = config.get('fov_range_relative', 0.0)
33
+ self.center_augmentation = config.get('center_augmentation', 0.0)
34
+ self.image_augmentation = config.get('image_augmentation', [])
35
+ self.depth_interpolation = config.get('depth_interpolation', 'bilinear')
36
+
37
+ if 'image_sizes' in config:
38
+ self.image_size_strategy = 'fixed'
39
+ self.image_sizes = config['image_sizes']
40
+ elif 'aspect_ratio_range' in config and 'area_range' in config:
41
+ self.image_size_strategy = 'aspect_area'
42
+ self.aspect_ratio_range = config['aspect_ratio_range']
43
+ self.area_range = config['area_range']
44
+ else:
45
+ raise ValueError('Invalid image size configuration')
46
+
47
+ # Load datasets
48
+ self.datasets = {}
49
+ for dataset in tqdm(config['datasets'], desc='Loading datasets'):
50
+ name = dataset['name']
51
+ content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text()
52
+ filenames = content.splitlines()
53
+ self.datasets[name] = {
54
+ **dataset,
55
+ 'path': dataset['path'],
56
+ 'filenames': filenames,
57
+ }
58
+ self.dataset_names = [dataset['name'] for dataset in config['datasets']]
59
+ self.dataset_weights = [dataset['weight'] for dataset in config['datasets']]
60
+
61
+ # Build pipeline
62
+ self.pipeline = pipeline.Sequential([
63
+ self._sample_batch,
64
+ pipeline.Unbatch(),
65
+ pipeline.Parallel([self._load_instance] * num_load_workers),
66
+ pipeline.Parallel([self._process_instance] * num_process_workers),
67
+ pipeline.Batch(self.batch_size),
68
+ self._collate_batch,
69
+ pipeline.Buffer(buffer_size),
70
+ ])
71
+
72
+ self.invalid_instance = {
73
+ 'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32),
74
+ 'image': np.zeros((256, 256, 3), dtype=np.uint8),
75
+ 'depth': np.ones((256, 256), dtype=np.float32),
76
+ 'depth_mask': np.ones((256, 256), dtype=bool),
77
+ 'depth_mask_inf': np.zeros((256, 256), dtype=bool),
78
+ 'label_type': 'invalid',
79
+ }
80
+
81
+ def _sample_batch(self):
82
+ batch_id = 0
83
+ last_area = None
84
+ while True:
85
+ # Depending on the sample strategy, choose a dataset and a filename
86
+ batch_id += 1
87
+ batch = []
88
+
89
+ # Sample instances
90
+ for _ in range(self.batch_size):
91
+ dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0]
92
+ filename = random.choice(self.datasets[dataset_name]['filenames'])
93
+
94
+ path = Path(self.datasets[dataset_name]['path'], filename)
95
+
96
+ instance = {
97
+ 'batch_id': batch_id,
98
+ 'seed': random.randint(0, 2 ** 32 - 1),
99
+ 'dataset': dataset_name,
100
+ 'filename': filename,
101
+ 'path': path,
102
+ 'label_type': self.datasets[dataset_name]['label_type'],
103
+ }
104
+ batch.append(instance)
105
+
106
+ # Decide the image size for this batch
107
+ if self.image_size_strategy == 'fixed':
108
+ width, height = random.choice(self.config['image_sizes'])
109
+ elif self.image_size_strategy == 'aspect_area':
110
+ area = random.uniform(*self.area_range)
111
+ aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch]
112
+ aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges))
113
+ aspect_ratio = random.uniform(*aspect_ratio_range)
114
+ width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5)
115
+ else:
116
+ raise ValueError('Invalid image size strategy')
117
+
118
+ for instance in batch:
119
+ instance['width'], instance['height'] = width, height
120
+
121
+ yield batch
122
+
123
+ def _load_instance(self, instance: dict):
124
+ try:
125
+ image = read_image(Path(instance['path'], 'image.jpg'))
126
+ depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png')))
127
+
128
+ meta = read_meta(Path(instance['path'], 'meta.json'))
129
+ intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
130
+ depth_mask = np.isfinite(depth)
131
+ depth_mask_inf = np.isinf(depth)
132
+ depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1)
133
+ data = {
134
+ 'image': image,
135
+ 'depth': depth,
136
+ 'depth_mask': depth_mask,
137
+ 'depth_mask_inf': depth_mask_inf,
138
+ 'intrinsics': intrinsics
139
+ }
140
+ instance.update({
141
+ **data,
142
+ })
143
+ except Exception as e:
144
+ print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e)
145
+ instance.update(self.invalid_instance)
146
+ return instance
147
+
148
+ def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]):
149
+ image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type']
150
+ depth_unit = self.datasets[instance['dataset']].get('depth_unit', None)
151
+
152
+ raw_height, raw_width = image.shape[:2]
153
+ raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
154
+ raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
155
+ raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
156
+ tgt_width, tgt_height = instance['width'], instance['height']
157
+ tgt_aspect = tgt_width / tgt_height
158
+
159
+ rng = np.random.default_rng(instance['seed'])
160
+
161
+ # 1. set target fov
162
+ center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation)
163
+ fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute)
164
+ fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative)
165
+ tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
166
+ tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
167
+ tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max)
168
+ tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max)
169
+ tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect)
170
+
171
+ # 2. set target image center (principal point) and the corresponding z-direction in raw camera space
172
+ center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x)
173
+ center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y)
174
+ cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2)
175
+ direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
176
+
177
+ # 3. obtain the rotation matrix for homography warping
178
+ R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
179
+
180
+ # 4. shrink the target view to fit into the warped image
181
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
182
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
183
+ corners = corners[:, :2] / corners[:, 2:3]
184
+ tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2
185
+ warp_horizontal, warp_vertical = float('inf'), float('inf')
186
+ for i in range(4):
187
+ intersection, _ = utils3d.numpy.ray_intersection(
188
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
189
+ corners[i - 1], corners[i] - corners[i - 1],
190
+ )
191
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
192
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
193
+
194
+ # 5. obtain the target intrinsics
195
+ fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical
196
+ tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
197
+
198
+ # 6. do homogeneous transformation
199
+ # 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
200
+ tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
201
+ rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
202
+ image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
203
+
204
+ edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01)
205
+ _, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True)
206
+ depth_nearest = depth[resize_index]
207
+ distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics))
208
+ edge_mask = edge_mask[resize_index]
209
+
210
+ if self.depth_interpolation == 'bilinear':
211
+ depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
212
+ depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
213
+ distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics))
214
+
215
+ depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0
216
+
217
+ # 6.2 calculate homography warping
218
+ transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
219
+ uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
220
+ pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
221
+ uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
222
+ pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
223
+
224
+ tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4)
225
+ tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics))
226
+ tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
227
+ tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length
228
+ tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
229
+ if self.depth_interpolation == 'bilinear':
230
+ tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
231
+ tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length
232
+ tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest)
233
+ else:
234
+ tgt_depth = tgt_depth_nearest
235
+ tgt_depth_mask = tgt_depth_mask_nearest
236
+
237
+ tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
238
+
239
+ # always make sure that mask is not empty
240
+ if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001:
241
+ tgt_depth_mask = np.ones_like(tgt_depth_mask)
242
+ tgt_depth = np.ones_like(tgt_depth)
243
+ instance['label_type'] = 'invalid'
244
+
245
+ # Flip augmentation
246
+ if rng.choice([True, False]):
247
+ tgt_image = np.flip(tgt_image, axis=1).copy()
248
+ tgt_depth = np.flip(tgt_depth, axis=1).copy()
249
+ tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy()
250
+ tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy()
251
+
252
+ # Color augmentation
253
+ image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation)
254
+ if 'jittering' in image_augmentation:
255
+ tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1)
256
+ tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3))
257
+ tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3))
258
+ tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3))
259
+ tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1))
260
+ tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3))
261
+ tgt_image = tgt_image.permute(1, 2, 0).numpy()
262
+ if 'dof' in image_augmentation:
263
+ if rng.uniform() < 0.5:
264
+ dof_strength = rng.integers(12)
265
+ tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth)
266
+ disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max()
267
+ tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max)
268
+ dof_focus = rng.uniform(disp_min, disp_max)
269
+ tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength)
270
+ if 'shot_noise' in image_augmentation:
271
+ if rng.uniform() < 0.5:
272
+ k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255
273
+ tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8)
274
+ if 'jpeg_loss' in image_augmentation:
275
+ if rng.uniform() < 0.5:
276
+ tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR)
277
+ if 'blurring' in image_augmentation:
278
+ if rng.uniform() < 0.5:
279
+ ratio = rng.uniform(0.25, 1)
280
+ tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]))
281
+
282
+ # convert depth to metric if necessary
283
+ if depth_unit is not None:
284
+ tgt_depth *= depth_unit
285
+ instance['is_metric'] = True
286
+ else:
287
+ instance['is_metric'] = False
288
+
289
+ # clamp depth maximum values
290
+ max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth
291
+ tgt_depth = np.clip(tgt_depth, 0, max_depth)
292
+ tgt_depth = np.nan_to_num(tgt_depth, nan=1.0)
293
+
294
+ if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known":
295
+ tgt_depth_mask_fin = tgt_depth_mask
296
+ else:
297
+ tgt_depth_mask_fin = ~tgt_depth_mask_inf
298
+
299
+ instance.update({
300
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
301
+ 'depth': torch.from_numpy(tgt_depth).float(),
302
+ 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
303
+ 'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(),
304
+ 'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(),
305
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
306
+ })
307
+
308
+ return instance
309
+
310
+ def _collate_batch(self, instances: List[Dict[str, Any]]):
311
+ batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']}
312
+ batch = {
313
+ 'label_type': [instance['label_type'] for instance in instances],
314
+ 'is_metric': [instance['is_metric'] for instance in instances],
315
+ 'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances],
316
+ **batch,
317
+ }
318
+ return batch
319
+
320
+ def get(self) -> Dict[str, Union[torch.Tensor, str]]:
321
+ return self.pipeline.get()
322
+
323
+ def start(self):
324
+ self.pipeline.start()
325
+
326
+ def stop(self):
327
+ self.pipeline.stop()
328
+
329
+ def __enter__(self):
330
+ self.start()
331
+ return self
332
+
333
+ def __exit__(self, exc_type, exc_value, traceback):
334
+ self.pipeline.terminate()
335
+ self.pipeline.join()
336
+ return False
337
+
338
+
moge/train/losses.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import utils3d
7
+
8
+ from ..utils.geometry_torch import (
9
+ weighted_mean,
10
+ harmonic_mean,
11
+ geometric_mean,
12
+ mask_aware_nearest_resize,
13
+ normalized_view_plane_uv,
14
+ angle_diff_vec3
15
+ )
16
+ from ..utils.alignment import (
17
+ align_points_scale_z_shift,
18
+ align_points_scale,
19
+ align_points_scale_xyz_shift,
20
+ align_points_z_shift,
21
+ )
22
+
23
+
24
+ def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
25
+ if beta == 0:
26
+ return err
27
+ else:
28
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
29
+
30
+
31
+ def affine_invariant_global_loss(
32
+ pred_points: torch.Tensor,
33
+ gt_points: torch.Tensor,
34
+ mask: torch.Tensor,
35
+ align_resolution: int = 64,
36
+ beta: float = 0.0,
37
+ trunc: float = 1.0,
38
+ sparsity_aware: bool = False
39
+ ):
40
+ device = pred_points.device
41
+
42
+ # Align
43
+ (pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution))
44
+ scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
45
+ valid = scale > 0
46
+ scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
47
+
48
+ pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
49
+
50
+ # Compute loss
51
+ weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
52
+ weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
53
+ loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
54
+
55
+ if sparsity_aware:
56
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
57
+ sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
58
+ loss = loss / (sparsity + 1e-7)
59
+
60
+ err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
61
+
62
+ # Record any scalar metric
63
+ misc = {
64
+ 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
65
+ 'delta': weighted_mean((err < 1).float(), mask).item()
66
+ }
67
+
68
+ return loss, misc, scale.detach()
69
+
70
+
71
+ def monitoring(points: torch.Tensor):
72
+ return {
73
+ 'std': points.std().item(),
74
+ }
75
+
76
+
77
+ def compute_anchor_sampling_weight(
78
+ points: torch.Tensor,
79
+ mask: torch.Tensor,
80
+ radius_2d: torch.Tensor,
81
+ radius_3d: torch.Tensor,
82
+ num_test: int = 64
83
+ ) -> torch.Tensor:
84
+ # Importance sampling to balance the sampled probability of fine strutures.
85
+ # NOTE: MoGe-1 uses uniform random sampling instead of importance sampling.
86
+ # This is an incremental trick introduced later than the publication of MoGe-1 paper.
87
+
88
+ height, width = points.shape[-3:-1]
89
+
90
+ pixel_i, pixel_j = torch.meshgrid(
91
+ torch.arange(height, device=points.device),
92
+ torch.arange(width, device=points.device),
93
+ indexing='ij'
94
+ )
95
+
96
+ test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
97
+ test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
98
+ test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test]
99
+ test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test]
100
+ test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test]
101
+ test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test]
102
+ test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3]
103
+ test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test]
104
+
105
+ weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1)
106
+ weight = torch.where(mask, weight, 0)
107
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width]
108
+ return weight
109
+
110
+
111
+ def affine_invariant_local_loss(
112
+ pred_points: torch.Tensor,
113
+ gt_points: torch.Tensor,
114
+ gt_mask: torch.Tensor,
115
+ focal: torch.Tensor,
116
+ global_scale: torch.Tensor,
117
+ level: Literal[4, 16, 64],
118
+ align_resolution: int = 32,
119
+ num_patches: int = 16,
120
+ beta: float = 0.0,
121
+ trunc: float = 1.0,
122
+ sparsity_aware: bool = False
123
+ ):
124
+ device, dtype = pred_points.device, pred_points.dtype
125
+ *batch_shape, height, width, _ = pred_points.shape
126
+ batch_size = math.prod(batch_shape)
127
+ pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None
128
+
129
+ # Sample patch anchor points indices [num_total_patches]
130
+ radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5)
131
+ radius_3d = 0.5 / level / focal * gt_points[..., 2]
132
+ anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64)
133
+ where_mask = torch.where(gt_mask)
134
+ random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True)
135
+ patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches]
136
+
137
+ # Get patch indices [num_total_patches, patch_h, patch_w]
138
+ patch_i, patch_j = torch.meshgrid(
139
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
140
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
141
+ indexing='ij'
142
+ )
143
+ patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None]
144
+ patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width)
145
+ patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1)
146
+
147
+ # Get patch mask and gt patch points
148
+ gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j]
149
+ gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2]
150
+ gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j]
151
+ gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1)
152
+ patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j]
153
+ patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None]
154
+
155
+ # Pick only non-empty patches
156
+ MINIMUM_POINTS_PER_PATCH = 32
157
+ nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH)
158
+ num_nonempty_patches = nonempty[0].shape[0]
159
+ if num_nonempty_patches == 0:
160
+ return torch.tensor(0.0, dtype=dtype, device=device), {}
161
+
162
+ # Finalize all patch variables
163
+ patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty]
164
+ patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w]
165
+ gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3]
166
+ gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches]
167
+ gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3]
168
+ pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j]
169
+
170
+ # Align patch points
171
+ (pred_patch_points_lr, gt_patch_points_lr), patch_lr_mask = mask_aware_nearest_resize((pred_patch_points, gt_patch_points), mask=patch_mask, size=(align_resolution, align_resolution))
172
+ local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc)
173
+ if global_scale is not None:
174
+ scale_differ = local_scale / global_scale[patch_batch_idx]
175
+ patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0)
176
+ else:
177
+ patch_valid = local_scale > 0
178
+ local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0)
179
+ patch_mask &= patch_valid[:, None, None]
180
+
181
+ pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3]
182
+
183
+ # Compute loss
184
+ gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1))
185
+ patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w]
186
+ loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty]
187
+
188
+ if sparsity_aware:
189
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
190
+ sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1))
191
+ loss = loss / (sparsity + 1e-7)
192
+ loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches
193
+ loss = loss.reshape(batch_shape)
194
+
195
+ err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None]
196
+
197
+ # Record any scalar metric
198
+ misc = {
199
+ 'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(),
200
+ 'delta': weighted_mean((err < 1).float(), patch_mask).item()
201
+ }
202
+
203
+ return loss, misc
204
+
205
+ def normal_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
206
+ device, dtype = points.device, points.dtype
207
+ height, width = points.shape[-3:-1]
208
+
209
+ leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :]
210
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
211
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
212
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
213
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
214
+
215
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :]
216
+ gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1)
217
+ gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1)
218
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
219
+ gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1)
220
+
221
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:]
222
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
223
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
224
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
225
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
226
+
227
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
228
+
229
+ loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
230
+ + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
231
+ + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
232
+ + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
233
+
234
+ loss = loss.mean() / (4 * max(points.shape[-3:-1]))
235
+
236
+ return loss, {}
237
+
238
+
239
+ def edge_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
240
+ device, dtype = points.device, points.dtype
241
+ height, width = points.shape[-3:-1]
242
+
243
+ dx = points[..., :-1, :, :] - points[..., 1:, :, :]
244
+ dy = points[..., :, :-1, :] - points[..., :, 1:, :]
245
+
246
+ gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :]
247
+ gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :]
248
+
249
+ mask_dx = mask[..., :-1, :] & mask[..., 1:, :]
250
+ mask_dy = mask[..., :, :-1] & mask[..., :, 1:]
251
+
252
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3)
253
+
254
+ loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
255
+ loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
256
+ loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1]))
257
+
258
+ return loss, {}
259
+
260
+
261
+ def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
262
+ loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square()
263
+ loss = loss.mean(dim=(-2, -1))
264
+ return loss, {}
265
+
266
+
267
+ def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
268
+ loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none')
269
+ loss = loss.mean(dim=(-2, -1))
270
+ return loss, {}
moge/train/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import fnmatch
3
+
4
+ import sympy
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def any_match(s: str, patterns: List[str]) -> bool:
10
+ return any(fnmatch.fnmatch(s, pat) for pat in patterns)
11
+
12
+
13
+ def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer:
14
+ named_param_groups = [
15
+ {
16
+ k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', []))
17
+ } for param_group_config in optimizer_config['params']
18
+ ]
19
+ excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)]
20
+ assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}'
21
+ optimizer_cls = getattr(torch.optim, optimizer_config['type'])
22
+ optimizer = optimizer_cls([
23
+ {
24
+ **param_group_config,
25
+ 'params': list(params.values()),
26
+ } for param_group_config, params in zip(optimizer_config['params'], named_param_groups)
27
+ ])
28
+ return optimizer
29
+
30
+
31
+ def parse_lr_lambda(s: str) -> Callable[[int], float]:
32
+ epoch = sympy.symbols('epoch')
33
+ lr_lambda = sympy.sympify(s)
34
+ return sympy.lambdify(epoch, lr_lambda, 'math')
35
+
36
+
37
+ def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler:
38
+ if scheduler_config['type'] == "SequentialLR":
39
+ child_schedulers = [
40
+ build_lr_scheduler(optimizer, child_scheduler_config)
41
+ for child_scheduler_config in scheduler_config['params']['schedulers']
42
+ ]
43
+ return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones'])
44
+ elif scheduler_config['type'] == "LambdaLR":
45
+ lr_lambda = scheduler_config['params']['lr_lambda']
46
+ if isinstance(lr_lambda, str):
47
+ lr_lambda = parse_lr_lambda(lr_lambda)
48
+ elif isinstance(lr_lambda, list):
49
+ lr_lambda = [parse_lr_lambda(l) for l in lr_lambda]
50
+ return torch.optim.lr_scheduler.LambdaLR(
51
+ optimizer,
52
+ lr_lambda=lr_lambda,
53
+ )
54
+ else:
55
+ scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type'])
56
+ scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {}))
57
+ return scheduler
moge/utils/__init__.py ADDED
File without changes
moge/utils/alignment.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+ import utils3d
11
+
12
+
13
+ def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
14
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
15
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
16
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
17
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
18
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
19
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
20
+ return torch.return_types.min((minimum, indices))
21
+
22
+
23
+ def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
24
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
25
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
26
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
27
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
28
+ results = []
29
+ for i in range(n_chunks):
30
+ chunk_args = tuple(arg[i] for arg in splited_args)
31
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
32
+ results.append(fn(*chunk_args, **chunk_kwargs))
33
+
34
+ if isinstance(results[0], tuple):
35
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
36
+ else:
37
+ return torch.cat(results, dim=0)
38
+
39
+
40
+ def _pad_inf(x_: torch.Tensor):
41
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
42
+
43
+
44
+ def _pad_cumsum(cumsum: torch.Tensor):
45
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
46
+
47
+
48
+ def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
49
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
50
+
51
+
52
+ def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
53
+ """
54
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
55
+
56
+ w_i must be >= 0.
57
+
58
+ ### Parameters:
59
+ - `x`: tensor of shape (..., n)
60
+ - `y`: tensor of shape (..., n)
61
+ - `w`: tensor of shape (..., n)
62
+ - `trunc`: optional, float or tensor of shape (..., n) or None
63
+
64
+ ### Returns:
65
+ - `a`: tensor of shape (...), differentiable
66
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
67
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
68
+ """
69
+ if trunc is None:
70
+ x, y, w = torch.broadcast_tensors(x, y, w)
71
+ sign = torch.sign(x)
72
+ x, y = x * sign, y * sign
73
+ y_div_x = y / x.clamp_min(eps)
74
+ y_div_x, argsort = y_div_x.sort(dim=-1)
75
+
76
+ wx = torch.gather(x * w, dim=-1, index=argsort)
77
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
78
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
79
+
80
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
81
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
82
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
83
+
84
+ else:
85
+ # Reshape to (batch_size, n) for simplicity
86
+ x, y, w = torch.broadcast_tensors(x, y, w)
87
+ batch_shape = x.shape[:-1]
88
+ batch_size = math.prod(batch_shape)
89
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
90
+
91
+ sign = torch.sign(x)
92
+ x, y = x * sign, y * sign
93
+ wx, wy = w * x, w * y
94
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
95
+
96
+ y_div_x = A = y / x.clamp_min(eps)
97
+ B = (wy - trunc) / wx.clamp_min(eps)
98
+ C = (wy + trunc) / wx.clamp_min(eps)
99
+ with torch.no_grad():
100
+ # Caculate prefix sum by orders of A, B, C
101
+ A, A_argsort = A.sort(dim=-1)
102
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
103
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
104
+
105
+ B, B_argsort = B.sort(dim=-1)
106
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
107
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
108
+
109
+ C, C_argsort = C.sort(dim=-1)
110
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
111
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
112
+
113
+ # Caculate left and right derivative of A
114
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
115
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
116
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
117
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
118
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
119
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
120
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
121
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
122
+
123
+ # Find extrema
124
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
125
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
126
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
127
+
128
+ # Calculate objective value at extrema
129
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
130
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
131
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
132
+ extrema_value = torch.cat([
133
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
134
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
135
+ ]) # (num_extrema,)
136
+
137
+ # Find minima among corresponding extrema
138
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
139
+ index = where_extrema_index[indices]
140
+
141
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
142
+ a = a.reshape(batch_shape)
143
+ loss = minima.reshape(batch_shape)
144
+ index = index.reshape(batch_shape)
145
+
146
+ return a, loss, index
147
+
148
+
149
+ def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
150
+ """
151
+ Align `depth_src` to `depth_tgt` with given constant weights.
152
+
153
+ ### Parameters:
154
+ - `depth_src: torch.Tensor` of shape (..., N)
155
+ - `depth_tgt: torch.Tensor` of shape (..., N)
156
+
157
+ """
158
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
159
+
160
+ return scale
161
+
162
+
163
+ def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
164
+ """
165
+ Align `depth_src` to `depth_tgt` with given constant weights.
166
+
167
+ ### Parameters:
168
+ - `depth_src: torch.Tensor` of shape (..., N)
169
+ - `depth_tgt: torch.Tensor` of shape (..., N)
170
+ - `weight: torch.Tensor` of shape (..., N)
171
+ - `trunc: float` or tensor of shape (..., N) or None
172
+
173
+ ### Returns:
174
+ - `scale: torch.Tensor` of shape (...).
175
+ - `shift: torch.Tensor` of shape (...).
176
+ """
177
+ dtype, device = depth_src.dtype, depth_src.device
178
+
179
+ # Flatten batch dimensions for simplicity
180
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
181
+ batch_size = math.prod(batch_shape)
182
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
183
+
184
+ # Here, we take anchors only for non-zero weights.
185
+ # Although the results will be still correct even anchor points have zero weight,
186
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
187
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
188
+
189
+ # Stop gradient when solving optimal anchors
190
+ with torch.no_grad():
191
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
192
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
193
+
194
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
195
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
196
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
197
+
198
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
199
+
200
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
201
+
202
+ # Reproduce by indexing for shorter compute graph
203
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
204
+ index_2 = index[index_anchor] # (batch_size,)
205
+
206
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
207
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
208
+
209
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
210
+ shift = tgt_1 - scale * src_1
211
+
212
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
213
+
214
+ return scale, shift
215
+
216
+ def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
217
+ """
218
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
219
+ """
220
+ dtype, device = depth_src.dtype, depth_src.device
221
+
222
+ w = weight
223
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
224
+ y = depth_tgt
225
+
226
+ for i in range(max_iter):
227
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
228
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
229
+
230
+ return beta[..., 0], beta[..., 1]
231
+
232
+
233
+ def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
234
+ """
235
+ ### Parameters:
236
+ - `points_src: torch.Tensor` of shape (..., N, 3)
237
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
238
+ - `weight: torch.Tensor` of shape (..., N)
239
+
240
+ ### Returns:
241
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
242
+ - `b: torch.Tensor` of shape (...)
243
+ """
244
+ dtype, device = points_src.dtype, points_src.device
245
+
246
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
247
+
248
+ return scale
249
+
250
+
251
+ def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
252
+ """
253
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
254
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
255
+
256
+ ### Parameters:
257
+ - `points_src: torch.Tensor` of shape (..., N, 3)
258
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
259
+ - `weights: torch.Tensor` of shape (..., N)
260
+
261
+ ### Returns:
262
+ - `scale: torch.Tensor` of shape (...).
263
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
264
+ """
265
+ dtype, device = points_src.dtype, points_src.device
266
+
267
+ # Flatten batch dimensions for simplicity
268
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
269
+ batch_size = math.prod(batch_shape)
270
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
271
+
272
+ # Take anchors
273
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
274
+ with torch.no_grad():
275
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
276
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
277
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
278
+
279
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
280
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
281
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
282
+
283
+ # Solve optimal scale and shift for each anchor
284
+ MAX_ELEMENTS = 2 ** 20
285
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
286
+
287
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
288
+
289
+ # Reproduce by indexing for shorter compute graph
290
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
291
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
292
+
293
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
294
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
295
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
296
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
297
+
298
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
299
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
300
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
301
+
302
+ return scale, shift
303
+
304
+
305
+ def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
306
+ """
307
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
308
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
309
+
310
+ ### Parameters:
311
+ - `points_src: torch.Tensor` of shape (..., N, 3)
312
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
313
+ - `weights: torch.Tensor` of shape (..., N)
314
+
315
+ ### Returns:
316
+ - `scale: torch.Tensor` of shape (...).
317
+ - `shift: torch.Tensor` of shape (..., 3)
318
+ """
319
+ dtype, device = points_src.dtype, points_src.device
320
+
321
+ # Flatten batch dimensions for simplicity
322
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
323
+ batch_size = math.prod(batch_shape)
324
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
325
+
326
+ # Take anchors
327
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
328
+
329
+ with torch.no_grad():
330
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
331
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
332
+
333
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
334
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
335
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
336
+
337
+ # Solve optimal scale and shift for each anchor
338
+ MAX_ELEMENTS = 2 ** 20
339
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
340
+
341
+ # Get optimal scale and shift for each batch element
342
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
343
+
344
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
345
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
346
+
347
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
348
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
349
+
350
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
351
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
352
+
353
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
354
+
355
+ return scale, shift
356
+
357
+
358
+ def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
359
+ """
360
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
361
+
362
+ ### Parameters:
363
+ - `points_src: torch.Tensor` of shape (..., N, 3)
364
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
365
+ - `weights: torch.Tensor` of shape (..., N)
366
+
367
+ ### Returns:
368
+ - `scale: torch.Tensor` of shape (...).
369
+ - `shift: torch.Tensor` of shape (..., 3)
370
+ """
371
+ dtype, device = points_src.dtype, points_src.device
372
+
373
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
374
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
375
+
376
+ return shift
377
+
378
+
379
+ def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
380
+ """
381
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
382
+
383
+ ### Parameters:
384
+ - `points_src: torch.Tensor` of shape (..., N, 3)
385
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
386
+ - `weights: torch.Tensor` of shape (..., N)
387
+
388
+ ### Returns:
389
+ - `scale: torch.Tensor` of shape (...).
390
+ - `shift: torch.Tensor` of shape (..., 3)
391
+ """
392
+ dtype, device = points_src.dtype, points_src.device
393
+
394
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
395
+
396
+ return shift
397
+
398
+
399
+ def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ """
401
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
402
+
403
+ ### Parameters:
404
+ - `x: torch.Tensor` of shape (..., N)
405
+ - `y: torch.Tensor` of shape (..., N)
406
+ - `w: torch.Tensor` of shape (..., N)
407
+
408
+ ### Returns:
409
+ - `a: torch.Tensor` of shape (...,)
410
+ - `b: torch.Tensor` of shape (...,)
411
+ """
412
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
413
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
414
+ B = (w_sqrt * y)[..., None]
415
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
416
+ return a, b
moge/utils/download.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import *
3
+ import requests
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ __all__ = ["download_file", "download_bytes"]
9
+
10
+
11
+ def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
12
+ # Ensure headers is a dict if not provided
13
+ headers = headers or {}
14
+
15
+ # Initialize local variables
16
+ file_path = Path(filepath)
17
+ downloaded_bytes = 0
18
+
19
+ # Check if we should resume the download
20
+ if resume and file_path.exists():
21
+ downloaded_bytes = file_path.stat().st_size
22
+ headers['Range'] = f"bytes={downloaded_bytes}-"
23
+
24
+ # Make a GET request to fetch the file
25
+ with requests.get(url, stream=True, headers=headers) as response:
26
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
27
+
28
+ # Calculate the total size to download
29
+ total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
30
+
31
+ # Display a progress bar while downloading
32
+ with (
33
+ tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
34
+ open(file_path, 'ab') as file,
35
+ ):
36
+ # Set the initial position of the progress bar
37
+ pbar.update(downloaded_bytes)
38
+
39
+ # Write the content to the file in chunks
40
+ for chunk in response.iter_content(chunk_size=4096):
41
+ file.write(chunk)
42
+ pbar.update(len(chunk))
43
+
44
+
45
+ def download_bytes(url: str, headers: dict = None) -> bytes:
46
+ # Ensure headers is a dict if not provided
47
+ headers = headers or {}
48
+
49
+ # Make a GET request to fetch the file
50
+ with requests.get(url, stream=True, headers=headers) as response:
51
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
52
+
53
+ # Read the content of the response
54
+ return response.content
55
+
moge/utils/geometry_numpy.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from functools import partial
3
+ import math
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from scipy.signal import fftconvolve
8
+ import numpy as np
9
+ import utils3d
10
+
11
+ from .tools import timeit
12
+
13
+
14
+ def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
15
+ if w is None:
16
+ return np.mean(x, axis=axis)
17
+ else:
18
+ w = w.astype(x.dtype)
19
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
20
+
21
+
22
+ def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
23
+ if w is None:
24
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
25
+ else:
26
+ w = w.astype(x.dtype)
27
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
28
+
29
+
30
+ def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
31
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
32
+ if aspect_ratio is None:
33
+ aspect_ratio = width / height
34
+
35
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
36
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
37
+
38
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
39
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
40
+ u, v = np.meshgrid(u, v, indexing='xy')
41
+ uv = np.stack([u, v], axis=-1)
42
+ return uv
43
+
44
+
45
+ def focal_to_fov_numpy(focal: np.ndarray):
46
+ return 2 * np.arctan(0.5 / focal)
47
+
48
+
49
+ def fov_to_focal_numpy(fov: np.ndarray):
50
+ return 0.5 / np.tan(fov / 2)
51
+
52
+
53
+ def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
54
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
55
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
56
+ return fov_x, fov_y
57
+
58
+
59
+ def point_map_to_depth_legacy_numpy(points: np.ndarray):
60
+ height, width = points.shape[-3:-1]
61
+ diagonal = (height ** 2 + width ** 2) ** 0.5
62
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
63
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
64
+
65
+ # Solve least squares problem
66
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
67
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
68
+
69
+ M = A.swapaxes(-2, -1) @ A
70
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
71
+ focal, shift = solution
72
+
73
+ depth = points[..., 2] + shift[..., None, None]
74
+ fov_x = np.arctan(width / diagonal / focal) * 2
75
+ fov_y = np.arctan(height / diagonal / focal) * 2
76
+ return depth, fov_x, fov_y, shift
77
+
78
+
79
+ def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
80
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
81
+ from scipy.optimize import least_squares
82
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
83
+
84
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
85
+ xy_proj = xy / (z + shift)[: , None]
86
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
87
+ err = (f * xy_proj - uv).ravel()
88
+ return err
89
+
90
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
91
+ optim_shift = solution['x'].squeeze().astype(np.float32)
92
+
93
+ xy_proj = xy / (z + optim_shift)[: , None]
94
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
95
+
96
+ return optim_shift, optim_focal
97
+
98
+
99
+ def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
100
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
101
+ from scipy.optimize import least_squares
102
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
103
+
104
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
105
+ xy_proj = xy / (z + shift)[: , None]
106
+ err = (focal * xy_proj - uv).ravel()
107
+ return err
108
+
109
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
110
+ optim_shift = solution['x'].squeeze().astype(np.float32)
111
+
112
+ return optim_shift
113
+
114
+
115
+ def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
116
+ import cv2
117
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
118
+
119
+ height, width = points.shape[-3], points.shape[-2]
120
+ diagonal = (height ** 2 + width ** 2) ** 0.5
121
+
122
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
123
+
124
+ if mask is None:
125
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
126
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
127
+ else:
128
+ (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
129
+
130
+ if points_lr.size < 2:
131
+ return 1., 0.
132
+
133
+ if focal is None:
134
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
135
+ else:
136
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
137
+
138
+ return focal, shift
139
+
140
+
141
+ def mask_aware_nearest_resize_numpy(
142
+ inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
143
+ mask: np.ndarray,
144
+ size: Tuple[int, int],
145
+ return_index: bool = False
146
+ ) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
147
+ """
148
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
149
+
150
+ ### Parameters
151
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
152
+ - `mask`: input 2D mask of shape (..., H, W)
153
+ - `size`: target size (width, height)
154
+
155
+ ### Returns
156
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
157
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
158
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
159
+ """
160
+ height, width = mask.shape[-2:]
161
+ target_width, target_height = size
162
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
163
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
164
+ filter_size = filter_h_i * filter_w_i
165
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
166
+
167
+ # Window the original mask and uv
168
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
169
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
170
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
171
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
172
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
173
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
174
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
175
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
176
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
177
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
178
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
179
+
180
+ # Gather the target pixels's local window
181
+ target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
182
+ target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
183
+ target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
184
+
185
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
186
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
187
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
188
+
189
+ # Compute nearest neighbor in the local window for each pixel
190
+ dist = np.square(target_window_centers - target_centers[..., None])
191
+ dist = dist[..., 0, :] + dist[..., 1, :]
192
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
193
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
194
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
195
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
196
+ target_mask = np.any(target_window_mask, axis=-1)
197
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
198
+
199
+ index = (*batch_indices, nearest_i, nearest_j)
200
+
201
+ if inputs is None:
202
+ outputs = None
203
+ elif isinstance(inputs, np.ndarray):
204
+ outputs = inputs[index]
205
+ elif isinstance(inputs, Sequence):
206
+ outputs = tuple(x[index] for x in inputs)
207
+ else:
208
+ raise ValueError(f'Invalid input type: {type(inputs)}')
209
+
210
+ if return_index:
211
+ return outputs, target_mask, index
212
+ else:
213
+ return outputs, target_mask
214
+
215
+
216
+ def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
217
+ """
218
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
219
+
220
+ ### Parameters
221
+ - `image`: Input 2D image of shape (..., H, W, C)
222
+ - `mask`: Input 2D mask of shape (..., H, W)
223
+ - `target_width`: target width of the resized map
224
+ - `target_height`: target height of the resized map
225
+
226
+ ### Returns
227
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
228
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
229
+ """
230
+ height, width = mask.shape[-2:]
231
+
232
+ if image.shape[-2:] == (height, width):
233
+ omit_channel_dim = True
234
+ else:
235
+ omit_channel_dim = False
236
+ if omit_channel_dim:
237
+ image = image[..., None]
238
+
239
+ image = np.where(mask[..., None], image, 0)
240
+
241
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
242
+ filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
243
+ filter_size = filter_h_i * filter_w_i
244
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
245
+
246
+ # Window the original mask and uv (non-copy)
247
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
248
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
249
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
250
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
251
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
252
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
253
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
254
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
255
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
256
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
257
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
258
+
259
+ # Gather the target pixels's local window
260
+ target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
261
+ target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
262
+ target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
263
+ target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
264
+
265
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
266
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
267
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
268
+
269
+ # Compute pixel area in the local windows
270
+ target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
271
+ target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
272
+ target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
273
+ target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
274
+
275
+ # Weighted sum by area
276
+ target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
277
+ target_mask = np.sum(target_window_area, axis=-1) >= 0.25
278
+ target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
279
+
280
+ if omit_channel_dim:
281
+ target_image = target_image[..., 0]
282
+
283
+ return target_image, target_mask
284
+
285
+
286
+ def norm3d(x: np.ndarray) -> np.ndarray:
287
+ "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
288
+ return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
289
+
290
+
291
+ def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
292
+ disp = np.where(mask, 1 / depth, 0)
293
+ disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
294
+ mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
295
+ kernel_size = 2 * thickness + 1
296
+ disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
297
+ mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
298
+
299
+ disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
300
+ fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
301
+ bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
302
+
303
+ edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
304
+ & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
305
+
306
+ return edge_mask
307
+
308
+
309
+ def disk_kernel(radius: int) -> np.ndarray:
310
+ """
311
+ Generate disk kernel with given radius.
312
+
313
+ Args:
314
+ radius (int): Radius of the disk (in pixels).
315
+
316
+ Returns:
317
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
318
+ """
319
+ # Create coordinate grid centered at (0,0)
320
+ L = np.arange(-radius, radius + 1)
321
+ X, Y = np.meshgrid(L, L)
322
+ # Generate disk: region inside circle with radius R is 1
323
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
324
+ # Normalize the kernel
325
+ kernel /= np.sum(kernel)
326
+ return kernel
327
+
328
+
329
+ def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
330
+ """
331
+ Apply disk blur to an image using FFT convolution.
332
+
333
+ Args:
334
+ image (np.ndarray): Input image, can be grayscale or color.
335
+ radius (int): Blur radius (in pixels).
336
+
337
+ Returns:
338
+ np.ndarray: Blurred image.
339
+ """
340
+ if radius == 0:
341
+ return image
342
+ kernel = disk_kernel(radius)
343
+ if image.ndim == 2:
344
+ blurred = fftconvolve(image, kernel, mode='same')
345
+ elif image.ndim == 3:
346
+ channels = []
347
+ for i in range(image.shape[2]):
348
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
349
+ channels.append(blurred_channel)
350
+ blurred = np.stack(channels, axis=-1)
351
+ else:
352
+ raise ValueError("Image must be 2D or 3D.")
353
+ return blurred
354
+
355
+
356
+ def depth_of_field(
357
+ img: np.ndarray,
358
+ disp: np.ndarray,
359
+ focus_disp : float,
360
+ max_blur_radius : int = 10,
361
+ ) -> np.ndarray:
362
+ """
363
+ Apply depth of field effect to an image.
364
+
365
+ Args:
366
+ img (numpy.ndarray): (H, W, 3) input image.
367
+ depth (numpy.ndarray): (H, W) depth map of the scene.
368
+ focus_depth (float): Focus depth of the lens.
369
+ strength (float): Strength of the depth of field effect.
370
+ max_blur_radius (int): Maximum blur radius (in pixels).
371
+
372
+ Returns:
373
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
374
+ """
375
+ # Precalculate dialated depth map for each blur radius
376
+ max_disp = np.max(disp)
377
+ disp = disp / max_disp
378
+ focus_disp = focus_disp / max_disp
379
+ dilated_disp = []
380
+ for radius in range(max_blur_radius + 1):
381
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
382
+
383
+ # Determine the blur radius for each pixel based on the depth map
384
+ blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
385
+ for radius in range(max_blur_radius + 1):
386
+ dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
387
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
388
+ blur_radii[mask] = dialted_blur_radii[mask]
389
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
390
+ blur_radii = cv2.blur(blur_radii, (5, 5))
391
+
392
+ # Precalculate the blured image for each blur radius
393
+ unique_radii = np.unique(blur_radii)
394
+ precomputed = {}
395
+ for radius in range(max_blur_radius + 1):
396
+ if radius not in unique_radii:
397
+ continue
398
+ precomputed[radius] = disk_blur(img, radius)
399
+
400
+ # Composit the blured image for each pixel
401
+ output = np.zeros_like(img)
402
+ for r in unique_radii:
403
+ mask = blur_radii == r
404
+ output[mask] = precomputed[r][mask]
405
+
406
+ return output
moge/utils/geometry_torch.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+ import utils3d
11
+
12
+ from .tools import timeit
13
+ from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
14
+
15
+
16
+ def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
17
+ if w is None:
18
+ return x.mean(dim=dim, keepdim=keepdim)
19
+ else:
20
+ w = w.to(x.dtype)
21
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
22
+
23
+
24
+ def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
25
+ if w is None:
26
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
27
+ else:
28
+ w = w.to(x.dtype)
29
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
30
+
31
+
32
+ def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
33
+ if w is None:
34
+ return x.add(eps).log().mean(dim=dim).exp()
35
+ else:
36
+ w = w.to(x.dtype)
37
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
38
+
39
+
40
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
41
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
42
+ if aspect_ratio is None:
43
+ aspect_ratio = width / height
44
+
45
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
46
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
47
+
48
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
49
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
50
+ u, v = torch.meshgrid(u, v, indexing='xy')
51
+ uv = torch.stack([u, v], dim=-1)
52
+ return uv
53
+
54
+
55
+ def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
56
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
57
+ kernel = kernel / kernel.sum()
58
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
59
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
60
+ input = F.conv2d(input, kernel, groups=input.shape[1])
61
+ return input
62
+
63
+
64
+ def focal_to_fov(focal: torch.Tensor):
65
+ return 2 * torch.atan(0.5 / focal)
66
+
67
+
68
+ def fov_to_focal(fov: torch.Tensor):
69
+ return 0.5 / torch.tan(fov / 2)
70
+
71
+
72
+ def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
73
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
74
+
75
+ def intrinsics_to_fov(intrinsics: torch.Tensor):
76
+ """
77
+ Returns field of view in radians from normalized intrinsics matrix.
78
+ ### Parameters:
79
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
80
+
81
+ ### Returns:
82
+ - fov_x: torch.Tensor of shape (...)
83
+ - fov_y: torch.Tensor of shape (...)
84
+ """
85
+ focal_x = intrinsics[..., 0, 0]
86
+ focal_y = intrinsics[..., 1, 1]
87
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
88
+
89
+
90
+ def point_map_to_depth_legacy(points: torch.Tensor):
91
+ height, width = points.shape[-3:-1]
92
+ diagonal = (height ** 2 + width ** 2) ** 0.5
93
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
94
+
95
+ # Solve least squares problem
96
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
97
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
98
+
99
+ M = A.transpose(-2, -1) @ A
100
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
101
+ focal, shift = solution.unbind(-1)
102
+
103
+ depth = points[..., 2] + shift[..., None, None]
104
+ fov_x = torch.atan(width / diagonal / focal) * 2
105
+ fov_y = torch.atan(height / diagonal / focal) * 2
106
+ return depth, fov_x, fov_y, shift
107
+
108
+
109
+ def view_plane_uv_to_focal(uv: torch.Tensor):
110
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
111
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
112
+ return focal
113
+
114
+
115
+ def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
116
+ """
117
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
118
+
119
+ Note that it assumes:
120
+ - the optical center is at the center of the map
121
+ - the map is undistorted
122
+ - the map is isometric in the x and y directions
123
+
124
+ ### Parameters:
125
+ - `points: torch.Tensor` of shape (..., H, W, 3)
126
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
127
+
128
+ ### Returns:
129
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
130
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
131
+ """
132
+ shape = points.shape
133
+ height, width = points.shape[-3], points.shape[-2]
134
+ diagonal = (height ** 2 + width ** 2) ** 0.5
135
+
136
+ points = points.reshape(-1, *shape[-3:])
137
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
138
+ focal = focal.reshape(-1) if focal is not None else None
139
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
140
+
141
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
142
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
143
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
144
+
145
+ uv_lr_np = uv_lr.cpu().numpy()
146
+ points_lr_np = points_lr.detach().cpu().numpy()
147
+ focal_np = focal.cpu().numpy() if focal is not None else None
148
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
149
+ optim_shift, optim_focal = [], []
150
+ for i in range(points.shape[0]):
151
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
152
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
153
+ if uv_lr_i_np.shape[0] < 2:
154
+ optim_focal.append(1)
155
+ optim_shift.append(0)
156
+ continue
157
+ if focal is None:
158
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
159
+ optim_focal.append(float(optim_focal_i))
160
+ else:
161
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
162
+ optim_shift.append(float(optim_shift_i))
163
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
164
+
165
+ if focal is None:
166
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
167
+ else:
168
+ optim_focal = focal.reshape(shape[:-3])
169
+
170
+ return optim_focal, optim_shift
171
+
172
+
173
+ def mask_aware_nearest_resize(
174
+ inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
175
+ mask: torch.BoolTensor,
176
+ size: Tuple[int, int],
177
+ return_index: bool = False
178
+ ) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
179
+ """
180
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
181
+
182
+ ### Parameters
183
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
184
+ - `mask`: input 2D mask of shape (..., H, W)
185
+ - `size`: target size (target_width, target_height)
186
+
187
+ ### Returns
188
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
189
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
190
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
191
+ """
192
+ height, width = mask.shape[-2:]
193
+ target_width, target_height = size
194
+ device = mask.device
195
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
196
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
197
+ filter_size = filter_h_i * filter_w_i
198
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
199
+
200
+ # Window the original mask and uv
201
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
202
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
203
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
204
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
205
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
206
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
207
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
208
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
209
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
210
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
211
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
212
+
213
+ # Gather the target pixels's local window
214
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
215
+ target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
216
+ target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
217
+
218
+ target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
219
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
220
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
221
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
222
+
223
+ # Compute nearest neighbor in the local window for each pixel
224
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
225
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
226
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
227
+ target_mask = torch.any(target_window_mask, dim=-1)
228
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
229
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
230
+
231
+ index = (*batch_indices, nearest_i, nearest_j)
232
+
233
+ if inputs is None:
234
+ outputs = None
235
+ elif isinstance(inputs, torch.Tensor):
236
+ outputs = inputs[index]
237
+ elif isinstance(inputs, Sequence):
238
+ outputs = tuple(x[index] for x in inputs)
239
+ else:
240
+ raise ValueError(f'Invalid input type: {type(inputs)}')
241
+
242
+ if return_index:
243
+ return outputs, target_mask, index
244
+ else:
245
+ return outputs, target_mask
246
+
247
+
248
+ def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
249
+ *batch_shape, height, width = depth.shape
250
+ depth = depth.reshape(-1, 1, height, width)
251
+ mask = mask.reshape(-1, 1, height, width)
252
+ if pooler =='max':
253
+ pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
254
+ output_mask = pooled_depth > depth * (1 + rtol)
255
+ elif pooler =='min':
256
+ pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
257
+ output_mask = pooled_depth < depth * (1 - rtol)
258
+ else:
259
+ raise ValueError(f'Unsupported pooler: {pooler}')
260
+ output_mask = output_mask.reshape(*batch_shape, height, width)
261
+ return output_mask
262
+
263
+
264
+ def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
265
+ device, dtype = depth.device, depth.dtype
266
+
267
+ disp = torch.where(mask, 1 / depth, 0)
268
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
269
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
270
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
271
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
272
+
273
+ x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
274
+ A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
275
+ A = mask_window[..., None] * A
276
+ I = torch.eye(3, device=device, dtype=dtype)
277
+
278
+ affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
279
+ diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
280
+
281
+ edge_mask = mask & (diff > tol).any(dim=-1)
282
+
283
+ disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
284
+ fg_edge_mask = edge_mask & (disp > disp_mean)
285
+ # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
286
+ bg_edge_mask = edge_mask & ~fg_edge_mask
287
+ return fg_edge_mask, bg_edge_mask
288
+
289
+
290
+ def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
291
+ device, dtype = depth.device, depth.dtype
292
+
293
+ disp = torch.where(mask, 1 / depth, 0)
294
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
295
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
296
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
297
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
298
+
299
+ disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
300
+ fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
301
+ bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
302
+
303
+ fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
304
+ bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
305
+
306
+ return fg_edge_mask, bg_edge_mask
307
+
308
+
309
+ def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
310
+ kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
311
+ for _ in range(iterations):
312
+ input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
313
+ mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
314
+ if filter =='min':
315
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
316
+ elif filter =='max':
317
+ input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
318
+ elif filter == 'mean':
319
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
320
+ elif filter =='median':
321
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
322
+ mask = mask_window.any(dim=(-2, -1))
323
+ return input, mask
324
+
325
+
326
+ def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
327
+ device, dtype = depth.device, depth.dtype
328
+ height, width = depth.shape[-2:]
329
+ radius = kernel_size // 2
330
+
331
+ duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
332
+
333
+ log_depth = depth.clamp_min_(eps).log()
334
+ log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
335
+
336
+ weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
337
+ tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
338
+
339
+ uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype)
340
+ K_inv = torch.inverse(intrinsics)
341
+
342
+ grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
343
+ / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
344
+ laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
345
+
346
+ laplacian = laplacian.clamp(-0.1, 0.1)
347
+ log_depth_refine = log_depth.clone()
348
+
349
+ for _ in range(iterations):
350
+ log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
351
+
352
+ depth_refine = log_depth_refine.exp()
353
+
354
+ return depth_refine
moge/utils/io.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from typing import IO
4
+ import zipfile
5
+ import json
6
+ import io
7
+ from typing import *
8
+ from pathlib import Path
9
+ import re
10
+ from PIL import Image, PngImagePlugin
11
+
12
+ import numpy as np
13
+ import cv2
14
+
15
+ from .tools import timeit
16
+
17
+
18
+ def save_glb(
19
+ save_path: Union[str, os.PathLike],
20
+ vertices: np.ndarray,
21
+ faces: np.ndarray,
22
+ vertex_uvs: np.ndarray,
23
+ texture: np.ndarray,
24
+ vertex_normals: Optional[np.ndarray] = None,
25
+ ):
26
+ import trimesh
27
+ import trimesh.visual
28
+ from PIL import Image
29
+
30
+ trimesh.Trimesh(
31
+ vertices=vertices,
32
+ vertex_normals=vertex_normals,
33
+ faces=faces,
34
+ visual = trimesh.visual.texture.TextureVisuals(
35
+ uv=vertex_uvs,
36
+ material=trimesh.visual.material.PBRMaterial(
37
+ baseColorTexture=Image.fromarray(texture),
38
+ metallicFactor=0.5,
39
+ roughnessFactor=1.0
40
+ )
41
+ ),
42
+ process=False
43
+ ).export(save_path)
44
+
45
+
46
+ def save_ply(
47
+ save_path: Union[str, os.PathLike],
48
+ vertices: np.ndarray,
49
+ faces: np.ndarray,
50
+ vertex_colors: np.ndarray,
51
+ vertex_normals: Optional[np.ndarray] = None,
52
+ ):
53
+ import trimesh
54
+ import trimesh.visual
55
+ from PIL import Image
56
+
57
+ trimesh.Trimesh(
58
+ vertices=vertices,
59
+ faces=faces,
60
+ vertex_colors=vertex_colors,
61
+ vertex_normals=vertex_normals,
62
+ process=False
63
+ ).export(save_path)
64
+
65
+
66
+ def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
67
+ """
68
+ Read a image, return uint8 RGB array of shape (H, W, 3).
69
+ """
70
+ if isinstance(path, (str, os.PathLike)):
71
+ data = Path(path).read_bytes()
72
+ else:
73
+ data = path.read()
74
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
75
+ return image
76
+
77
+
78
+ def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
79
+ """
80
+ Write a image, input uint8 RGB array of shape (H, W, 3).
81
+ """
82
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
83
+ if isinstance(path, (str, os.PathLike)):
84
+ Path(path).write_bytes(data)
85
+ else:
86
+ path.write(data)
87
+
88
+
89
+ def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]:
90
+ """
91
+ Read a depth image, return float32 depth array of shape (H, W).
92
+ """
93
+ if isinstance(path, (str, os.PathLike)):
94
+ data = Path(path).read_bytes()
95
+ else:
96
+ data = path.read()
97
+ pil_image = Image.open(io.BytesIO(data))
98
+ near = float(pil_image.info.get('near'))
99
+ far = float(pil_image.info.get('far'))
100
+ unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None
101
+ depth = np.array(pil_image)
102
+ mask_nan, mask_inf = depth == 0, depth == 65535
103
+ depth = (depth.astype(np.float32) - 1) / 65533
104
+ depth = near ** (1 - depth) * far ** depth
105
+ depth[mask_nan] = np.nan
106
+ depth[mask_inf] = np.inf
107
+ return depth, unit
108
+
109
+
110
+ def write_depth(
111
+ path: Union[str, os.PathLike, IO],
112
+ depth: np.ndarray,
113
+ unit: float = None,
114
+ max_range: float = 1e5,
115
+ compression_level: int = 7,
116
+ ):
117
+ """
118
+ Encode and write a depth image as 16-bit PNG format.
119
+ ### Parameters:
120
+ - `path: Union[str, os.PathLike, IO]`
121
+ The file path or file object to write to.
122
+ - `depth: np.ndarray`
123
+ The depth array, float32 array of shape (H, W).
124
+ May contain `NaN` for invalid values and `Inf` for infinite values.
125
+ - `unit: float = None`
126
+ The unit of the depth values.
127
+
128
+ Depth values are encoded as follows:
129
+ - 0: unknown
130
+ - 1 ~ 65534: depth values in logarithmic
131
+ - 65535: infinity
132
+
133
+ metadata is stored in the PNG file as text fields:
134
+ - `near`: the minimum depth value
135
+ - `far`: the maximum depth value
136
+ - `unit`: the unit of the depth values (optional)
137
+ """
138
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
139
+
140
+ depth = depth.astype(np.float32)
141
+ mask_finite = depth
142
+ near = max(depth[mask_values].min(), 1e-5)
143
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
144
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
145
+ depth[mask_nan] = 0
146
+ depth[mask_inf] = 65535
147
+
148
+ pil_image = Image.fromarray(depth)
149
+ pnginfo = PngImagePlugin.PngInfo()
150
+ pnginfo.add_text('near', str(near))
151
+ pnginfo.add_text('far', str(far))
152
+ if unit is not None:
153
+ pnginfo.add_text('unit', str(unit))
154
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
155
+
156
+
157
+ def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
158
+ """
159
+ Read a segmentation mask
160
+ ### Parameters:
161
+ - `path: Union[str, os.PathLike, IO]`
162
+ The file path or file object to read from.
163
+ ### Returns:
164
+ - `Tuple[np.ndarray, Dict[str, int]]`
165
+ A tuple containing:
166
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
167
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
168
+ """
169
+ if isinstance(path, (str, os.PathLike)):
170
+ data = Path(path).read_bytes()
171
+ else:
172
+ data = path.read()
173
+ pil_image = Image.open(io.BytesIO(data))
174
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
175
+ mask = np.array(pil_image)
176
+ return mask, labels
177
+
178
+
179
+ def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
180
+ """
181
+ Write a segmentation mask and label mapping, as PNG format.
182
+ ### Parameters:
183
+ - `path: Union[str, os.PathLike, IO]`
184
+ The file path or file object to write to.
185
+ - `mask: np.ndarray`
186
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
187
+ - `labels: Dict[str, int] = None`
188
+ The label mapping, a dictionary of {label_name: label_id}.
189
+ - `compression_level: int = 7`
190
+ The compression level for PNG compression.
191
+ """
192
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
193
+ pil_image = Image.fromarray(mask)
194
+ pnginfo = PngImagePlugin.PngInfo()
195
+ if labels is not None:
196
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
197
+ pnginfo.add_text('labels', labels_json)
198
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
199
+
200
+
201
+
202
+ def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
203
+ """
204
+ Read a normal image, return float32 normal array of shape (H, W, 3).
205
+ """
206
+ if isinstance(path, (str, os.PathLike)):
207
+ data = Path(path).read_bytes()
208
+ else:
209
+ data = path.read()
210
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
211
+ mask_nan = np.all(normal == 0, axis=-1)
212
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
213
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
214
+ normal[mask_nan] = np.nan
215
+ return normal
216
+
217
+
218
+ def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
219
+ """
220
+ Write a normal image, input float32 normal array of shape (H, W, 3).
221
+ """
222
+ mask_nan = np.isnan(normal).any(axis=-1)
223
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
224
+ normal[mask_nan] = 0
225
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
226
+ if isinstance(path, (str, os.PathLike)):
227
+ Path(path).write_bytes(data)
228
+ else:
229
+ path.write(data)
230
+
231
+
232
+ def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]:
233
+ return json.loads(Path(path).read_text())
234
+
235
+ def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]):
236
+ Path(path).write_text(json.dumps(meta))