|
|
""" |
|
|
MagicLatentAdapter: two-in-one latent adapter for ComfyUI. |
|
|
|
|
|
- Mode "generate": creates a latent of appropriate grid size for the target model |
|
|
(optionally mixing an input image via VAE), then adapts channels. |
|
|
- Mode "adapt": takes an incoming LATENT and adapts channel count to match the model. |
|
|
|
|
|
Family switch: "auto / SD / SDXL / FLUX" influences only stride fallback when VAE |
|
|
is not provided. In AUTO we query VAE stride if possible and fall back to 8. |
|
|
|
|
|
No file re-encodings are performed; all code is ASCII/English as requested. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import comfy.sample as _sample |
|
|
|
|
|
|
|
|
class MagicLatentAdapter: |
|
|
"""Generate or adapt a LATENT to fit the target model's expectations.""" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"model": ("MODEL", {}), |
|
|
"mode": (["generate", "adapt"], {"default": "generate"}), |
|
|
"family": (["auto", "SD", "SDXL", "FLUX"], {"default": "auto"}), |
|
|
|
|
|
|
|
|
"width": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}), |
|
|
"height": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}), |
|
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), |
|
|
"sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1}), |
|
|
"bias": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1}), |
|
|
"mix_image": ("BOOLEAN", {"default": False}), |
|
|
}, |
|
|
"optional": { |
|
|
|
|
|
"latent": ("LATENT", {}), |
|
|
|
|
|
"vae": ("VAE", {}), |
|
|
"image": ("IMAGE", {}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("LATENT",) |
|
|
RETURN_NAMES = ("LATENT",) |
|
|
FUNCTION = "run" |
|
|
CATEGORY = "MagicNodes" |
|
|
|
|
|
@staticmethod |
|
|
def _detect_stride(vae, family: str) -> int: |
|
|
|
|
|
if vae is not None: |
|
|
try: |
|
|
s = int(vae.spacial_compression_decode()) |
|
|
if s > 0: |
|
|
return s |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
fam = (family or "auto").lower() |
|
|
if fam in ("sd", "sdxl", "flux"): |
|
|
return 8 |
|
|
return 8 |
|
|
|
|
|
@staticmethod |
|
|
def _latent_format(model) -> tuple[int, int]: |
|
|
"""Return (channels, dimensions) from model.latent_format. |
|
|
dimensions: 2 -> NCHW, 3 -> NCDHW. |
|
|
""" |
|
|
try: |
|
|
lf = model.get_model_object("latent_format") |
|
|
ch = int(getattr(lf, "latent_channels", 4)) |
|
|
dims = int(getattr(lf, "latent_dimensions", 2)) |
|
|
if dims not in (2, 3): |
|
|
dims = 2 |
|
|
return ch, dims |
|
|
except Exception: |
|
|
return 4, 2 |
|
|
|
|
|
@staticmethod |
|
|
def _adapt_channels(model, z: torch.Tensor, preserve_zero: bool = False) -> torch.Tensor: |
|
|
"""Adapts channel count and dims to the model's latent_format. |
|
|
If preserve_zero and the latent is all zeros, pad with zeros instead of noise. |
|
|
""" |
|
|
target_c, target_dims = MagicLatentAdapter._latent_format(model) |
|
|
|
|
|
|
|
|
try: |
|
|
z = _sample.fix_empty_latent_channels(model, z) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if target_dims == 3 and z.ndim == 4: |
|
|
z = z.unsqueeze(2) |
|
|
elif target_dims == 2 and z.ndim == 5: |
|
|
if z.shape[2] == 1: |
|
|
z = z.squeeze(2) |
|
|
else: |
|
|
z = z[:, :, :1].squeeze(2) |
|
|
|
|
|
|
|
|
if z.ndim == 4: |
|
|
B, C, H, W = z.shape |
|
|
if C == target_c: |
|
|
return z |
|
|
if C > target_c: |
|
|
return z[:, :target_c] |
|
|
dev, dt = z.device, z.dtype |
|
|
if preserve_zero and torch.count_nonzero(z) == 0: |
|
|
pad = torch.zeros(B, target_c - C, H, W, device=dev, dtype=dt) |
|
|
else: |
|
|
pad = torch.randn(B, target_c - C, H, W, device=dev, dtype=dt) |
|
|
return torch.cat([z, pad], dim=1) |
|
|
elif z.ndim == 5: |
|
|
B, C, D, H, W = z.shape |
|
|
if C == target_c: |
|
|
return z |
|
|
if C > target_c: |
|
|
return z[:, :target_c] |
|
|
dev, dt = z.device, z.dtype |
|
|
if preserve_zero and torch.count_nonzero(z) == 0: |
|
|
pad = torch.zeros(B, target_c - C, D, H, W, device=dev, dtype=dt) |
|
|
else: |
|
|
pad = torch.randn(B, target_c - C, D, H, W, device=dev, dtype=dt) |
|
|
return torch.cat([z, pad], dim=1) |
|
|
else: |
|
|
return z |
|
|
|
|
|
@staticmethod |
|
|
def _mix_image_into_latent(vae, image_bhwc: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
|
|
if vae is None or image_bhwc is None: |
|
|
return z |
|
|
try: |
|
|
|
|
|
try: |
|
|
stride = int(vae.spacial_compression_decode()) |
|
|
except Exception: |
|
|
stride = 8 |
|
|
h, w = image_bhwc.shape[1:3] |
|
|
def _align_up(x, s): |
|
|
return int(((x + s - 1) // s) * s) |
|
|
Ht, Wt = _align_up(h, stride), _align_up(w, stride) |
|
|
x = image_bhwc |
|
|
if (Ht != h) or (Wt != w): |
|
|
pad_h = Ht - h |
|
|
pad_w = Wt - w |
|
|
x_nchw = x.movedim(-1, 1) |
|
|
x_nchw = F.pad(x_nchw, (0, pad_w, 0, pad_h), mode='replicate') |
|
|
x = x_nchw.movedim(1, -1) |
|
|
enc = vae.encode(x[:, :, :, :3]) |
|
|
|
|
|
while enc.ndim < z.ndim: |
|
|
enc = enc.unsqueeze(2) |
|
|
while enc.ndim > z.ndim: |
|
|
|
|
|
if enc.ndim == 5 and enc.shape[2] == 1: |
|
|
enc = enc.squeeze(2) |
|
|
else: |
|
|
enc = enc[(slice(None), slice(None)) + (slice(0,1),) * (enc.ndim-2)] |
|
|
if enc.ndim == 5: |
|
|
enc = enc.squeeze(2) |
|
|
if enc.shape[0] != z.shape[0]: |
|
|
enc = enc[:1] |
|
|
enc = enc.repeat(z.shape[0], *([1] * (enc.ndim - 1))) |
|
|
|
|
|
if enc.ndim == 4: |
|
|
if enc.shape[2:] != z.shape[2:]: |
|
|
enc = F.interpolate(enc, size=z.shape[2:], mode="nearest") |
|
|
elif enc.ndim == 5: |
|
|
if enc.shape[2:] != z.shape[2:]: |
|
|
enc = F.interpolate(enc, size=z.shape[2:], mode="nearest") |
|
|
|
|
|
if enc.shape[1] != z.shape[1]: |
|
|
cmin = min(enc.shape[1], z.shape[1]) |
|
|
enc = enc[:, :cmin] |
|
|
z = z[:, :cmin] |
|
|
return enc + z |
|
|
except Exception: |
|
|
return z |
|
|
|
|
|
def run( |
|
|
self, |
|
|
model, |
|
|
mode: str, |
|
|
family: str, |
|
|
width: int, |
|
|
height: int, |
|
|
batch_size: int, |
|
|
sigma: float, |
|
|
bias: float, |
|
|
mix_image: bool = False, |
|
|
latent=None, |
|
|
vae=None, |
|
|
image=None, |
|
|
): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if mode == "adapt": |
|
|
if latent is None or "samples" not in latent: |
|
|
|
|
|
stride = self._detect_stride(vae, family) |
|
|
h8, w8 = max(1, height // stride), max(1, width // stride) |
|
|
target_c, target_dims = self._latent_format(model) |
|
|
if target_dims == 3: |
|
|
z = torch.zeros(batch_size, target_c, 1, h8, w8, device=device) |
|
|
else: |
|
|
z = torch.zeros(batch_size, target_c, h8, w8, device=device) |
|
|
else: |
|
|
z = latent["samples"].to(device) |
|
|
z = self._adapt_channels(model, z, preserve_zero=True) |
|
|
return ({"samples": z},) |
|
|
|
|
|
|
|
|
stride = self._detect_stride(vae, family) |
|
|
h8, w8 = max(1, height // stride), max(1, width // stride) |
|
|
target_c, target_dims = self._latent_format(model) |
|
|
if target_dims == 3: |
|
|
z = torch.randn(batch_size, target_c, 1, h8, w8, device=device) * float(sigma) + float(bias) |
|
|
else: |
|
|
z = torch.randn(batch_size, target_c, h8, w8, device=device) * float(sigma) + float(bias) |
|
|
if mix_image and (vae is not None) and (image is not None): |
|
|
|
|
|
img = image.to(device) |
|
|
z = self._mix_image_into_latent(vae, img, z) |
|
|
|
|
|
z = self._adapt_channels(model, z, preserve_zero=False) |
|
|
return ({"samples": z},) |
|
|
|