Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						ff49a48
	
1
								Parent(s):
							
							a2337f4
								
add dependencies
Browse files- requirements.txt +8 -0
- tsr/__pycache__/system.cpython-310.pyc +0 -0
- tsr/__pycache__/utils.cpython-310.pyc +0 -0
- tsr/models/__pycache__/camera.cpython-310.pyc +0 -0
- tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
- tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
- tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
- tsr/models/isosurface.py +48 -0
- tsr/models/nerf_renderer.py +180 -0
- tsr/models/network_utils.py +124 -0
- tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
- tsr/models/tokenizers/image.py +67 -0
- tsr/models/tokenizers/triplane.py +45 -0
- tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
- tsr/models/transformer/attention.py +628 -0
- tsr/models/transformer/basic_transformer_block.py +314 -0
- tsr/models/transformer/transformer_1d.py +216 -0
- tsr/system.py +203 -0
- tsr/utils.py +492 -0
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            omegaconf==2.3.0
         | 
| 2 | 
            +
            Pillow==10.1.0
         | 
| 3 | 
            +
            einops==0.7.0
         | 
| 4 | 
            +
            git+https://github.com/tatsy/torchmcubes.git
         | 
| 5 | 
            +
            transformers==4.35.0
         | 
| 6 | 
            +
            trimesh==4.0.5
         | 
| 7 | 
            +
            rembg
         | 
| 8 | 
            +
            huggingface-hub
         | 
    	
        tsr/__pycache__/system.cpython-310.pyc
    ADDED
    
    | Binary file (5.41 kB). View file | 
|  | 
    	
        tsr/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (13.8 kB). View file | 
|  | 
    	
        tsr/models/__pycache__/camera.cpython-310.pyc
    ADDED
    
    | Binary file (1.48 kB). View file | 
|  | 
    	
        tsr/models/__pycache__/isosurface.cpython-310.pyc
    ADDED
    
    | Binary file (2.04 kB). View file | 
|  | 
    	
        tsr/models/__pycache__/nerf_renderer.cpython-310.pyc
    ADDED
    
    | Binary file (5.29 kB). View file | 
|  | 
    	
        tsr/models/__pycache__/network_utils.cpython-310.pyc
    ADDED
    
    | Binary file (3.42 kB). View file | 
|  | 
    	
        tsr/models/isosurface.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Callable, Optional, Tuple
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from torchmcubes import marching_cubes
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class IsosurfaceHelper(nn.Module):
         | 
| 10 | 
            +
                points_range: Tuple[float, float] = (0, 1)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                @property
         | 
| 13 | 
            +
                def grid_vertices(self) -> torch.FloatTensor:
         | 
| 14 | 
            +
                    raise NotImplementedError
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class MarchingCubeHelper(IsosurfaceHelper):
         | 
| 18 | 
            +
                def __init__(self, resolution: int) -> None:
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    self.resolution = resolution
         | 
| 21 | 
            +
                    self.mc_func: Callable = marching_cubes
         | 
| 22 | 
            +
                    self._grid_vertices: Optional[torch.FloatTensor] = None
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                @property
         | 
| 25 | 
            +
                def grid_vertices(self) -> torch.FloatTensor:
         | 
| 26 | 
            +
                    if self._grid_vertices is None:
         | 
| 27 | 
            +
                        # keep the vertices on CPU so that we can support very large resolution
         | 
| 28 | 
            +
                        x, y, z = (
         | 
| 29 | 
            +
                            torch.linspace(*self.points_range, self.resolution),
         | 
| 30 | 
            +
                            torch.linspace(*self.points_range, self.resolution),
         | 
| 31 | 
            +
                            torch.linspace(*self.points_range, self.resolution),
         | 
| 32 | 
            +
                        )
         | 
| 33 | 
            +
                        x, y, z = torch.meshgrid(x, y, z, indexing="ij")
         | 
