File size: 9,071 Bytes
ccadebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
MagicLatentAdapter: two-in-one latent adapter for ComfyUI.

- Mode "generate": creates a latent of appropriate grid size for the target model
  (optionally mixing an input image via VAE), then adapts channels.
- Mode "adapt": takes an incoming LATENT and adapts channel count to match the model.

Family switch: "auto / SD / SDXL / FLUX" influences only stride fallback when VAE
is not provided. In AUTO we query VAE stride if possible and fall back to 8.

No file re-encodings are performed; all code is ASCII/English as requested.
"""

from __future__ import annotations

import torch
import torch.nn.functional as F

import comfy.sample as _sample


class MagicLatentAdapter:
    """Generate or adapt a LATENT to fit the target model's expectations."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model": ("MODEL", {}),
                "mode": (["generate", "adapt"], {"default": "generate"}),
                "family": (["auto", "SD", "SDXL", "FLUX"], {"default": "auto"}),

                # Generation params (ignored in adapt mode)
                "width": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}),
                "height": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}),
                "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
                "sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1}),
                "bias": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1}),
                "mix_image": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                # For adapt mode
                "latent": ("LATENT", {}),
                # For image mixing in generate mode
                "vae": ("VAE", {}),
                "image": ("IMAGE", {}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    RETURN_NAMES = ("LATENT",)
    FUNCTION = "run"
    CATEGORY = "MagicNodes"

    @staticmethod
    def _detect_stride(vae, family: str) -> int:
        # Prefer VAE stride if available
        if vae is not None:
            try:
                s = int(vae.spacial_compression_decode())
                if s > 0:
                    return s
            except Exception:
                pass
        # Fallback per-family (conservative)
        fam = (family or "auto").lower()
        if fam in ("sd", "sdxl", "flux"):
            return 8
        return 8  # sensible default

    @staticmethod
    def _latent_format(model) -> tuple[int, int]:
        """Return (channels, dimensions) from model.latent_format.
        dimensions: 2 -> NCHW, 3 -> NCDHW.
        """
        try:
            lf = model.get_model_object("latent_format")
            ch = int(getattr(lf, "latent_channels", 4))
            dims = int(getattr(lf, "latent_dimensions", 2))
            if dims not in (2, 3):
                dims = 2
            return ch, dims
        except Exception:
            return 4, 2

    @staticmethod
    def _adapt_channels(model, z: torch.Tensor, preserve_zero: bool = False) -> torch.Tensor:
        """Adapts channel count and dims to the model's latent_format.
        If preserve_zero and the latent is all zeros, pad with zeros instead of noise.
        """
        target_c, target_dims = MagicLatentAdapter._latent_format(model)

        # First, let Comfy add depth dim for empty latents when needed
        try:
            z = _sample.fix_empty_latent_channels(model, z)
        except Exception:
            pass

        # Align dimensions
        if target_dims == 3 and z.ndim == 4:
            z = z.unsqueeze(2)  # N C 1 H W
        elif target_dims == 2 and z.ndim == 5:
            if z.shape[2] == 1:
                z = z.squeeze(2)
            else:
                z = z[:, :, :1].squeeze(2)

        # Align channels
        if z.ndim == 4:
            B, C, H, W = z.shape
            if C == target_c:
                return z
            if C > target_c:
                return z[:, :target_c]
            dev, dt = z.device, z.dtype
            if preserve_zero and torch.count_nonzero(z) == 0:
                pad = torch.zeros(B, target_c - C, H, W, device=dev, dtype=dt)
            else:
                pad = torch.randn(B, target_c - C, H, W, device=dev, dtype=dt)
            return torch.cat([z, pad], dim=1)
        elif z.ndim == 5:
            B, C, D, H, W = z.shape
            if C == target_c:
                return z
            if C > target_c:
                return z[:, :target_c]
            dev, dt = z.device, z.dtype
            if preserve_zero and torch.count_nonzero(z) == 0:
                pad = torch.zeros(B, target_c - C, D, H, W, device=dev, dtype=dt)
            else:
                pad = torch.randn(B, target_c - C, D, H, W, device=dev, dtype=dt)
            return torch.cat([z, pad], dim=1)
        else:
            return z

    @staticmethod
    def _mix_image_into_latent(vae, image_bhwc: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        if vae is None or image_bhwc is None:
            return z
        try:
            # Align image spatial to VAE grid by padding (mirror) if needed
            try:
                stride = int(vae.spacial_compression_decode())
            except Exception:
                stride = 8
            h, w = image_bhwc.shape[1:3]
            def _align_up(x, s):
                return int(((x + s - 1) // s) * s)
            Ht, Wt = _align_up(h, stride), _align_up(w, stride)
            x = image_bhwc
            if (Ht != h) or (Wt != w):
                pad_h = Ht - h
                pad_w = Wt - w
                x_nchw = x.movedim(-1, 1)
                x_nchw = F.pad(x_nchw, (0, pad_w, 0, pad_h), mode='replicate')
                x = x_nchw.movedim(1, -1)
            enc = vae.encode(x[:, :, :, :3])
            # If batch mismatches, use first encoding and tile
            while enc.ndim < z.ndim:
                enc = enc.unsqueeze(2)  # add depth dim if needed
            while enc.ndim > z.ndim:
                # reduce extra depth dims
                if enc.ndim == 5 and enc.shape[2] == 1:
                    enc = enc.squeeze(2)
                else:
                    enc = enc[(slice(None), slice(None)) + (slice(0,1),) * (enc.ndim-2)]
                    if enc.ndim == 5:
                        enc = enc.squeeze(2)
            if enc.shape[0] != z.shape[0]:
                enc = enc[:1]
                enc = enc.repeat(z.shape[0], *([1] * (enc.ndim - 1)))
            # Resize spatial if needed (nearest)
            if enc.ndim == 4:
                if enc.shape[2:] != z.shape[2:]:
                    enc = F.interpolate(enc, size=z.shape[2:], mode="nearest")
            elif enc.ndim == 5:
                if enc.shape[2:] != z.shape[2:]:
                    enc = F.interpolate(enc, size=z.shape[2:], mode="nearest")
            # Channel adapt for mixing safety
            if enc.shape[1] != z.shape[1]:
                cmin = min(enc.shape[1], z.shape[1])
                enc = enc[:, :cmin]
                z = z[:, :cmin]
            return enc + z
        except Exception:
            return z

    def run(
        self,
        model,
        mode: str,
        family: str,
        width: int,
        height: int,
        batch_size: int,
        sigma: float,
        bias: float,
        mix_image: bool = False,
        latent=None,
        vae=None,
        image=None,
    ):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if mode == "adapt":
            if latent is None or "samples" not in latent:
                # Produce an empty latent matching model's latent_format
                stride = self._detect_stride(vae, family)
                h8, w8 = max(1, height // stride), max(1, width // stride)
                target_c, target_dims = self._latent_format(model)
                if target_dims == 3:
                    z = torch.zeros(batch_size, target_c, 1, h8, w8, device=device)
                else:
                    z = torch.zeros(batch_size, target_c, h8, w8, device=device)
            else:
                z = latent["samples"].to(device)
            z = self._adapt_channels(model, z, preserve_zero=True)
            return ({"samples": z},)

        # generate
        stride = self._detect_stride(vae, family)
        h8, w8 = max(1, height // stride), max(1, width // stride)
        target_c, target_dims = self._latent_format(model)
        if target_dims == 3:
            z = torch.randn(batch_size, target_c, 1, h8, w8, device=device) * float(sigma) + float(bias)
        else:
            z = torch.randn(batch_size, target_c, h8, w8, device=device) * float(sigma) + float(bias)
        if mix_image and (vae is not None) and (image is not None):
            # image is BHWC 0..1
            img = image.to(device)
            z = self._mix_image_into_latent(vae, img, z)
        # Final channel adaptation
        z = self._adapt_channels(model, z, preserve_zero=False)
        return ({"samples": z},)