| 34 | 
            +
                        verts = torch.cat(
         | 
| 35 | 
            +
                            [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
         | 
| 36 | 
            +
                        ).reshape(-1, 3)
         | 
| 37 | 
            +
                        self._grid_vertices = verts
         | 
| 38 | 
            +
                    return self._grid_vertices
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def forward(
         | 
| 41 | 
            +
                    self,
         | 
| 42 | 
            +
                    level: torch.FloatTensor,
         | 
| 43 | 
            +
                ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
         | 
| 44 | 
            +
                    level = -level.view(self.resolution, self.resolution, self.resolution)
         | 
| 45 | 
            +
                    v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
         | 
| 46 | 
            +
                    v_pos = v_pos[..., [2, 1, 0]]
         | 
| 47 | 
            +
                    v_pos = v_pos / (self.resolution - 1.0)
         | 
| 48 | 
            +
                    return v_pos, t_pos_idx
         | 
    	
        tsr/models/nerf_renderer.py
    ADDED
    
    | @@ -0,0 +1,180 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Dict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from einops import rearrange, reduce
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..utils import (
         | 
| 9 | 
            +
                BaseModule,
         | 
| 10 | 
            +
                chunk_batch,
         | 
| 11 | 
            +
                get_activation,
         | 
| 12 | 
            +
                rays_intersect_bbox,
         | 
| 13 | 
            +
                scale_tensor,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class TriplaneNeRFRenderer(BaseModule):
         | 
| 18 | 
            +
                @dataclass
         | 
| 19 | 
            +
                class Config(BaseModule.Config):
         | 
| 20 | 
            +
                    radius: float
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    feature_reduction: str = "concat"
         | 
| 23 | 
            +
                    density_activation: str = "trunc_exp"
         | 
| 24 | 
            +
                    density_bias: float = -1.0
         | 
| 25 | 
            +
                    color_activation: str = "sigmoid"
         | 
| 26 | 
            +
                    num_samples_per_ray: int = 128
         | 
| 27 | 
            +
                    randomized: bool = False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                cfg: Config
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def configure(self) -> None:
         | 
| 32 | 
            +
                    assert self.cfg.feature_reduction in ["concat", "mean"]
         | 
| 33 | 
            +
                    self.chunk_size = 0
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def set_chunk_size(self, chunk_size: int):
         | 
| 36 | 
            +
                    assert (
         | 
| 37 | 
            +
                        chunk_size >= 0
         | 
| 38 | 
            +
                    ), "chunk_size must be a non-negative integer (0 for no chunking)."
         | 
| 39 | 
            +
                    self.chunk_size = chunk_size
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def query_triplane(
         | 
| 42 | 
            +
                    self,
         | 
| 43 | 
            +
                    decoder: torch.nn.Module,
         | 
| 44 | 
            +
                    positions: torch.Tensor,
         | 
| 45 | 
            +
                    triplane: torch.Tensor,
         | 
| 46 | 
            +
                ) -> Dict[str, torch.Tensor]:
         | 
| 47 | 
            +
                    input_shape = positions.shape[:-1]
         | 
| 48 | 
            +
                    positions = positions.view(-1, 3)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # positions in (-radius, radius)
         | 
| 51 | 
            +
                    # normalized to (-1, 1) for grid sample
         | 
| 52 | 
            +
                    positions = scale_tensor(
         | 
| 53 | 
            +
                        positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    def _query_chunk(x):
         | 
| 57 | 
            +
                        indices2D: torch.Tensor = torch.stack(
         | 
| 58 | 
            +
                            (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
         | 
| 59 | 
            +
                            dim=-3,
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
                        out: torch.Tensor = F.grid_sample(
         | 
| 62 | 
            +
                            rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
         | 
| 63 | 
            +
                            rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
         | 
| 64 | 
            +
                            align_corners=False,
         | 
| 65 | 
            +
                            mode="bilinear",
         | 
| 66 | 
            +
                        )
         | 
| 67 | 
            +
                        if self.cfg.feature_reduction == "concat":
         | 
| 68 | 
            +
                            out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
         | 
| 69 | 
            +
                        elif self.cfg.feature_reduction == "mean":
         | 
| 70 | 
            +
                            out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
         | 
| 71 | 
            +
                        else:
         | 
| 72 | 
            +
                            raise NotImplementedError
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                        net_out: Dict[str, torch.Tensor] = decoder(out)
         | 
| 75 | 
            +
                        return net_out
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if self.chunk_size > 0:
         | 
| 78 | 
            +
                        net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        net_out = _query_chunk(positions)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    net_out["density_act"] = get_activation(self.cfg.density_activation)(
         | 
| 83 | 
            +
                        net_out["density"] + self.cfg.density_bias
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    net_out["color"] = get_activation(self.cfg.color_activation)(
         | 
| 86 | 
            +
                        net_out["features"]
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    return net_out
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def _forward(
         | 
| 94 | 
            +
                    self,
         | 
| 95 | 
            +
                    decoder: torch.nn.Module,
         | 
| 96 | 
            +
                    triplane: torch.Tensor,
         | 
| 97 | 
            +
                    rays_o: torch.Tensor,
         | 
| 98 | 
            +
                    rays_d: torch.Tensor,
         | 
| 99 | 
            +
                    **kwargs,
         | 
| 100 | 
            +
                ):
         | 
| 101 | 
            +
                    rays_shape = rays_o.shape[:-1]
         | 
| 102 | 
            +
                    rays_o = rays_o.view(-1, 3)
         | 
| 103 | 
            +
                    rays_d = rays_d.view(-1, 3)
         | 
| 104 | 
            +
                    n_rays = rays_o.shape[0]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
         | 
| 107 | 
            +
                    t_near, t_far = t_near[rays_valid], t_far[rays_valid]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    t_vals = torch.linspace(
         | 
| 110 | 
            +
                        0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
         | 
| 113 | 
            +
                    z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None]  # (N_rays, N_samples)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    xyz = (
         | 
| 116 | 
            +
                        rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
         | 
| 117 | 
            +
                    )  # (N_rays, N_sample, 3)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    mlp_out = self.query_triplane(
         | 
| 120 | 
            +
                        decoder=decoder,
         | 
| 121 | 
            +
                        positions=xyz,
         | 
| 122 | 
            +
                        triplane=triplane,
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    eps = 1e-10
         | 
| 126 | 
            +
                    # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
         | 
| 127 | 
            +
                    deltas = t_vals[1:] - t_vals[:-1]  # (N_rays, N_samples)
         | 
| 128 | 
            +
                    alpha = 1 - torch.exp(
         | 
| 129 | 
            +
                        -deltas * mlp_out["density_act"][..., 0]
         | 
| 130 | 
            +
                    )  # (N_rays, N_samples)
         | 
| 131 | 
            +
                    accum_prod = torch.cat(
         | 
| 132 | 
            +
                        [
         | 
| 133 | 
            +
                            torch.ones_like(alpha[:, :1]),
         | 
| 134 | 
            +
                            torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
         | 
| 135 | 
            +
                        ],
         | 
| 136 | 
            +
                        dim=-1,
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
                    weights = alpha * accum_prod  # (N_rays, N_samples)
         | 
| 139 | 
            +
                    comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2)  # (N_rays, 3)
         | 
| 140 | 
            +
                    opacity_ = weights.sum(dim=-1)  # (N_rays)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    comp_rgb = torch.zeros(
         | 
| 143 | 
            +
                        n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
         | 
| 146 | 
            +
                    comp_rgb[rays_valid] = comp_rgb_
         | 
| 147 | 
            +
                    opacity[rays_valid] = opacity_
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    comp_rgb += 1 - opacity[..., None]
         | 
| 150 | 
            +
                    comp_rgb = comp_rgb.view(*rays_shape, 3)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    return comp_rgb
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def forward(
         | 
| 155 | 
            +
                    self,
         | 
| 156 | 
            +
                    decoder: torch.nn.Module,
         | 
| 157 | 
            +
                    triplane: torch.Tensor,
         | 
| 158 | 
            +
                    rays_o: torch.Tensor,
         | 
| 159 | 
            +
                    rays_d: torch.Tensor,
         | 
| 160 | 
            +
                ) -> Dict[str, torch.Tensor]:
         | 
| 161 | 
            +
                    if triplane.ndim == 4:
         | 
| 162 | 
            +
                        comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        comp_rgb = torch.stack(
         | 
| 165 | 
            +
                            [
         | 
| 166 | 
            +
                                self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
         | 
| 167 | 
            +
                                for i in range(triplane.shape[0])
         | 
| 168 | 
            +
                            ],
         | 
| 169 | 
            +
                            dim=0,
         | 
| 170 | 
            +
                        )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    return comp_rgb
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def train(self, mode=True):
         | 
| 175 | 
            +
                    self.randomized = mode and self.cfg.randomized
         | 
| 176 | 
            +
                    return super().train(mode=mode)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def eval(self):
         | 
| 179 | 
            +
                    self.randomized = False
         | 
| 180 | 
            +
                    return super().eval()
         | 
    	
        tsr/models/network_utils.py
    ADDED
    
    | @@ -0,0 +1,124 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..utils import BaseModule
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class TriplaneUpsampleNetwork(BaseModule):
         | 
| 12 | 
            +
                @dataclass
         | 
| 13 | 
            +
                class Config(BaseModule.Config):
         | 
| 14 | 
            +
                    in_channels: int
         | 
| 15 | 
            +
                    out_channels: int
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                cfg: Config
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def configure(self) -> None:
         | 
| 20 | 
            +
                    self.upsample = nn.ConvTranspose2d(
         | 
| 21 | 
            +
                        self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
         | 
| 25 | 
            +
                    triplanes_up = rearrange(
         | 
| 26 | 
            +
                        self.upsample(
         | 
| 27 | 
            +
                            rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
         | 
| 28 | 
            +
                        ),
         | 
| 29 | 
            +
                        "(B Np) Co Hp Wp -> B Np Co Hp Wp",
         | 
| 30 | 
            +
                        Np=3,
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
                    return triplanes_up
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class NeRFMLP(BaseModule):
         | 
| 36 | 
            +
                @dataclass
         | 
| 37 | 
            +
                class Config(BaseModule.Config):
         | 
| 38 | 
            +
                    in_channels: int
         | 
| 39 | 
            +
                    n_neurons: int
         | 
| 40 | 
            +
                    n_hidden_layers: int
         | 
| 41 | 
            +
                    activation: str = "relu"
         | 
| 42 | 
            +
                    bias: bool = True
         | 
| 43 | 
            +
                    weight_init: Optional[str] = "kaiming_uniform"
         | 
| 44 | 
            +
                    bias_init: Optional[str] = None
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                cfg: Config
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def configure(self) -> None:
         | 
| 49 | 
            +
                    layers = [
         | 
| 50 | 
            +
                        self.make_linear(
         | 
| 51 | 
            +
                            self.cfg.in_channels,
         | 
| 52 | 
            +
                            self.cfg.n_neurons,
         | 
| 53 | 
            +
                            bias=self.cfg.bias,
         | 
| 54 | 
            +
                            weight_init=self.cfg.weight_init,
         | 
| 55 | 
            +
                            bias_init=self.cfg.bias_init,
         | 
| 56 | 
            +
                        ),
         | 
| 57 | 
            +
                        self.make_activation(self.cfg.activation),
         | 
| 58 | 
            +
                    ]
         | 
| 59 | 
            +
                    for i in range(self.cfg.n_hidden_layers - 1):
         | 
| 60 | 
            +
                        layers += [
         | 
| 61 | 
            +
                            self.make_linear(
         | 
| 62 | 
            +
                                self.cfg.n_neurons,
         | 
| 63 | 
            +
                                self.cfg.n_neurons,
         | 
| 64 | 
            +
                                bias=self.cfg.bias,
         | 
| 65 | 
            +
                                weight_init=self.cfg.weight_init,
         | 
| 66 | 
            +
                                bias_init=self.cfg.bias_init,
         | 
| 67 | 
            +
                            ),
         | 
| 68 | 
            +
                            self.make_activation(self.cfg.activation),
         | 
| 69 | 
            +
                        ]
         | 
| 70 | 
            +
                    layers += [
         | 
| 71 | 
            +
                        self.make_linear(
         | 
| 72 | 
            +
                            self.cfg.n_neurons,
         | 
| 73 | 
            +
                            4,  # density 1 + features 3
         | 
| 74 | 
            +
                            bias=self.cfg.bias,
         | 
| 75 | 
            +
                            weight_init=self.cfg.weight_init,
         | 
| 76 | 
            +
                            bias_init=self.cfg.bias_init,
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
                    ]
         | 
| 79 | 
            +
                    self.layers = nn.Sequential(*layers)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def make_linear(
         | 
| 82 | 
            +
                    self,
         | 
| 83 | 
            +
                    dim_in,
         | 
| 84 | 
            +
                    dim_out,
         | 
| 85 | 
            +
                    bias=True,
         | 
| 86 | 
            +
                    weight_init=None,
         | 
| 87 | 
            +
                    bias_init=None,
         | 
| 88 | 
            +
                ):
         | 
| 89 | 
            +
                    layer = nn.Linear(dim_in, dim_out, bias=bias)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if weight_init is None:
         | 
| 92 | 
            +
                        pass
         | 
| 93 | 
            +
                    elif weight_init == "kaiming_uniform":
         | 
| 94 | 
            +
                        torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        raise NotImplementedError
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if bias:
         | 
| 99 | 
            +
                        if bias_init is None:
         | 
| 100 | 
            +
                            pass
         | 
| 101 | 
            +
                        elif bias_init == "zero":
         | 
| 102 | 
            +
                            torch.nn.init.zeros_(layer.bias)
         | 
| 103 | 
            +
                        else:
         | 
| 104 | 
            +
                            raise NotImplementedError
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    return layer
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def make_activation(self, activation):
         | 
| 109 | 
            +
                    if activation == "relu":
         | 
| 110 | 
            +
                        return nn.ReLU(inplace=True)
         | 
| 111 | 
            +
                    elif activation == "silu":
         | 
| 112 | 
            +
                        return nn.SiLU(inplace=True)
         | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        raise NotImplementedError
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def forward(self, x):
         | 
| 117 | 
            +
                    inp_shape = x.shape[:-1]
         | 
| 118 | 
            +
                    x = x.reshape(-1, x.shape[-1])
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    features = self.layers(x)
         | 
| 121 | 
            +
                    features = features.reshape(*inp_shape, -1)
         | 
| 122 | 
            +
                    out = {"density": features[..., 0:1], "features": features[..., 1:4]}
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    return out
         | 
    	
        tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc
    ADDED
    
    | Binary file (18.6 kB). View file | 
|  | 
    	
        tsr/models/tokenizers/__pycache__/image.cpython-310.pyc
    ADDED
    
    | Binary file (2.42 kB). View file | 
|  | 
    	
        tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc
    ADDED
    
    | Binary file (1.76 kB). View file | 
|  | 
    	
        tsr/models/tokenizers/image.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 8 | 
            +
            from transformers.models.vit.modeling_vit import ViTModel
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from ...utils import BaseModule
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class DINOSingleImageTokenizer(BaseModule):
         | 
| 14 | 
            +
                @dataclass
         | 
| 15 | 
            +
                class Config(BaseModule.Config):
         | 
| 16 | 
            +
                    pretrained_model_name_or_path: str = "facebook/dino-vitb16"
         | 
| 17 | 
            +
                    enable_gradient_checkpointing: bool = False
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                cfg: Config
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def configure(self) -> None:
         | 
| 22 | 
            +
                    self.model: ViTModel = ViTModel(
         | 
| 23 | 
            +
                        ViTModel.config_class.from_pretrained(
         | 
| 24 | 
            +
                            hf_hub_download(
         | 
| 25 | 
            +
                                repo_id=self.cfg.pretrained_model_name_or_path,
         | 
| 26 | 
            +
                                filename="config.json",
         | 
| 27 | 
            +
                            )
         | 
| 28 | 
            +
                        )
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    if self.cfg.enable_gradient_checkpointing:
         | 
| 32 | 
            +
                        self.model.encoder.gradient_checkpointing = True
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.register_buffer(
         | 
| 35 | 
            +
                        "image_mean",
         | 
| 36 | 
            +
                        torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
         | 
| 37 | 
            +
                        persistent=False,
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    self.register_buffer(
         | 
| 40 | 
            +
                        "image_std",
         | 
| 41 | 
            +
                        torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
         | 
| 42 | 
            +
                        persistent=False,
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
         | 
| 46 | 
            +
                    packed = False
         | 
| 47 | 
            +
                    if images.ndim == 4:
         | 
| 48 | 
            +
                        packed = True
         | 
| 49 | 
            +
                        images = images.unsqueeze(1)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    batch_size, n_input_views = images.shape[:2]
         | 
| 52 | 
            +
                    images = (images - self.image_mean) / self.image_std
         | 
| 53 | 
            +
                    out = self.model(
         | 
| 54 | 
            +
                        rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    local_features, global_features = out.last_hidden_state, out.pooler_output
         | 
| 57 | 
            +
                    local_features = local_features.permute(0, 2, 1)
         | 
| 58 | 
            +
                    local_features = rearrange(
         | 
| 59 | 
            +
                        local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    if packed:
         | 
| 62 | 
            +
                        local_features = local_features.squeeze(1)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    return local_features
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def detokenize(self, *args, **kwargs):
         | 
| 67 | 
            +
                    raise NotImplementedError
         | 
    	
        tsr/models/tokenizers/triplane.py
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange, repeat
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ...utils import BaseModule
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class Triplane1DTokenizer(BaseModule):
         | 
| 12 | 
            +
                @dataclass
         | 
| 13 | 
            +
                class Config(BaseModule.Config):
         | 
| 14 | 
            +
                    plane_size: int
         | 
| 15 | 
            +
                    num_channels: int
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                cfg: Config
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def configure(self) -> None:
         | 
| 20 | 
            +
                    self.embeddings = nn.Parameter(
         | 
| 21 | 
            +
                        torch.randn(
         | 
| 22 | 
            +
                            (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
         | 
| 23 | 
            +
                            dtype=torch.float32,
         | 
| 24 | 
            +
                        )
         | 
| 25 | 
            +
                        * 1
         | 
| 26 | 
            +
                        / math.sqrt(self.cfg.num_channels)
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, batch_size: int) -> torch.Tensor:
         | 
| 30 | 
            +
                    return rearrange(
         | 
| 31 | 
            +
                        repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
         | 
| 32 | 
            +
                        "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
         | 
| 36 | 
            +
                    batch_size, Ct, Nt = tokens.shape
         | 
| 37 | 
            +
                    assert Nt == self.cfg.plane_size**2 * 3
         | 
| 38 | 
            +
                    assert Ct == self.cfg.num_channels
         | 
| 39 | 
            +
                    return rearrange(
         | 
| 40 | 
            +
                        tokens,
         | 
| 41 | 
            +
                        "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
         | 
| 42 | 
            +
                        Np=3,
         | 
| 43 | 
            +
                        Hp=self.cfg.plane_size,
         | 
| 44 | 
            +
                        Wp=self.cfg.plane_size,
         | 
| 45 | 
            +
                    )
         | 
    	
        tsr/models/transformer/__pycache__/attention.cpython-310.pyc
    ADDED
    
    | Binary file (15.3 kB). View file | 
|  | 
    	
        tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc
    ADDED
    
    | Binary file (9.96 kB). View file | 
|  | 
    	
        tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc
    ADDED
    
    | Binary file (7.47 kB). View file | 
|  | 
    	
        tsr/models/transformer/attention.py
    ADDED
    
    | @@ -0,0 +1,628 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from typing import Optional
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn.functional as F
         | 
| 18 | 
            +
            from torch import nn
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Attention(nn.Module):
         | 
| 22 | 
            +
                r"""
         | 
| 23 | 
            +
                A cross attention layer.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Parameters:
         | 
| 26 | 
            +
                    query_dim (`int`):
         | 
| 27 | 
            +
                        The number of channels in the query.
         | 
| 28 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 29 | 
            +
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         | 
| 30 | 
            +
                    heads (`int`,  *optional*, defaults to 8):
         | 
| 31 | 
            +
                        The number of heads to use for multi-head attention.
         | 
| 32 | 
            +
                    dim_head (`int`,  *optional*, defaults to 64):
         | 
| 33 | 
            +
                        The number of channels in each head.
         | 
| 34 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0):
         | 
| 35 | 
            +
                        The dropout probability to use.
         | 
| 36 | 
            +
                    bias (`bool`, *optional*, defaults to False):
         | 
| 37 | 
            +
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         | 
| 38 | 
            +
                    upcast_attention (`bool`, *optional*, defaults to False):
         | 
| 39 | 
            +
                        Set to `True` to upcast the attention computation to `float32`.
         | 
| 40 | 
            +
                    upcast_softmax (`bool`, *optional*, defaults to False):
         | 
| 41 | 
            +
                        Set to `True` to upcast the softmax computation to `float32`.
         | 
| 42 | 
            +
                    cross_attention_norm (`str`, *optional*, defaults to `None`):
         | 
| 43 | 
            +
                        The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
         | 
| 44 | 
            +
                    cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 45 | 
            +
                        The number of groups to use for the group norm in the cross attention.
         | 
| 46 | 
            +
                    added_kv_proj_dim (`int`, *optional*, defaults to `None`):
         | 
| 47 | 
            +
                        The number of channels to use for the added key and value projections. If `None`, no projection is used.
         | 
| 48 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to `None`):
         | 
| 49 | 
            +
                        The number of groups to use for the group norm in the attention.
         | 
| 50 | 
            +
                    spatial_norm_dim (`int`, *optional*, defaults to `None`):
         | 
| 51 | 
            +
                        The number of channels to use for the spatial normalization.
         | 
| 52 | 
            +
                    out_bias (`bool`, *optional*, defaults to `True`):
         | 
| 53 | 
            +
                        Set to `True` to use a bias in the output linear layer.
         | 
| 54 | 
            +
                    scale_qk (`bool`, *optional*, defaults to `True`):
         | 
| 55 | 
            +
                        Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
         | 
| 56 | 
            +
                    only_cross_attention (`bool`, *optional*, defaults to `False`):
         | 
| 57 | 
            +
                        Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
         | 
| 58 | 
            +
                        `added_kv_proj_dim` is not `None`.
         | 
| 59 | 
            +
                    eps (`float`, *optional*, defaults to 1e-5):
         | 
| 60 | 
            +
                        An additional value added to the denominator in group normalization that is used for numerical stability.
         | 
| 61 | 
            +
                    rescale_output_factor (`float`, *optional*, defaults to 1.0):
         | 
| 62 | 
            +
                        A factor to rescale the output by dividing it with this value.
         | 
| 63 | 
            +
                    residual_connection (`bool`, *optional*, defaults to `False`):
         | 
| 64 | 
            +
                        Set to `True` to add the residual connection to the output.
         | 
| 65 | 
            +
                    _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
         | 
| 66 | 
            +
                        Set to `True` if the attention block is loaded from a deprecated state dict.
         | 
| 67 | 
            +
                    processor (`AttnProcessor`, *optional*, defaults to `None`):
         | 
| 68 | 
            +
                        The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
         | 
| 69 | 
            +
                        `AttnProcessor` otherwise.
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __init__(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    query_dim: int,
         | 
| 75 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 76 | 
            +
                    heads: int = 8,
         | 
| 77 | 
            +
                    dim_head: int = 64,
         | 
| 78 | 
            +
                    dropout: float = 0.0,
         | 
| 79 | 
            +
                    bias: bool = False,
         | 
| 80 | 
            +
                    upcast_attention: bool = False,
         | 
| 81 | 
            +
                    upcast_softmax: bool = False,
         | 
| 82 | 
            +
                    cross_attention_norm: Optional[str] = None,
         | 
| 83 | 
            +
                    cross_attention_norm_num_groups: int = 32,
         | 
| 84 | 
            +
                    added_kv_proj_dim: Optional[int] = None,
         | 
| 85 | 
            +
                    norm_num_groups: Optional[int] = None,
         | 
| 86 | 
            +
                    out_bias: bool = True,
         | 
| 87 | 
            +
                    scale_qk: bool = True,
         | 
| 88 | 
            +
                    only_cross_attention: bool = False,
         | 
| 89 | 
            +
                    eps: float = 1e-5,
         | 
| 90 | 
            +
                    rescale_output_factor: float = 1.0,
         | 
| 91 | 
            +
                    residual_connection: bool = False,
         | 
| 92 | 
            +
                    _from_deprecated_attn_block: bool = False,
         | 
| 93 | 
            +
                    processor: Optional["AttnProcessor"] = None,
         | 
| 94 | 
            +
                    out_dim: int = None,
         | 
| 95 | 
            +
                ):
         | 
| 96 | 
            +
                    super().__init__()
         | 
| 97 | 
            +
                    self.inner_dim = out_dim if out_dim is not None else dim_head * heads
         | 
| 98 | 
            +
                    self.query_dim = query_dim
         | 
| 99 | 
            +
                    self.cross_attention_dim = (
         | 
| 100 | 
            +
                        cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    self.upcast_attention = upcast_attention
         | 
| 103 | 
            +
                    self.upcast_softmax = upcast_softmax
         | 
| 104 | 
            +
                    self.rescale_output_factor = rescale_output_factor
         | 
| 105 | 
            +
                    self.residual_connection = residual_connection
         | 
| 106 | 
            +
                    self.dropout = dropout
         | 
| 107 | 
            +
                    self.fused_projections = False
         | 
| 108 | 
            +
                    self.out_dim = out_dim if out_dim is not None else query_dim
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # we make use of this private variable to know whether this class is loaded
         | 
| 111 | 
            +
                    # with an deprecated state dict so that we can convert it on the fly
         | 
| 112 | 
            +
                    self._from_deprecated_attn_block = _from_deprecated_attn_block
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.scale_qk = scale_qk
         | 
| 115 | 
            +
                    self.scale = dim_head**-0.5 if self.scale_qk else 1.0
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.heads = out_dim // dim_head if out_dim is not None else heads
         | 
| 118 | 
            +
                    # for slice_size > 0 the attention score computation
         | 
| 119 | 
            +
                    # is split across the batch axis to save memory
         | 
| 120 | 
            +
                    # You can set slice_size with `set_attention_slice`
         | 
| 121 | 
            +
                    self.sliceable_head_dim = heads
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         | 
| 124 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    if self.added_kv_proj_dim is None and self.only_cross_attention:
         | 
| 127 | 
            +
                        raise ValueError(
         | 
| 128 | 
            +
                            "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
         | 
| 129 | 
            +
                        )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if norm_num_groups is not None:
         | 
| 132 | 
            +
                        self.group_norm = nn.GroupNorm(
         | 
| 133 | 
            +
                            num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        self.group_norm = None
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.spatial_norm = None
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if cross_attention_norm is None:
         | 
| 141 | 
            +
                        self.norm_cross = None
         | 
| 142 | 
            +
                    elif cross_attention_norm == "layer_norm":
         | 
| 143 | 
            +
                        self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
         | 
| 144 | 
            +
                    elif cross_attention_norm == "group_norm":
         | 
| 145 | 
            +
                        if self.added_kv_proj_dim is not None:
         | 
| 146 | 
            +
                            # The given `encoder_hidden_states` are initially of shape
         | 
| 147 | 
            +
                            # (batch_size, seq_len, added_kv_proj_dim) before being projected
         | 
| 148 | 
            +
                            # to (batch_size, seq_len, cross_attention_dim). The norm is applied
         | 
| 149 | 
            +
                            # before the projection, so we need to use `added_kv_proj_dim` as
         | 
| 150 | 
            +
                            # the number of channels for the group norm.
         | 
| 151 | 
            +
                            norm_cross_num_channels = added_kv_proj_dim
         | 
| 152 | 
            +
                        else:
         | 
| 153 | 
            +
                            norm_cross_num_channels = self.cross_attention_dim
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        self.norm_cross = nn.GroupNorm(
         | 
| 156 | 
            +
                            num_channels=norm_cross_num_channels,
         | 
| 157 | 
            +
                            num_groups=cross_attention_norm_num_groups,
         | 
| 158 | 
            +
                            eps=1e-5,
         | 
| 159 | 
            +
                            affine=True,
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                    else:
         | 
| 162 | 
            +
                        raise ValueError(
         | 
| 163 | 
            +
                            f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    linear_cls = nn.Linear
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    self.linear_cls = linear_cls
         | 
| 169 | 
            +
                    self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    if not self.only_cross_attention:
         | 
| 172 | 
            +
                        # only relevant for the `AddedKVProcessor` classes
         | 
| 173 | 
            +
                        self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 174 | 
            +
                        self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 175 | 
            +
                    else:
         | 
| 176 | 
            +
                        self.to_k = None
         | 
| 177 | 
            +
                        self.to_v = None
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    if self.added_kv_proj_dim is not None:
         | 
| 180 | 
            +
                        self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         | 
| 181 | 
            +
                        self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 184 | 
            +
                    self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
         | 
| 185 | 
            +
                    self.to_out.append(nn.Dropout(dropout))
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # set attention processor
         | 
| 188 | 
            +
                    # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         | 
| 189 | 
            +
                    # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         | 
| 190 | 
            +
                    # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         | 
| 191 | 
            +
                    if processor is None:
         | 
| 192 | 
            +
                        processor = (
         | 
| 193 | 
            +
                            AttnProcessor2_0()
         | 
| 194 | 
            +
                            if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
         | 
| 195 | 
            +
                            else AttnProcessor()
         | 
| 196 | 
            +
                        )
         | 
| 197 | 
            +
                    self.set_processor(processor)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def set_processor(self, processor: "AttnProcessor") -> None:
         | 
| 200 | 
            +
                    self.processor = processor
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(
         | 
| 203 | 
            +
                    self,
         | 
| 204 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 205 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 206 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 207 | 
            +
                    **cross_attention_kwargs,
         | 
| 208 | 
            +
                ) -> torch.Tensor:
         | 
| 209 | 
            +
                    r"""
         | 
| 210 | 
            +
                    The forward method of the `Attention` class.
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    Args:
         | 
| 213 | 
            +
                        hidden_states (`torch.Tensor`):
         | 
| 214 | 
            +
                            The hidden states of the query.
         | 
| 215 | 
            +
                        encoder_hidden_states (`torch.Tensor`, *optional*):
         | 
| 216 | 
            +
                            The hidden states of the encoder.
         | 
| 217 | 
            +
                        attention_mask (`torch.Tensor`, *optional*):
         | 
| 218 | 
            +
                            The attention mask to use. If `None`, no mask is applied.
         | 
| 219 | 
            +
                        **cross_attention_kwargs:
         | 
| 220 | 
            +
                            Additional keyword arguments to pass along to the cross attention.
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    Returns:
         | 
| 223 | 
            +
                        `torch.Tensor`: The output of the attention layer.
         | 
| 224 | 
            +
                    """
         | 
| 225 | 
            +
                    # The `Attention` class can call different attention processors / attention functions
         | 
| 226 | 
            +
                    # here we simply pass along all tensors to the selected processor class
         | 
| 227 | 
            +
                    # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         | 
| 228 | 
            +
                    return self.processor(
         | 
| 229 | 
            +
                        self,
         | 
| 230 | 
            +
                        hidden_states,
         | 
| 231 | 
            +
                        encoder_hidden_states=encoder_hidden_states,
         | 
| 232 | 
            +
                        attention_mask=attention_mask,
         | 
| 233 | 
            +
                        **cross_attention_kwargs,
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
         | 
| 237 | 
            +
                    r"""
         | 
| 238 | 
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
         | 
| 239 | 
            +
                    is the number of heads initialized while constructing the `Attention` class.
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    Args:
         | 
| 242 | 
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    Returns:
         | 
| 245 | 
            +
                        `torch.Tensor`: The reshaped tensor.
         | 
| 246 | 
            +
                    """
         | 
| 247 | 
            +
                    head_size = self.heads
         | 
| 248 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 249 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 250 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(
         | 
| 251 | 
            +
                        batch_size // head_size, seq_len, dim * head_size
         | 
| 252 | 
            +
                    )
         | 
| 253 | 
            +
                    return tensor
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
         | 
| 256 | 
            +
                    r"""
         | 
| 257 | 
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
         | 
| 258 | 
            +
                    the number of heads initialized while constructing the `Attention` class.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    Args:
         | 
| 261 | 
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         | 
| 262 | 
            +
                        out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
         | 
| 263 | 
            +
                            reshaped to `[batch_size * heads, seq_len, dim // heads]`.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    Returns:
         | 
| 266 | 
            +
                        `torch.Tensor`: The reshaped tensor.
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    head_size = self.heads
         | 
| 269 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 270 | 
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         | 
| 271 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if out_dim == 3:
         | 
| 274 | 
            +
                        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    return tensor
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                def get_attention_scores(
         | 
| 279 | 
            +
                    self,
         | 
| 280 | 
            +
                    query: torch.Tensor,
         | 
| 281 | 
            +
                    key: torch.Tensor,
         | 
| 282 | 
            +
                    attention_mask: torch.Tensor = None,
         | 
| 283 | 
            +
                ) -> torch.Tensor:
         | 
| 284 | 
            +
                    r"""
         | 
| 285 | 
            +
                    Compute the attention scores.
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    Args:
         | 
| 288 | 
            +
                        query (`torch.Tensor`): The query tensor.
         | 
| 289 | 
            +
                        key (`torch.Tensor`): The key tensor.
         | 
| 290 | 
            +
                        attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    Returns:
         | 
| 293 | 
            +
                        `torch.Tensor`: The attention probabilities/scores.
         | 
| 294 | 
            +
                    """
         | 
| 295 | 
            +
                    dtype = query.dtype
         | 
| 296 | 
            +
                    if self.upcast_attention:
         | 
| 297 | 
            +
                        query = query.float()
         | 
| 298 | 
            +
                        key = key.float()
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    if attention_mask is None:
         | 
| 301 | 
            +
                        baddbmm_input = torch.empty(
         | 
| 302 | 
            +
                            query.shape[0],
         | 
| 303 | 
            +
                            query.shape[1],
         | 
| 304 | 
            +
                            key.shape[1],
         | 
| 305 | 
            +
                            dtype=query.dtype,
         | 
| 306 | 
            +
                            device=query.device,
         | 
| 307 | 
            +
                        )
         | 
| 308 | 
            +
                        beta = 0
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        baddbmm_input = attention_mask
         | 
| 311 | 
            +
                        beta = 1
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    attention_scores = torch.baddbmm(
         | 
| 314 | 
            +
                        baddbmm_input,
         | 
| 315 | 
            +
                        query,
         | 
| 316 | 
            +
                        key.transpose(-1, -2),
         | 
| 317 | 
            +
                        beta=beta,
         | 
| 318 | 
            +
                        alpha=self.scale,
         | 
| 319 | 
            +
                    )
         | 
| 320 | 
            +
                    del baddbmm_input
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    if self.upcast_softmax:
         | 
| 323 | 
            +
                        attention_scores = attention_scores.float()
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         | 
| 326 | 
            +
                    del attention_scores
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    attention_probs = attention_probs.to(dtype)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    return attention_probs
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                def prepare_attention_mask(
         | 
| 333 | 
            +
                    self,
         | 
| 334 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 335 | 
            +
                    target_length: int,
         | 
| 336 | 
            +
                    batch_size: int,
         | 
| 337 | 
            +
                    out_dim: int = 3,
         | 
| 338 | 
            +
                ) -> torch.Tensor:
         | 
| 339 | 
            +
                    r"""
         | 
| 340 | 
            +
                    Prepare the attention mask for the attention computation.
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    Args:
         | 
| 343 | 
            +
                        attention_mask (`torch.Tensor`):
         | 
| 344 | 
            +
                            The attention mask to prepare.
         | 
| 345 | 
            +
                        target_length (`int`):
         | 
| 346 | 
            +
                            The target length of the attention mask. This is the length of the attention mask after padding.
         | 
| 347 | 
            +
                        batch_size (`int`):
         | 
| 348 | 
            +
                            The batch size, which is used to repeat the attention mask.
         | 
| 349 | 
            +
                        out_dim (`int`, *optional*, defaults to `3`):
         | 
| 350 | 
            +
                            The output dimension of the attention mask. Can be either `3` or `4`.
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    Returns:
         | 
| 353 | 
            +
                        `torch.Tensor`: The prepared attention mask.
         | 
| 354 | 
            +
                    """
         | 
| 355 | 
            +
                    head_size = self.heads
         | 
| 356 | 
            +
                    if attention_mask is None:
         | 
| 357 | 
            +
                        return attention_mask
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    current_length: int = attention_mask.shape[-1]
         | 
| 360 | 
            +
                    if current_length != target_length:
         | 
| 361 | 
            +
                        if attention_mask.device.type == "mps":
         | 
| 362 | 
            +
                            # HACK: MPS: Does not support padding by greater than dimension of input tensor.
         | 
| 363 | 
            +
                            # Instead, we can manually construct the padding tensor.
         | 
| 364 | 
            +
                            padding_shape = (
         | 
| 365 | 
            +
                                attention_mask.shape[0],
         | 
| 366 | 
            +
                                attention_mask.shape[1],
         | 
| 367 | 
            +
                                target_length,
         | 
| 368 | 
            +
                            )
         | 
| 369 | 
            +
                            padding = torch.zeros(
         | 
| 370 | 
            +
                                padding_shape,
         | 
| 371 | 
            +
                                dtype=attention_mask.dtype,
         | 
| 372 | 
            +
                                device=attention_mask.device,
         | 
| 373 | 
            +
                            )
         | 
| 374 | 
            +
                            attention_mask = torch.cat([attention_mask, padding], dim=2)
         | 
| 375 | 
            +
                        else:
         | 
| 376 | 
            +
                            # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
         | 
| 377 | 
            +
                            #       we want to instead pad by (0, remaining_length), where remaining_length is:
         | 
| 378 | 
            +
                            #       remaining_length: int = target_length - current_length
         | 
| 379 | 
            +
                            # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
         | 
| 380 | 
            +
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    if out_dim == 3:
         | 
| 383 | 
            +
                        if attention_mask.shape[0] < batch_size * head_size:
         | 
| 384 | 
            +
                            attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
         | 
| 385 | 
            +
                    elif out_dim == 4:
         | 
| 386 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 387 | 
            +
                        attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    return attention_mask
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                def norm_encoder_hidden_states(
         | 
| 392 | 
            +
                    self, encoder_hidden_states: torch.Tensor
         | 
| 393 | 
            +
                ) -> torch.Tensor:
         | 
| 394 | 
            +
                    r"""
         | 
| 395 | 
            +
                    Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
         | 
| 396 | 
            +
                    `Attention` class.
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    Args:
         | 
| 399 | 
            +
                        encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    Returns:
         | 
| 402 | 
            +
                        `torch.Tensor`: The normalized encoder hidden states.
         | 
| 403 | 
            +
                    """
         | 
| 404 | 
            +
                    assert (
         | 
| 405 | 
            +
                        self.norm_cross is not None
         | 
| 406 | 
            +
                    ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    if isinstance(self.norm_cross, nn.LayerNorm):
         | 
| 409 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 410 | 
            +
                    elif isinstance(self.norm_cross, nn.GroupNorm):
         | 
| 411 | 
            +
                        # Group norm norms along the channels dimension and expects
         | 
| 412 | 
            +
                        # input to be in the shape of (N, C, *). In this case, we want
         | 
| 413 | 
            +
                        # to norm along the hidden dimension, so we need to move
         | 
| 414 | 
            +
                        # (batch_size, sequence_length, hidden_size) ->
         | 
| 415 | 
            +
                        # (batch_size, hidden_size, sequence_length)
         | 
| 416 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 417 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 418 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 419 | 
            +
                    else:
         | 
| 420 | 
            +
                        assert False
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    return encoder_hidden_states
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                @torch.no_grad()
         | 
| 425 | 
            +
                def fuse_projections(self, fuse=True):
         | 
| 426 | 
            +
                    is_cross_attention = self.cross_attention_dim != self.query_dim
         | 
| 427 | 
            +
                    device = self.to_q.weight.data.device
         | 
| 428 | 
            +
                    dtype = self.to_q.weight.data.dtype
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    if not is_cross_attention:
         | 
| 431 | 
            +
                        # fetch weight matrices.
         | 
| 432 | 
            +
                        concatenated_weights = torch.cat(
         | 
| 433 | 
            +
                            [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
         | 
| 434 | 
            +
                        )
         | 
| 435 | 
            +
                        in_features = concatenated_weights.shape[1]
         | 
| 436 | 
            +
                        out_features = concatenated_weights.shape[0]
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                        # create a new single projection layer and copy over the weights.
         | 
| 439 | 
            +
                        self.to_qkv = self.linear_cls(
         | 
| 440 | 
            +
                            in_features, out_features, bias=False, device=device, dtype=dtype
         | 
| 441 | 
            +
                        )
         | 
| 442 | 
            +
                        self.to_qkv.weight.copy_(concatenated_weights)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    else:
         | 
| 445 | 
            +
                        concatenated_weights = torch.cat(
         | 
| 446 | 
            +
                            [self.to_k.weight.data, self.to_v.weight.data]
         | 
| 447 | 
            +
                        )
         | 
| 448 | 
            +
                        in_features = concatenated_weights.shape[1]
         | 
| 449 | 
            +
                        out_features = concatenated_weights.shape[0]
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                        self.to_kv = self.linear_cls(
         | 
| 452 | 
            +
                            in_features, out_features, bias=False, device=device, dtype=dtype
         | 
| 453 | 
            +
                        )
         | 
| 454 | 
            +
                        self.to_kv.weight.copy_(concatenated_weights)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    self.fused_projections = fuse
         | 
| 457 | 
            +
             | 
| 458 | 
            +
             | 
| 459 | 
            +
            class AttnProcessor:
         | 
| 460 | 
            +
                r"""
         | 
| 461 | 
            +
                Default processor for performing attention-related computations.
         | 
| 462 | 
            +
                """
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                def __call__(
         | 
| 465 | 
            +
                    self,
         | 
| 466 | 
            +
                    attn: Attention,
         | 
| 467 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 468 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 469 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 470 | 
            +
                ) -> torch.Tensor:
         | 
| 471 | 
            +
                    residual = hidden_states
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    if input_ndim == 4:
         | 
| 476 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 477 | 
            +
                        hidden_states = hidden_states.view(
         | 
| 478 | 
            +
                            batch_size, channel, height * width
         | 
| 479 | 
            +
                        ).transpose(1, 2)
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 482 | 
            +
                        hidden_states.shape
         | 
| 483 | 
            +
                        if encoder_hidden_states is None
         | 
| 484 | 
            +
                        else encoder_hidden_states.shape
         | 
| 485 | 
            +
                    )
         | 
| 486 | 
            +
                    attention_mask = attn.prepare_attention_mask(
         | 
| 487 | 
            +
                        attention_mask, sequence_length, batch_size
         | 
| 488 | 
            +
                    )
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    if attn.group_norm is not None:
         | 
| 491 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         | 
| 492 | 
            +
                            1, 2
         | 
| 493 | 
            +
                        )
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    if encoder_hidden_states is None:
         | 
| 498 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 499 | 
            +
                    elif attn.norm_cross:
         | 
| 500 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(
         | 
| 501 | 
            +
                            encoder_hidden_states
         | 
| 502 | 
            +
                        )
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 505 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 508 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 509 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 512 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 513 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    # linear proj
         | 
| 516 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 517 | 
            +
                    # dropout
         | 
| 518 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    if input_ndim == 4:
         | 
| 521 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         | 
| 522 | 
            +
                            batch_size, channel, height, width
         | 
| 523 | 
            +
                        )
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if attn.residual_connection:
         | 
| 526 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    return hidden_states
         | 
| 531 | 
            +
             | 
| 532 | 
            +
             | 
| 533 | 
            +
            class AttnProcessor2_0:
         | 
| 534 | 
            +
                r"""
         | 
| 535 | 
            +
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         | 
| 536 | 
            +
                """
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def __init__(self):
         | 
| 539 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 540 | 
            +
                        raise ImportError(
         | 
| 541 | 
            +
                            "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
         | 
| 542 | 
            +
                        )
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                def __call__(
         | 
| 545 | 
            +
                    self,
         | 
| 546 | 
            +
                    attn: Attention,
         | 
| 547 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 548 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 549 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 550 | 
            +
                ) -> torch.FloatTensor:
         | 
| 551 | 
            +
                    residual = hidden_states
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    if input_ndim == 4:
         | 
| 556 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 557 | 
            +
                        hidden_states = hidden_states.view(
         | 
| 558 | 
            +
                            batch_size, channel, height * width
         | 
| 559 | 
            +
                        ).transpose(1, 2)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 562 | 
            +
                        hidden_states.shape
         | 
| 563 | 
            +
                        if encoder_hidden_states is None
         | 
| 564 | 
            +
                        else encoder_hidden_states.shape
         | 
| 565 | 
            +
                    )
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                    if attention_mask is not None:
         | 
| 568 | 
            +
                        attention_mask = attn.prepare_attention_mask(
         | 
| 569 | 
            +
                            attention_mask, sequence_length, batch_size
         | 
| 570 | 
            +
                        )
         | 
| 571 | 
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 572 | 
            +
                        # (batch, heads, source_length, target_length)
         | 
| 573 | 
            +
                        attention_mask = attention_mask.view(
         | 
| 574 | 
            +
                            batch_size, attn.heads, -1, attention_mask.shape[-1]
         | 
| 575 | 
            +
                        )
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    if attn.group_norm is not None:
         | 
| 578 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         | 
| 579 | 
            +
                            1, 2
         | 
| 580 | 
            +
                        )
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    if encoder_hidden_states is None:
         | 
| 585 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 586 | 
            +
                    elif attn.norm_cross:
         | 
| 587 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(
         | 
| 588 | 
            +
                            encoder_hidden_states
         | 
| 589 | 
            +
                        )
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 592 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 595 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 600 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 603 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 604 | 
            +
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 605 | 
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 606 | 
            +
                    )
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(
         | 
| 609 | 
            +
                        batch_size, -1, attn.heads * head_dim
         | 
| 610 | 
            +
                    )
         | 
| 611 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    # linear proj
         | 
| 614 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 615 | 
            +
                    # dropout
         | 
| 616 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    if input_ndim == 4:
         | 
| 619 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         | 
| 620 | 
            +
                            batch_size, channel, height, width
         | 
| 621 | 
            +
                        )
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    if attn.residual_connection:
         | 
| 624 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    return hidden_states
         | 
    	
        tsr/models/transformer/basic_transformer_block.py
    ADDED
    
    | @@ -0,0 +1,314 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from typing import Any, Dict, Optional
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn.functional as F
         | 
| 19 | 
            +
            from torch import nn
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from .attention import Attention
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 25 | 
            +
                r"""
         | 
| 26 | 
            +
                A basic Transformer block.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                Parameters:
         | 
| 29 | 
            +
                    dim (`int`): The number of channels in the input and output.
         | 
| 30 | 
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 31 | 
            +
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 32 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 33 | 
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         | 
| 34 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 35 | 
            +
                    num_embeds_ada_norm (:
         | 
| 36 | 
            +
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         | 
| 37 | 
            +
                    attention_bias (:
         | 
| 38 | 
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         | 
| 39 | 
            +
                    only_cross_attention (`bool`, *optional*):
         | 
| 40 | 
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         | 
| 41 | 
            +
                    double_self_attention (`bool`, *optional*):
         | 
| 42 | 
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         | 
| 43 | 
            +
                    upcast_attention (`bool`, *optional*):
         | 
| 44 | 
            +
                        Whether to upcast the attention computation to float32. This is useful for mixed precision training.
         | 
| 45 | 
            +
                    norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
         | 
| 46 | 
            +
                        Whether to use learnable elementwise affine parameters for normalization.
         | 
| 47 | 
            +
                    norm_type (`str`, *optional*, defaults to `"layer_norm"`):
         | 
| 48 | 
            +
                        The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
         | 
| 49 | 
            +
                    final_dropout (`bool` *optional*, defaults to False):
         | 
| 50 | 
            +
                        Whether to apply a final dropout after the last feed-forward layer.
         | 
| 51 | 
            +
                    attention_type (`str`, *optional*, defaults to `"default"`):
         | 
| 52 | 
            +
                        The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(
         | 
| 56 | 
            +
                    self,
         | 
| 57 | 
            +
                    dim: int,
         | 
| 58 | 
            +
                    num_attention_heads: int,
         | 
| 59 | 
            +
                    attention_head_dim: int,
         | 
| 60 | 
            +
                    dropout=0.0,
         | 
| 61 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 62 | 
            +
                    activation_fn: str = "geglu",
         | 
| 63 | 
            +
                    attention_bias: bool = False,
         | 
| 64 | 
            +
                    only_cross_attention: bool = False,
         | 
| 65 | 
            +
                    double_self_attention: bool = False,
         | 
| 66 | 
            +
                    upcast_attention: bool = False,
         | 
| 67 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 68 | 
            +
                    norm_type: str = "layer_norm",
         | 
| 69 | 
            +
                    final_dropout: bool = False,
         | 
| 70 | 
            +
                ):
         | 
| 71 | 
            +
                    super().__init__()
         | 
| 72 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    assert norm_type == "layer_norm"
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Define 3 blocks. Each block has its own normalization layer.
         | 
| 77 | 
            +
                    # 1. Self-Attn
         | 
| 78 | 
            +
                    self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 79 | 
            +
                    self.attn1 = Attention(
         | 
| 80 | 
            +
                        query_dim=dim,
         | 
| 81 | 
            +
                        heads=num_attention_heads,
         | 
| 82 | 
            +
                        dim_head=attention_head_dim,
         | 
| 83 | 
            +
                        dropout=dropout,
         | 
| 84 | 
            +
                        bias=attention_bias,
         | 
| 85 | 
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 86 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # 2. Cross-Attn
         | 
| 90 | 
            +
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 91 | 
            +
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         | 
| 92 | 
            +
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         | 
| 93 | 
            +
                        # the second cross attention block.
         | 
| 94 | 
            +
                        self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        self.attn2 = Attention(
         | 
| 97 | 
            +
                            query_dim=dim,
         | 
| 98 | 
            +
                            cross_attention_dim=cross_attention_dim
         | 
| 99 | 
            +
                            if not double_self_attention
         | 
| 100 | 
            +
                            else None,
         | 
| 101 | 
            +
                            heads=num_attention_heads,
         | 
| 102 | 
            +
                            dim_head=attention_head_dim,
         | 
| 103 | 
            +
                            dropout=dropout,
         | 
| 104 | 
            +
                            bias=attention_bias,
         | 
| 105 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 106 | 
            +
                        )  # is self-attn if encoder_hidden_states is none
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        self.norm2 = None
         | 
| 109 | 
            +
                        self.attn2 = None
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # 3. Feed-forward
         | 
| 112 | 
            +
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 113 | 
            +
                    self.ff = FeedForward(
         | 
| 114 | 
            +
                        dim,
         | 
| 115 | 
            +
                        dropout=dropout,
         | 
| 116 | 
            +
                        activation_fn=activation_fn,
         | 
| 117 | 
            +
                        final_dropout=final_dropout,
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # let chunk size default to None
         | 
| 121 | 
            +
                    self._chunk_size = None
         | 
| 122 | 
            +
                    self._chunk_dim = 0
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
         | 
| 125 | 
            +
                    # Sets chunk feed-forward
         | 
| 126 | 
            +
                    self._chunk_size = chunk_size
         | 
| 127 | 
            +
                    self._chunk_dim = dim
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 132 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 133 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 134 | 
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 135 | 
            +
                ) -> torch.FloatTensor:
         | 
| 136 | 
            +
                    # Notice that normalization is always applied before the real computation in the following blocks.
         | 
| 137 | 
            +
                    # 0. Self-Attention
         | 
| 138 | 
            +
                    norm_hidden_states = self.norm1(hidden_states)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    attn_output = self.attn1(
         | 
| 141 | 
            +
                        norm_hidden_states,
         | 
| 142 | 
            +
                        encoder_hidden_states=encoder_hidden_states
         | 
| 143 | 
            +
                        if self.only_cross_attention
         | 
| 144 | 
            +
                        else None,
         | 
| 145 | 
            +
                        attention_mask=attention_mask,
         | 
| 146 | 
            +
                    )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    hidden_states = attn_output + hidden_states
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # 3. Cross-Attention
         | 
| 151 | 
            +
                    if self.attn2 is not None:
         | 
| 152 | 
            +
                        norm_hidden_states = self.norm2(hidden_states)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        attn_output = self.attn2(
         | 
| 155 | 
            +
                            norm_hidden_states,
         | 
| 156 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 157 | 
            +
                            attention_mask=encoder_attention_mask,
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        hidden_states = attn_output + hidden_states
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # 4. Feed-forward
         | 
| 162 | 
            +
                    norm_hidden_states = self.norm3(hidden_states)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if self._chunk_size is not None:
         | 
| 165 | 
            +
                        # "feed_forward_chunk_size" can be used to save memory
         | 
| 166 | 
            +
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         | 
| 167 | 
            +
                            raise ValueError(
         | 
| 168 | 
            +
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         | 
| 169 | 
            +
                            )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         | 
| 172 | 
            +
                        ff_output = torch.cat(
         | 
| 173 | 
            +
                            [
         | 
| 174 | 
            +
                                self.ff(hid_slice)
         | 
| 175 | 
            +
                                for hid_slice in norm_hidden_states.chunk(
         | 
| 176 | 
            +
                                    num_chunks, dim=self._chunk_dim
         | 
| 177 | 
            +
                                )
         | 
| 178 | 
            +
                            ],
         | 
| 179 | 
            +
                            dim=self._chunk_dim,
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        ff_output = self.ff(norm_hidden_states)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    hidden_states = ff_output + hidden_states
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return hidden_states
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            class FeedForward(nn.Module):
         | 
| 190 | 
            +
                r"""
         | 
| 191 | 
            +
                A feed-forward layer.
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                Parameters:
         | 
| 194 | 
            +
                    dim (`int`): The number of channels in the input.
         | 
| 195 | 
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 196 | 
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 197 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 198 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 199 | 
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 200 | 
            +
                """
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def __init__(
         | 
| 203 | 
            +
                    self,
         | 
| 204 | 
            +
                    dim: int,
         | 
| 205 | 
            +
                    dim_out: Optional[int] = None,
         | 
| 206 | 
            +
                    mult: int = 4,
         | 
| 207 | 
            +
                    dropout: float = 0.0,
         | 
| 208 | 
            +
                    activation_fn: str = "geglu",
         | 
| 209 | 
            +
                    final_dropout: bool = False,
         | 
| 210 | 
            +
                ):
         | 
| 211 | 
            +
                    super().__init__()
         | 
| 212 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 213 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 214 | 
            +
                    linear_cls = nn.Linear
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    if activation_fn == "gelu":
         | 
| 217 | 
            +
                        act_fn = GELU(dim, inner_dim)
         | 
| 218 | 
            +
                    if activation_fn == "gelu-approximate":
         | 
| 219 | 
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         | 
| 220 | 
            +
                    elif activation_fn == "geglu":
         | 
| 221 | 
            +
                        act_fn = GEGLU(dim, inner_dim)
         | 
| 222 | 
            +
                    elif activation_fn == "geglu-approximate":
         | 
| 223 | 
            +
                        act_fn = ApproximateGELU(dim, inner_dim)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 226 | 
            +
                    # project in
         | 
| 227 | 
            +
                    self.net.append(act_fn)
         | 
| 228 | 
            +
                    # project dropout
         | 
| 229 | 
            +
                    self.net.append(nn.Dropout(dropout))
         | 
| 230 | 
            +
                    # project out
         | 
| 231 | 
            +
                    self.net.append(linear_cls(inner_dim, dim_out))
         | 
| 232 | 
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 233 | 
            +
                    if final_dropout:
         | 
| 234 | 
            +
                        self.net.append(nn.Dropout(dropout))
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         | 
| 237 | 
            +
                    for module in self.net:
         | 
| 238 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 239 | 
            +
                    return hidden_states
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            class GELU(nn.Module):
         | 
| 243 | 
            +
                r"""
         | 
| 244 | 
            +
                GELU activation function with tanh approximation support with `approximate="tanh"`.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                Parameters:
         | 
| 247 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 248 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 249 | 
            +
                    approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
         | 
| 253 | 
            +
                    super().__init__()
         | 
| 254 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         | 
| 255 | 
            +
                    self.approximate = approximate
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def gelu(self, gate: torch.Tensor) -> torch.Tensor:
         | 
| 258 | 
            +
                    if gate.device.type != "mps":
         | 
| 259 | 
            +
                        return F.gelu(gate, approximate=self.approximate)
         | 
| 260 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 261 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
         | 
| 262 | 
            +
                        dtype=gate.dtype
         | 
| 263 | 
            +
                    )
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def forward(self, hidden_states):
         | 
| 266 | 
            +
                    hidden_states = self.proj(hidden_states)
         | 
| 267 | 
            +
                    hidden_states = self.gelu(hidden_states)
         | 
| 268 | 
            +
                    return hidden_states
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            class GEGLU(nn.Module):
         | 
| 272 | 
            +
                r"""
         | 
| 273 | 
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                Parameters:
         | 
| 276 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 277 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 278 | 
            +
                """
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 281 | 
            +
                    super().__init__()
         | 
| 282 | 
            +
                    linear_cls = nn.Linear
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    self.proj = linear_cls(dim_in, dim_out * 2)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def gelu(self, gate: torch.Tensor) -> torch.Tensor:
         | 
| 287 | 
            +
                    if gate.device.type != "mps":
         | 
| 288 | 
            +
                        return F.gelu(gate)
         | 
| 289 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 290 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def forward(self, hidden_states, scale: float = 1.0):
         | 
| 293 | 
            +
                    args = ()
         | 
| 294 | 
            +
                    hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
         | 
| 295 | 
            +
                    return hidden_states * self.gelu(gate)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            class ApproximateGELU(nn.Module):
         | 
| 299 | 
            +
                r"""
         | 
| 300 | 
            +
                The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
         | 
| 301 | 
            +
                https://arxiv.org/abs/1606.08415.
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                Parameters:
         | 
| 304 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 305 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 306 | 
            +
                """
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 309 | 
            +
                    super().__init__()
         | 
| 310 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 313 | 
            +
                    x = self.proj(x)
         | 
| 314 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
    	
        tsr/models/transformer/transformer_1d.py
    ADDED
    
    | @@ -0,0 +1,216 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Any, Dict, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ...utils import BaseModule
         | 
| 9 | 
            +
            from .basic_transformer_block import BasicTransformerBlock
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Transformer1D(BaseModule):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                A 1D Transformer model for sequence data.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Parameters:
         | 
| 17 | 
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         | 
| 18 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         | 
| 19 | 
            +
                    in_channels (`int`, *optional*):
         | 
| 20 | 
            +
                        The number of channels in the input and output (specify if the input is **continuous**).
         | 
| 21 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         | 
| 22 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 23 | 
            +
                    cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
         | 
| 24 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
         | 
| 25 | 
            +
                    num_embeds_ada_norm ( `int`, *optional*):
         | 
| 26 | 
            +
                        The number of diffusion steps used during training. Pass if at least one of the norm_layers is
         | 
| 27 | 
            +
                        `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
         | 
| 28 | 
            +
                        added to the hidden states.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                        During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
         | 
| 31 | 
            +
                    attention_bias (`bool`, *optional*):
         | 
| 32 | 
            +
                        Configure if the `TransformerBlocks` attention should contain a bias parameter.
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                @dataclass
         | 
| 36 | 
            +
                class Config(BaseModule.Config):
         | 
| 37 | 
            +
                    num_attention_heads: int = 16
         | 
| 38 | 
            +
                    attention_head_dim: int = 88
         | 
| 39 | 
            +
                    in_channels: Optional[int] = None
         | 
| 40 | 
            +
                    out_channels: Optional[int] = None
         | 
| 41 | 
            +
                    num_layers: int = 1
         | 
| 42 | 
            +
                    dropout: float = 0.0
         | 
| 43 | 
            +
                    norm_num_groups: int = 32
         | 
| 44 | 
            +
                    cross_attention_dim: Optional[int] = None
         | 
| 45 | 
            +
                    attention_bias: bool = False
         | 
| 46 | 
            +
                    activation_fn: str = "geglu"
         | 
| 47 | 
            +
                    only_cross_attention: bool = False
         | 
| 48 | 
            +
                    double_self_attention: bool = False
         | 
| 49 | 
            +
                    upcast_attention: bool = False
         | 
| 50 | 
            +
                    norm_type: str = "layer_norm"
         | 
| 51 | 
            +
                    norm_elementwise_affine: bool = True
         | 
| 52 | 
            +
                    gradient_checkpointing: bool = False
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                cfg: Config
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def configure(self) -> None:
         | 
| 57 | 
            +
                    self.num_attention_heads = self.cfg.num_attention_heads
         | 
| 58 | 
            +
                    self.attention_head_dim = self.cfg.attention_head_dim
         | 
| 59 | 
            +
                    inner_dim = self.num_attention_heads * self.attention_head_dim
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    linear_cls = nn.Linear
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # 2. Define input layers
         | 
| 64 | 
            +
                    self.in_channels = self.cfg.in_channels
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.norm = torch.nn.GroupNorm(
         | 
| 67 | 
            +
                        num_groups=self.cfg.norm_num_groups,
         | 
| 68 | 
            +
                        num_channels=self.cfg.in_channels,
         | 
| 69 | 
            +
                        eps=1e-6,
         | 
| 70 | 
            +
                        affine=True,
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # 3. Define transformers blocks
         | 
| 75 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 76 | 
            +
                        [
         | 
| 77 | 
            +
                            BasicTransformerBlock(
         | 
| 78 | 
            +
                                inner_dim,
         | 
| 79 | 
            +
                                self.num_attention_heads,
         | 
| 80 | 
            +
                                self.attention_head_dim,
         | 
| 81 | 
            +
                                dropout=self.cfg.dropout,
         | 
| 82 | 
            +
                                cross_attention_dim=self.cfg.cross_attention_dim,
         | 
| 83 | 
            +
                                activation_fn=self.cfg.activation_fn,
         | 
| 84 | 
            +
                                attention_bias=self.cfg.attention_bias,
         | 
| 85 | 
            +
                                only_cross_attention=self.cfg.only_cross_attention,
         | 
| 86 | 
            +
                                double_self_attention=self.cfg.double_self_attention,
         | 
| 87 | 
            +
                                upcast_attention=self.cfg.upcast_attention,
         | 
| 88 | 
            +
                                norm_type=self.cfg.norm_type,
         | 
| 89 | 
            +
                                norm_elementwise_affine=self.cfg.norm_elementwise_affine,
         | 
| 90 | 
            +
                            )
         | 
| 91 | 
            +
                            for d in range(self.cfg.num_layers)
         | 
| 92 | 
            +
                        ]
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # 4. Define output layers
         | 
| 96 | 
            +
                    self.out_channels = (
         | 
| 97 | 
            +
                        self.cfg.in_channels
         | 
| 98 | 
            +
                        if self.cfg.out_channels is None
         | 
| 99 | 
            +
                        else self.cfg.out_channels
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.gradient_checkpointing = self.cfg.gradient_checkpointing
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def forward(
         | 
| 107 | 
            +
                    self,
         | 
| 108 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 109 | 
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 110 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 111 | 
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 112 | 
            +
                ):
         | 
| 113 | 
            +
                    """
         | 
| 114 | 
            +
                    The [`Transformer1DModel`] forward method.
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    Args:
         | 
| 117 | 
            +
                        hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
         | 
| 118 | 
            +
                            Input `hidden_states`.
         | 
| 119 | 
            +
                        encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
         | 
| 120 | 
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         | 
| 121 | 
            +
                            self-attention.
         | 
| 122 | 
            +
                        timestep ( `torch.LongTensor`, *optional*):
         | 
| 123 | 
            +
                            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
         | 
| 124 | 
            +
                        class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
         | 
| 125 | 
            +
                            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
         | 
| 126 | 
            +
                            `AdaLayerZeroNorm`.
         | 
| 127 | 
            +
                        cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
         | 
| 128 | 
            +
                            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
         | 
| 129 | 
            +
                            `self.processor` in
         | 
| 130 | 
            +
                            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
         | 
| 131 | 
            +
                        attention_mask ( `torch.Tensor`, *optional*):
         | 
| 132 | 
            +
                            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
         | 
| 133 | 
            +
                            is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
         | 
| 134 | 
            +
                            negative values to the attention scores corresponding to "discard" tokens.
         | 
| 135 | 
            +
                        encoder_attention_mask ( `torch.Tensor`, *optional*):
         | 
| 136 | 
            +
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         | 
| 139 | 
            +
                                * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                            If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
         | 
| 142 | 
            +
                            above. This bias will be added to the cross-attention scores.
         | 
| 143 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 144 | 
            +
                            Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
         | 
| 145 | 
            +
                            tuple.
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    Returns:
         | 
| 148 | 
            +
                        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
         | 
| 149 | 
            +
                        `tuple` where the first element is the sample tensor.
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
         | 
| 152 | 
            +
                    #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
         | 
| 153 | 
            +
                    #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
         | 
| 154 | 
            +
                    # expects mask of shape:
         | 
| 155 | 
            +
                    #   [batch, key_tokens]
         | 
| 156 | 
            +
                    # adds singleton query_tokens dimension:
         | 
| 157 | 
            +
                    #   [batch,                    1, key_tokens]
         | 
| 158 | 
            +
                    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         | 
| 159 | 
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         | 
| 160 | 
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         | 
| 161 | 
            +
                    if attention_mask is not None and attention_mask.ndim == 2:
         | 
| 162 | 
            +
                        # assume that mask is expressed as:
         | 
| 163 | 
            +
                        #   (1 = keep,      0 = discard)
         | 
| 164 | 
            +
                        # convert mask into a bias that can be added to attention scores:
         | 
| 165 | 
            +
                        #       (keep = +0,     discard = -10000.0)
         | 
| 166 | 
            +
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         | 
| 167 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 170 | 
            +
                    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         | 
| 171 | 
            +
                        encoder_attention_mask = (
         | 
| 172 | 
            +
                            1 - encoder_attention_mask.to(hidden_states.dtype)
         | 
| 173 | 
            +
                        ) * -10000.0
         | 
| 174 | 
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # 1. Input
         | 
| 177 | 
            +
                    batch, _, seq_len = hidden_states.shape
         | 
| 178 | 
            +
                    residual = hidden_states
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 181 | 
            +
                    inner_dim = hidden_states.shape[1]
         | 
| 182 | 
            +
                    hidden_states = hidden_states.permute(0, 2, 1).reshape(
         | 
| 183 | 
            +
                        batch, seq_len, inner_dim
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
                    hidden_states = self.proj_in(hidden_states)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # 2. Blocks
         | 
| 188 | 
            +
                    for block in self.transformer_blocks:
         | 
| 189 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 190 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 191 | 
            +
                                block,
         | 
| 192 | 
            +
                                hidden_states,
         | 
| 193 | 
            +
                                attention_mask,
         | 
| 194 | 
            +
                                encoder_hidden_states,
         | 
| 195 | 
            +
                                encoder_attention_mask,
         | 
| 196 | 
            +
                                use_reentrant=False,
         | 
| 197 | 
            +
                            )
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            hidden_states = block(
         | 
| 200 | 
            +
                                hidden_states,
         | 
| 201 | 
            +
                                attention_mask=attention_mask,
         | 
| 202 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 203 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 204 | 
            +
                            )
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    # 3. Output
         | 
| 207 | 
            +
                    hidden_states = self.proj_out(hidden_states)
         | 
| 208 | 
            +
                    hidden_states = (
         | 
| 209 | 
            +
                        hidden_states.reshape(batch, seq_len, inner_dim)
         | 
| 210 | 
            +
                        .permute(0, 2, 1)
         | 
| 211 | 
            +
                        .contiguous()
         | 
| 212 | 
            +
                    )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    output = hidden_states + residual
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    return output
         | 
    	
        tsr/system.py
    ADDED
    
    | @@ -0,0 +1,203 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from dataclasses import dataclass, field
         | 
| 4 | 
            +
            from typing import List, Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import PIL.Image
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            import trimesh
         | 
| 11 | 
            +
            from einops import rearrange
         | 
| 12 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 13 | 
            +
            from omegaconf import OmegaConf
         | 
| 14 | 
            +
            from PIL import Image
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .models.isosurface import MarchingCubeHelper
         | 
| 17 | 
            +
            from .utils import (
         | 
| 18 | 
            +
                BaseModule,
         | 
| 19 | 
            +
                ImagePreprocessor,
         | 
| 20 | 
            +
                find_class,
         | 
| 21 | 
            +
                get_spherical_cameras,
         | 
| 22 | 
            +
                scale_tensor,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class TSR(BaseModule):
         | 
| 27 | 
            +
                @dataclass
         | 
| 28 | 
            +
                class Config(BaseModule.Config):
         | 
| 29 | 
            +
                    cond_image_size: int
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    image_tokenizer_cls: str
         | 
| 32 | 
            +
                    image_tokenizer: dict
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    tokenizer_cls: str
         | 
| 35 | 
            +
                    tokenizer: dict
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    backbone_cls: str
         | 
| 38 | 
            +
                    backbone: dict
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    post_processor_cls: str
         | 
| 41 | 
            +
                    post_processor: dict
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    decoder_cls: str
         | 
| 44 | 
            +
                    decoder: dict
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    renderer_cls: str
         | 
| 47 | 
            +
                    renderer: dict
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                cfg: Config
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @classmethod
         | 
| 52 | 
            +
                def from_pretrained(
         | 
| 53 | 
            +
                    cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
         | 
| 54 | 
            +
                ):
         | 
| 55 | 
            +
                    if os.path.isdir(pretrained_model_name_or_path):
         | 
| 56 | 
            +
                        config_path = os.path.join(pretrained_model_name_or_path, config_name)
         | 
| 57 | 
            +
                        weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
         | 
| 58 | 
            +
                    else:
         | 
| 59 | 
            +
                        config_path = hf_hub_download(
         | 
| 60 | 
            +
                            repo_id=pretrained_model_name_or_path, filename=config_name
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
                        weight_path = hf_hub_download(
         | 
| 63 | 
            +
                            repo_id=pretrained_model_name_or_path, filename=weight_name
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    cfg = OmegaConf.load(config_path)
         | 
| 67 | 
            +
                    OmegaConf.resolve(cfg)
         | 
| 68 | 
            +
                    model = cls(cfg)
         | 
| 69 | 
            +
                    ckpt = torch.load(weight_path, map_location="cpu")
         | 
| 70 | 
            +
                    model.load_state_dict(ckpt)
         | 
| 71 | 
            +
                    return model
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def configure(self):
         | 
| 74 | 
            +
                    self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
         | 
| 75 | 
            +
                        self.cfg.image_tokenizer
         | 
| 76 | 
            +
                    )
         | 
| 77 | 
            +
                    self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
         | 
| 78 | 
            +
                    self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
         | 
| 79 | 
            +
                    self.post_processor = find_class(self.cfg.post_processor_cls)(
         | 
| 80 | 
            +
                        self.cfg.post_processor
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
         | 
| 83 | 
            +
                    self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
         | 
| 84 | 
            +
                    self.image_processor = ImagePreprocessor()
         | 
| 85 | 
            +
                    self.isosurface_helper = None
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(
         | 
| 88 | 
            +
                    self,
         | 
| 89 | 
            +
                    image: Union[
         | 
| 90 | 
            +
                        PIL.Image.Image,
         | 
| 91 | 
            +
                        np.ndarray,
         | 
| 92 | 
            +
                        torch.FloatTensor,
         | 
| 93 | 
            +
                        List[PIL.Image.Image],
         | 
| 94 | 
            +
                        List[np.ndarray],
         | 
| 95 | 
            +
                        List[torch.FloatTensor],
         | 
| 96 | 
            +
                    ],
         | 
| 97 | 
            +
                    device: str,
         | 
| 98 | 
            +
                ) -> torch.FloatTensor:
         | 
| 99 | 
            +
                    rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
         | 
| 100 | 
            +
                        device
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    batch_size = rgb_cond.shape[0]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    input_image_tokens: torch.Tensor = self.image_tokenizer(
         | 
| 105 | 
            +
                        rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    input_image_tokens = rearrange(
         | 
| 109 | 
            +
                        input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    tokens: torch.Tensor = self.tokenizer(batch_size)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    tokens = self.backbone(
         | 
| 115 | 
            +
                        tokens,
         | 
| 116 | 
            +
                        encoder_hidden_states=input_image_tokens,
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
         | 
| 120 | 
            +
                    return scene_codes
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def render(
         | 
| 123 | 
            +
                    self,
         | 
| 124 | 
            +
                    scene_codes,
         | 
| 125 | 
            +
                    n_views: int,
         | 
| 126 | 
            +
                    elevation_deg: float = 0.0,
         | 
| 127 | 
            +
                    camera_distance: float = 1.9,
         | 
| 128 | 
            +
                    fovy_deg: float = 40.0,
         | 
| 129 | 
            +
                    height: int = 256,
         | 
| 130 | 
            +
                    width: int = 256,
         | 
| 131 | 
            +
                    return_type: str = "pil",
         | 
| 132 | 
            +
                ):
         | 
| 133 | 
            +
                    rays_o, rays_d = get_spherical_cameras(
         | 
| 134 | 
            +
                        n_views, elevation_deg, camera_distance, fovy_deg, height, width
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
                    rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    def process_output(image: torch.FloatTensor):
         | 
| 139 | 
            +
                        if return_type == "pt":
         | 
| 140 | 
            +
                            return image
         | 
| 141 | 
            +
                        elif return_type == "np":
         | 
| 142 | 
            +
                            return image.detach().cpu().numpy()
         | 
| 143 | 
            +
                        elif return_type == "pil":
         | 
| 144 | 
            +
                            return Image.fromarray(
         | 
| 145 | 
            +
                                (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
         | 
| 146 | 
            +
                            )
         | 
| 147 | 
            +
                        else:
         | 
| 148 | 
            +
                            raise NotImplementedError
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    images = []
         | 
| 151 | 
            +
                    for scene_code in scene_codes:
         | 
| 152 | 
            +
                        images_ = []
         | 
| 153 | 
            +
                        for i in range(n_views):
         | 
| 154 | 
            +
                            with torch.no_grad():
         | 
| 155 | 
            +
                                image = self.renderer(
         | 
| 156 | 
            +
                                    self.decoder, scene_code, rays_o[i], rays_d[i]
         | 
| 157 | 
            +
                                )
         | 
| 158 | 
            +
                            images_.append(process_output(image))
         | 
| 159 | 
            +
                        images.append(images_)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return images
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def set_marching_cubes_resolution(self, resolution: int):
         | 
| 164 | 
            +
                    if (
         | 
| 165 | 
            +
                        self.isosurface_helper is not None
         | 
| 166 | 
            +
                        and self.isosurface_helper.resolution == resolution
         | 
| 167 | 
            +
                    ):
         | 
| 168 | 
            +
                        return
         | 
| 169 | 
            +
                    self.isosurface_helper = MarchingCubeHelper(resolution)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 20.0):
         | 
| 172 | 
            +
                    self.set_marching_cubes_resolution(resolution)
         | 
| 173 | 
            +
                    meshes = []
         | 
| 174 | 
            +
                    for scene_code in scene_codes:
         | 
| 175 | 
            +
                        with torch.no_grad():
         | 
| 176 | 
            +
                            density = self.renderer.query_triplane(
         | 
| 177 | 
            +
                                self.decoder,
         | 
| 178 | 
            +
                                scale_tensor(
         | 
| 179 | 
            +
                                    self.isosurface_helper.grid_vertices.to(scene_codes.device),
         | 
| 180 | 
            +
                                    self.isosurface_helper.points_range,
         | 
| 181 | 
            +
                                    (-self.renderer.cfg.radius, self.renderer.cfg.radius),
         | 
| 182 | 
            +
                                ),
         | 
| 183 | 
            +
                                scene_code,
         | 
| 184 | 
            +
                            )["density_act"]
         | 
| 185 | 
            +
                        v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
         | 
| 186 | 
            +
                        v_pos = scale_tensor(
         | 
| 187 | 
            +
                            v_pos,
         | 
| 188 | 
            +
                            self.isosurface_helper.points_range,
         | 
| 189 | 
            +
                            (-self.renderer.cfg.radius, self.renderer.cfg.radius),
         | 
| 190 | 
            +
                        )
         | 
| 191 | 
            +
                        with torch.no_grad():
         | 
| 192 | 
            +
                            color = self.renderer.query_triplane(
         | 
| 193 | 
            +
                                self.decoder,
         | 
| 194 | 
            +
                                v_pos,
         | 
| 195 | 
            +
                                scene_code,
         | 
| 196 | 
            +
                            )["color"]
         | 
| 197 | 
            +
                        mesh = trimesh.Trimesh(
         | 
| 198 | 
            +
                            vertices=v_pos.cpu().numpy(),
         | 
| 199 | 
            +
                            faces=t_pos_idx.cpu().numpy(),
         | 
| 200 | 
            +
                            vertex_colors=color.cpu().numpy(),
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
                        meshes.append(mesh)
         | 
| 203 | 
            +
                    return meshes
         | 
    	
        tsr/utils.py
    ADDED
    
    | @@ -0,0 +1,492 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            from collections import defaultdict
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import imageio
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import PIL.Image
         | 
| 10 | 
            +
            import rembg
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import torch.nn.functional as F
         | 
| 14 | 
            +
            from omegaconf import DictConfig, OmegaConf
         | 
| 15 | 
            +
            from PIL import Image
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
         | 
| 19 | 
            +
                scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
         | 
| 20 | 
            +
                return scfg
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def find_class(cls_string):
         | 
| 24 | 
            +
                module_string = ".".join(cls_string.split(".")[:-1])
         | 
| 25 | 
            +
                cls_name = cls_string.split(".")[-1]
         | 
| 26 | 
            +
                module = importlib.import_module(module_string, package=None)
         | 
| 27 | 
            +
                cls = getattr(module, cls_name)
         | 
| 28 | 
            +
                return cls
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_intrinsic_from_fov(fov, H, W, bs=-1):
         | 
| 32 | 
            +
                focal_length = 0.5 * H / np.tan(0.5 * fov)
         | 
| 33 | 
            +
                intrinsic = np.identity(3, dtype=np.float32)
         | 
| 34 | 
            +
                intrinsic[0, 0] = focal_length
         | 
| 35 | 
            +
                intrinsic[1, 1] = focal_length
         | 
| 36 | 
            +
                intrinsic[0, 2] = W / 2.0
         | 
| 37 | 
            +
                intrinsic[1, 2] = H / 2.0
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                if bs > 0:
         | 
| 40 | 
            +
                    intrinsic = intrinsic[None].repeat(bs, axis=0)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                return torch.from_numpy(intrinsic)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class BaseModule(nn.Module):
         | 
| 46 | 
            +
                @dataclass
         | 
| 47 | 
            +
                class Config:
         | 
| 48 | 
            +
                    pass
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                cfg: Config  # add this to every subclass of BaseModule to enable static type checking
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __init__(
         | 
| 53 | 
            +
                    self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
         | 
| 54 | 
            +
                ) -> None:
         | 
| 55 | 
            +
                    super().__init__()
         | 
| 56 | 
            +
                    self.cfg = parse_structured(self.Config, cfg)
         | 
| 57 | 
            +
                    self.configure(*args, **kwargs)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def configure(self, *args, **kwargs) -> None:
         | 
| 60 | 
            +
                    raise NotImplementedError
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            class ImagePreprocessor:
         | 
| 64 | 
            +
                def convert_and_resize(
         | 
| 65 | 
            +
                    self,
         | 
| 66 | 
            +
                    image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
         | 
| 67 | 
            +
                    size: int,
         | 
| 68 | 
            +
                ):
         | 
| 69 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 70 | 
            +
                        image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
         | 
| 71 | 
            +
                    elif isinstance(image, np.ndarray):
         | 
| 72 | 
            +
                        if image.dtype == np.uint8:
         | 
| 73 | 
            +
                            image = torch.from_numpy(image.astype(np.float32) / 255.0)
         | 
| 74 | 
            +
                        else:
         | 
| 75 | 
            +
                            image = torch.from_numpy(image)
         | 
| 76 | 
            +
                    elif isinstance(image, torch.Tensor):
         | 
| 77 | 
            +
                        pass
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    batched = image.ndim == 4
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if not batched:
         | 
| 82 | 
            +
                        image = image[None, ...]
         | 
| 83 | 
            +
                    image = F.interpolate(
         | 
| 84 | 
            +
                        image.permute(0, 3, 1, 2),
         | 
| 85 | 
            +
                        (size, size),
         | 
| 86 | 
            +
                        mode="bilinear",
         | 
| 87 | 
            +
                        align_corners=False,
         | 
| 88 | 
            +
                        antialias=True,
         | 
| 89 | 
            +
                    ).permute(0, 2, 3, 1)
         | 
| 90 | 
            +
                    if not batched:
         | 
| 91 | 
            +
                        image = image[0]
         | 
| 92 | 
            +
                    return image
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __call__(
         | 
| 95 | 
            +
                    self,
         | 
| 96 | 
            +
                    image: Union[
         | 
| 97 | 
            +
                        PIL.Image.Image,
         | 
| 98 | 
            +
                        np.ndarray,
         | 
| 99 | 
            +
                        torch.FloatTensor,
         | 
| 100 | 
            +
                        List[PIL.Image.Image],
         | 
| 101 | 
            +
                        List[np.ndarray],
         | 
| 102 | 
            +
                        List[torch.FloatTensor],
         | 
| 103 | 
            +
                    ],
         | 
| 104 | 
            +
                    size: int,
         | 
| 105 | 
            +
                ) -> Any:
         | 
| 106 | 
            +
                    if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
         | 
| 107 | 
            +
                        image = self.convert_and_resize(image, size)
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        if not isinstance(image, list):
         | 
| 110 | 
            +
                            image = [image]
         | 
| 111 | 
            +
                        image = [self.convert_and_resize(im, size) for im in image]
         | 
| 112 | 
            +
                        image = torch.stack(image, dim=0)
         | 
| 113 | 
            +
                    return image
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def rays_intersect_bbox(
         | 
| 117 | 
            +
                rays_o: torch.Tensor,
         | 
| 118 | 
            +
                rays_d: torch.Tensor,
         | 
| 119 | 
            +
                radius: float,
         | 
| 120 | 
            +
                near: float = 0.0,
         | 
| 121 | 
            +
                valid_thresh: float = 0.01,
         | 
| 122 | 
            +
            ):
         | 
| 123 | 
            +
                input_shape = rays_o.shape[:-1]
         | 
| 124 | 
            +
                rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
         | 
| 125 | 
            +
                rays_d_valid = torch.where(
         | 
| 126 | 
            +
                    rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
         | 
| 127 | 
            +
                )
         | 
| 128 | 
            +
                if type(radius) in [int, float]:
         | 
| 129 | 
            +
                    radius = torch.FloatTensor(
         | 
| 130 | 
            +
                        [[-radius, radius], [-radius, radius], [-radius, radius]]
         | 
| 131 | 
            +
                    ).to(rays_o.device)
         | 
| 132 | 
            +
                radius = (
         | 
| 133 | 
            +
                    1.0 - 1.0e-3
         | 
| 134 | 
            +
                ) * radius  # tighten the radius to make sure the intersection point lies in the bounding box
         | 
| 135 | 
            +
                interx0 = (radius[..., 1] - rays_o) / rays_d_valid
         | 
| 136 | 
            +
                interx1 = (radius[..., 0] - rays_o) / rays_d_valid
         | 
| 137 | 
            +
                t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
         | 
| 138 | 
            +
                t_far = torch.maximum(interx0, interx1).amin(dim=-1)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # check wheter a ray intersects the bbox or not
         | 
| 141 | 
            +
                rays_valid = t_far - t_near > valid_thresh
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                t_near[torch.where(~rays_valid)] = 0.0
         | 
| 144 | 
            +
                t_far[torch.where(~rays_valid)] = 0.0
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                t_near = t_near.view(*input_shape, 1)
         | 
| 147 | 
            +
                t_far = t_far.view(*input_shape, 1)
         | 
| 148 | 
            +
                rays_valid = rays_valid.view(*input_shape)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                return t_near, t_far, rays_valid
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
         | 
| 154 | 
            +
                if chunk_size <= 0:
         | 
| 155 | 
            +
                    return func(*args, **kwargs)
         | 
| 156 | 
            +
                B = None
         | 
| 157 | 
            +
                for arg in list(args) + list(kwargs.values()):
         | 
| 158 | 
            +
                    if isinstance(arg, torch.Tensor):
         | 
| 159 | 
            +
                        B = arg.shape[0]
         | 
| 160 | 
            +
                        break
         | 
| 161 | 
            +
                assert (
         | 
| 162 | 
            +
                    B is not None
         | 
| 163 | 
            +
                ), "No tensor found in args or kwargs, cannot determine batch size."
         | 
| 164 | 
            +
                out = defaultdict(list)
         | 
| 165 | 
            +
                out_type = None
         | 
| 166 | 
            +
                # max(1, B) to support B == 0
         | 
| 167 | 
            +
                for i in range(0, max(1, B), chunk_size):
         | 
| 168 | 
            +
                    out_chunk = func(
         | 
| 169 | 
            +
                        *[
         | 
| 170 | 
            +
                            arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
         | 
| 171 | 
            +
                            for arg in args
         | 
| 172 | 
            +
                        ],
         | 
| 173 | 
            +
                        **{
         | 
| 174 | 
            +
                            k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
         | 
| 175 | 
            +
                            for k, arg in kwargs.items()
         | 
| 176 | 
            +
                        },
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
                    if out_chunk is None:
         | 
| 179 | 
            +
                        continue
         | 
| 180 | 
            +
                    out_type = type(out_chunk)
         | 
| 181 | 
            +
                    if isinstance(out_chunk, torch.Tensor):
         | 
| 182 | 
            +
                        out_chunk = {0: out_chunk}
         | 
| 183 | 
            +
                    elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
         | 
| 184 | 
            +
                        chunk_length = len(out_chunk)
         | 
| 185 | 
            +
                        out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
         | 
| 186 | 
            +
                    elif isinstance(out_chunk, dict):
         | 
| 187 | 
            +
                        pass
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        print(
         | 
| 190 | 
            +
                            f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                        exit(1)
         | 
| 193 | 
            +
                    for k, v in out_chunk.items():
         | 
| 194 | 
            +
                        v = v if torch.is_grad_enabled() else v.detach()
         | 
| 195 | 
            +
                        out[k].append(v)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if out_type is None:
         | 
| 198 | 
            +
                    return None
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                out_merged: Dict[Any, Optional[torch.Tensor]] = {}
         | 
| 201 | 
            +
                for k, v in out.items():
         | 
| 202 | 
            +
                    if all([vv is None for vv in v]):
         | 
| 203 | 
            +
                        # allow None in return value
         | 
| 204 | 
            +
                        out_merged[k] = None
         | 
| 205 | 
            +
                    elif all([isinstance(vv, torch.Tensor) for vv in v]):
         | 
| 206 | 
            +
                        out_merged[k] = torch.cat(v, dim=0)
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        raise TypeError(
         | 
| 209 | 
            +
                            f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
         | 
| 210 | 
            +
                        )
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                if out_type is torch.Tensor:
         | 
| 213 | 
            +
                    return out_merged[0]
         | 
| 214 | 
            +
                elif out_type in [tuple, list]:
         | 
| 215 | 
            +
                    return out_type([out_merged[i] for i in range(chunk_length)])
         | 
| 216 | 
            +
                elif out_type is dict:
         | 
| 217 | 
            +
                    return out_merged
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            ValidScale = Union[Tuple[float, float], torch.FloatTensor]
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
         | 
| 224 | 
            +
                if inp_scale is None:
         | 
| 225 | 
            +
                    inp_scale = (0, 1)
         | 
| 226 | 
            +
                if tgt_scale is None:
         | 
| 227 | 
            +
                    tgt_scale = (0, 1)
         | 
| 228 | 
            +
                if isinstance(tgt_scale, torch.FloatTensor):
         | 
| 229 | 
            +
                    assert dat.shape[-1] == tgt_scale.shape[-1]
         | 
| 230 | 
            +
                dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
         | 
| 231 | 
            +
                dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
         | 
| 232 | 
            +
                return dat
         | 
| 233 | 
            +
             | 
| 234 | 
            +
             | 
| 235 | 
            +
            def get_activation(name) -> Callable:
         | 
| 236 | 
            +
                if name is None:
         | 
| 237 | 
            +
                    return lambda x: x
         | 
| 238 | 
            +
                name = name.lower()
         | 
| 239 | 
            +
                if name == "none":
         | 
| 240 | 
            +
                    return lambda x: x
         | 
| 241 | 
            +
                elif name == "exp":
         | 
| 242 | 
            +
                    return lambda x: torch.exp(x)
         | 
| 243 | 
            +
                elif name == "sigmoid":
         | 
| 244 | 
            +
                    return lambda x: torch.sigmoid(x)
         | 
| 245 | 
            +
                elif name == "tanh":
         | 
| 246 | 
            +
                    return lambda x: torch.tanh(x)
         | 
| 247 | 
            +
                elif name == "softplus":
         | 
| 248 | 
            +
                    return lambda x: F.softplus(x)
         | 
| 249 | 
            +
                else:
         | 
| 250 | 
            +
                    try:
         | 
| 251 | 
            +
                        return getattr(F, name)
         | 
| 252 | 
            +
                    except AttributeError:
         | 
| 253 | 
            +
                        raise ValueError(f"Unknown activation function: {name}")
         | 
| 254 | 
            +
             | 
| 255 | 
            +
             | 
| 256 | 
            +
            def get_ray_directions(
         | 
| 257 | 
            +
                H: int,
         | 
| 258 | 
            +
                W: int,
         | 
| 259 | 
            +
                focal: Union[float, Tuple[float, float]],
         | 
| 260 | 
            +
                principal: Optional[Tuple[float, float]] = None,
         | 
| 261 | 
            +
                use_pixel_centers: bool = True,
         | 
| 262 | 
            +
                normalize: bool = True,
         | 
| 263 | 
            +
            ) -> torch.FloatTensor:
         | 
| 264 | 
            +
                """
         | 
| 265 | 
            +
                Get ray directions for all pixels in camera coordinate.
         | 
| 266 | 
            +
                Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
         | 
| 267 | 
            +
                           ray-tracing-generating-camera-rays/standard-coordinate-systems
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                Inputs:
         | 
| 270 | 
            +
                    H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
         | 
| 271 | 
            +
                Outputs:
         | 
| 272 | 
            +
                    directions: (H, W, 3), the direction of the rays in camera coordinate
         | 
| 273 | 
            +
                """
         | 
| 274 | 
            +
                pixel_center = 0.5 if use_pixel_centers else 0
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if isinstance(focal, float):
         | 
| 277 | 
            +
                    fx, fy = focal, focal
         | 
| 278 | 
            +
                    cx, cy = W / 2, H / 2
         | 
| 279 | 
            +
                else:
         | 
| 280 | 
            +
                    fx, fy = focal
         | 
| 281 | 
            +
                    assert principal is not None
         | 
| 282 | 
            +
                    cx, cy = principal
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                i, j = torch.meshgrid(
         | 
| 285 | 
            +
                    torch.arange(W, dtype=torch.float32) + pixel_center,
         | 
| 286 | 
            +
                    torch.arange(H, dtype=torch.float32) + pixel_center,
         | 
| 287 | 
            +
                    indexing="xy",
         | 
| 288 | 
            +
                )
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                if normalize:
         | 
| 293 | 
            +
                    directions = F.normalize(directions, dim=-1)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                return directions
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            def get_rays(
         | 
| 299 | 
            +
                directions,
         | 
| 300 | 
            +
                c2w,
         | 
| 301 | 
            +
                keepdim=False,
         | 
| 302 | 
            +
                noise_scale=0.0,
         | 
| 303 | 
            +
                normalize=False,
         | 
| 304 | 
            +
            ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
         | 
| 305 | 
            +
                # Rotate ray directions from camera coordinate to the world coordinate
         | 
| 306 | 
            +
                assert directions.shape[-1] == 3
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                if directions.ndim == 2:  # (N_rays, 3)
         | 
| 309 | 
            +
                    if c2w.ndim == 2:  # (4, 4)
         | 
| 310 | 
            +
                        c2w = c2w[None, :, :]
         | 
| 311 | 
            +
                    assert c2w.ndim == 3  # (N_rays, 4, 4) or (1, 4, 4)
         | 
| 312 | 
            +
                    rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1)  # (N_rays, 3)
         | 
| 313 | 
            +
                    rays_o = c2w[:, :3, 3].expand(rays_d.shape)
         | 
| 314 | 
            +
                elif directions.ndim == 3:  # (H, W, 3)
         | 
| 315 | 
            +
                    assert c2w.ndim in [2, 3]
         | 
| 316 | 
            +
                    if c2w.ndim == 2:  # (4, 4)
         | 
| 317 | 
            +
                        rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
         | 
| 318 | 
            +
                            -1
         | 
| 319 | 
            +
                        )  # (H, W, 3)
         | 
| 320 | 
            +
                        rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
         | 
| 321 | 
            +
                    elif c2w.ndim == 3:  # (B, 4, 4)
         | 
| 322 | 
            +
                        rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
         | 
| 323 | 
            +
                            -1
         | 
| 324 | 
            +
                        )  # (B, H, W, 3)
         | 
| 325 | 
            +
                        rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
         | 
| 326 | 
            +
                elif directions.ndim == 4:  # (B, H, W, 3)
         | 
| 327 | 
            +
                    assert c2w.ndim == 3  # (B, 4, 4)
         | 
| 328 | 
            +
                    rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
         | 
| 329 | 
            +
                        -1
         | 
| 330 | 
            +
                    )  # (B, H, W, 3)
         | 
| 331 | 
            +
                    rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                # add camera noise to avoid grid-like artifect
         | 
| 334 | 
            +
                # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
         | 
| 335 | 
            +
                if noise_scale > 0:
         | 
| 336 | 
            +
                    rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
         | 
| 337 | 
            +
                    rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                if normalize:
         | 
| 340 | 
            +
                    rays_d = F.normalize(rays_d, dim=-1)
         | 
| 341 | 
            +
                if not keepdim:
         | 
| 342 | 
            +
                    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                return rays_o, rays_d
         | 
| 345 | 
            +
             | 
| 346 | 
            +
             | 
| 347 | 
            +
            def get_spherical_cameras(
         | 
| 348 | 
            +
                n_views: int,
         | 
| 349 | 
            +
                elevation_deg: float,
         | 
| 350 | 
            +
                camera_distance: float,
         | 
| 351 | 
            +
                fovy_deg: float,
         | 
| 352 | 
            +
                height: int,
         | 
| 353 | 
            +
                width: int,
         | 
| 354 | 
            +
            ):
         | 
| 355 | 
            +
                azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
         | 
| 356 | 
            +
                elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
         | 
| 357 | 
            +
                camera_distances = torch.full_like(elevation_deg, camera_distance)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                elevation = elevation_deg * math.pi / 180
         | 
| 360 | 
            +
                azimuth = azimuth_deg * math.pi / 180
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                # convert spherical coordinates to cartesian coordinates
         | 
| 363 | 
            +
                # right hand coordinate system, x back, y right, z up
         | 
| 364 | 
            +
                # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
         | 
| 365 | 
            +
                camera_positions = torch.stack(
         | 
| 366 | 
            +
                    [
         | 
| 367 | 
            +
                        camera_distances * torch.cos(elevation) * torch.cos(azimuth),
         | 
| 368 | 
            +
                        camera_distances * torch.cos(elevation) * torch.sin(azimuth),
         | 
| 369 | 
            +
                        camera_distances * torch.sin(elevation),
         | 
| 370 | 
            +
                    ],
         | 
| 371 | 
            +
                    dim=-1,
         | 
| 372 | 
            +
                )
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                # default scene center at origin
         | 
| 375 | 
            +
                center = torch.zeros_like(camera_positions)
         | 
| 376 | 
            +
                # default camera up direction as +z
         | 
| 377 | 
            +
                up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                lookat = F.normalize(center - camera_positions, dim=-1)
         | 
| 382 | 
            +
                right = F.normalize(torch.cross(lookat, up), dim=-1)
         | 
| 383 | 
            +
                up = F.normalize(torch.cross(right, lookat), dim=-1)
         | 
| 384 | 
            +
                c2w3x4 = torch.cat(
         | 
| 385 | 
            +
                    [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
         | 
| 386 | 
            +
                    dim=-1,
         | 
| 387 | 
            +
                )
         | 
| 388 | 
            +
                c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
         | 
| 389 | 
            +
                c2w[:, 3, 3] = 1.0
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                # get directions by dividing directions_unit_focal by focal length
         | 
| 392 | 
            +
                focal_length = 0.5 * height / torch.tan(0.5 * fovy)
         | 
| 393 | 
            +
                directions_unit_focal = get_ray_directions(
         | 
| 394 | 
            +
                    H=height,
         | 
| 395 | 
            +
                    W=width,
         | 
| 396 | 
            +
                    focal=1.0,
         | 
| 397 | 
            +
                )
         | 
| 398 | 
            +
                directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
         | 
| 399 | 
            +
                directions[:, :, :, :2] = (
         | 
| 400 | 
            +
                    directions[:, :, :, :2] / focal_length[:, None, None, None]
         | 
| 401 | 
            +
                )
         | 
| 402 | 
            +
                # must use normalize=True to normalize directions here
         | 
| 403 | 
            +
                rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                return rays_o, rays_d
         | 
| 406 | 
            +
             | 
| 407 | 
            +
             | 
| 408 | 
            +
            def remove_background(
         | 
| 409 | 
            +
                image: PIL.Image.Image,
         | 
| 410 | 
            +
                rembg_session: Any = None,
         | 
| 411 | 
            +
                force: bool = False,
         | 
| 412 | 
            +
                **rembg_kwargs,
         | 
| 413 | 
            +
            ) -> PIL.Image.Image:
         | 
| 414 | 
            +
                do_remove = True
         | 
| 415 | 
            +
                if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
         | 
| 416 | 
            +
                    do_remove = False
         | 
| 417 | 
            +
                do_remove = do_remove or force
         | 
| 418 | 
            +
                if do_remove:
         | 
| 419 | 
            +
                    image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
         | 
| 420 | 
            +
                return image
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            def resize_foreground(
         | 
| 424 | 
            +
                image: PIL.Image.Image,
         | 
| 425 | 
            +
                ratio: float,
         | 
| 426 | 
            +
            ) -> PIL.Image.Image:
         | 
| 427 | 
            +
                image = np.array(image)
         | 
| 428 | 
            +
                assert image.shape[-1] == 4
         | 
| 429 | 
            +
                alpha = np.where(image[..., 3] > 0)
         | 
| 430 | 
            +
                y1, y2, x1, x2 = (
         | 
| 431 | 
            +
                    alpha[0].min(),
         | 
| 432 | 
            +
                    alpha[0].max(),
         | 
| 433 | 
            +
                    alpha[1].min(),
         | 
| 434 | 
            +
                    alpha[1].max(),
         | 
| 435 | 
            +
                )
         | 
| 436 | 
            +
                # crop the foreground
         | 
| 437 | 
            +
                fg = image[y1:y2, x1:x2]
         | 
| 438 | 
            +
                # pad to square
         | 
| 439 | 
            +
                size = max(fg.shape[0], fg.shape[1])
         | 
| 440 | 
            +
                ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
         | 
| 441 | 
            +
                ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
         | 
| 442 | 
            +
                new_image = np.pad(
         | 
| 443 | 
            +
                    fg,
         | 
| 444 | 
            +
                    ((ph0, ph1), (pw0, pw1), (0, 0)),
         | 
| 445 | 
            +
                    mode="constant",
         | 
| 446 | 
            +
                    constant_values=((0, 0), (0, 0), (0, 0)),
         | 
| 447 | 
            +
                )
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                # compute padding according to the ratio
         | 
| 450 | 
            +
                new_size = int(new_image.shape[0] / ratio)
         | 
| 451 | 
            +
                # pad to size, double side
         | 
| 452 | 
            +
                ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
         | 
| 453 | 
            +
                ph1, pw1 = new_size - size - ph0, new_size - size - pw0
         | 
| 454 | 
            +
                new_image = np.pad(
         | 
| 455 | 
            +
                    new_image,
         | 
| 456 | 
            +
                    ((ph0, ph1), (pw0, pw1), (0, 0)),
         | 
| 457 | 
            +
                    mode="constant",
         | 
| 458 | 
            +
                    constant_values=((0, 0), (0, 0), (0, 0)),
         | 
| 459 | 
            +
                )
         | 
| 460 | 
            +
                new_image = PIL.Image.fromarray(new_image)
         | 
| 461 | 
            +
                return new_image
         | 
| 462 | 
            +
             | 
| 463 | 
            +
             | 
| 464 | 
            +
            def save_video(
         | 
| 465 | 
            +
                frames: List[PIL.Image.Image],
         | 
| 466 | 
            +
                output_path: str,
         | 
| 467 | 
            +
                fps: int = 30,
         | 
| 468 | 
            +
            ):
         | 
| 469 | 
            +
                # use imageio to save video
         | 
| 470 | 
            +
                frames = [np.array(frame) for frame in frames]
         | 
| 471 | 
            +
                writer = imageio.get_writer(output_path, fps=fps)
         | 
| 472 | 
            +
                for frame in frames:
         | 
| 473 | 
            +
                    writer.append_data(frame)
         | 
| 474 | 
            +
                writer.close()
         | 
| 475 | 
            +
             | 
| 476 | 
            +
             | 
| 477 | 
            +
            _dir2vec = {
         | 
| 478 | 
            +
                "+x": np.array([1, 0, 0]),
         | 
| 479 | 
            +
                "+y": np.array([0, 1, 0]),
         | 
| 480 | 
            +
                "+z": np.array([0, 0, 1]),
         | 
| 481 | 
            +
                "-x": np.array([-1, 0, 0]),
         | 
| 482 | 
            +
                "-y": np.array([0, -1, 0]),
         | 
| 483 | 
            +
                "-z": np.array([0, 0, -1]),
         | 
| 484 | 
            +
            }
         | 
| 485 | 
            +
             | 
| 486 | 
            +
             | 
| 487 | 
            +
            def to_gradio_3d_orientation(vertices):
         | 
| 488 | 
            +
                z_, x_ = _dir2vec["+y"], _dir2vec["-z"]
         | 
| 489 | 
            +
                y_ = np.cross(z_, x_)
         | 
| 490 | 
            +
                std2mesh = np.stack([x_, y_, z_], axis=0).T
         | 
| 491 | 
            +
                vertices = np.dot(std2mesh, vertices.T).T
         | 
| 492 | 
            +
                return vertices
         | 
