MagicNodes / mod /mg_latent_adapter.py
DZRobo
Add MagicLatentAdapter node (experimental support FLUX/Qwen models) (#10)
ccadebd unverified
raw
history blame
9.07 kB
"""
MagicLatentAdapter: two-in-one latent adapter for ComfyUI.
- Mode "generate": creates a latent of appropriate grid size for the target model
(optionally mixing an input image via VAE), then adapts channels.
- Mode "adapt": takes an incoming LATENT and adapts channel count to match the model.
Family switch: "auto / SD / SDXL / FLUX" influences only stride fallback when VAE
is not provided. In AUTO we query VAE stride if possible and fall back to 8.
No file re-encodings are performed; all code is ASCII/English as requested.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
import comfy.sample as _sample
class MagicLatentAdapter:
"""Generate or adapt a LATENT to fit the target model's expectations."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL", {}),
"mode": (["generate", "adapt"], {"default": "generate"}),
"family": (["auto", "SD", "SDXL", "FLUX"], {"default": "auto"}),
# Generation params (ignored in adapt mode)
"width": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}),
"height": ("INT", {"default": 512, "min": 8, "max": 8192, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"bias": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1}),
"mix_image": ("BOOLEAN", {"default": False}),
},
"optional": {
# For adapt mode
"latent": ("LATENT", {}),
# For image mixing in generate mode
"vae": ("VAE", {}),
"image": ("IMAGE", {}),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("LATENT",)
FUNCTION = "run"
CATEGORY = "MagicNodes"
@staticmethod
def _detect_stride(vae, family: str) -> int:
# Prefer VAE stride if available
if vae is not None:
try:
s = int(vae.spacial_compression_decode())
if s > 0:
return s
except Exception:
pass
# Fallback per-family (conservative)
fam = (family or "auto").lower()
if fam in ("sd", "sdxl", "flux"):
return 8
return 8 # sensible default
@staticmethod
def _latent_format(model) -> tuple[int, int]:
"""Return (channels, dimensions) from model.latent_format.
dimensions: 2 -> NCHW, 3 -> NCDHW.
"""
try:
lf = model.get_model_object("latent_format")
ch = int(getattr(lf, "latent_channels", 4))
dims = int(getattr(lf, "latent_dimensions", 2))
if dims not in (2, 3):
dims = 2
return ch, dims
except Exception:
return 4, 2
@staticmethod
def _adapt_channels(model, z: torch.Tensor, preserve_zero: bool = False) -> torch.Tensor:
"""Adapts channel count and dims to the model's latent_format.
If preserve_zero and the latent is all zeros, pad with zeros instead of noise.
"""
target_c, target_dims = MagicLatentAdapter._latent_format(model)
# First, let Comfy add depth dim for empty latents when needed
try:
z = _sample.fix_empty_latent_channels(model, z)
except Exception:
pass
# Align dimensions
if target_dims == 3 and z.ndim == 4:
z = z.unsqueeze(2) # N C 1 H W
elif target_dims == 2 and z.ndim == 5:
if z.shape[2] == 1:
z = z.squeeze(2)
else:
z = z[:, :, :1].squeeze(2)
# Align channels
if z.ndim == 4:
B, C, H, W = z.shape
if C == target_c:
return z
if C > target_c:
return z[:, :target_c]
dev, dt = z.device, z.dtype
if preserve_zero and torch.count_nonzero(z) == 0:
pad = torch.zeros(B, target_c - C, H, W, device=dev, dtype=dt)
else:
pad = torch.randn(B, target_c - C, H, W, device=dev, dtype=dt)
return torch.cat([z, pad], dim=1)
elif z.ndim == 5:
B, C, D, H, W = z.shape
if C == target_c:
return z
if C > target_c:
return z[:, :target_c]
dev, dt = z.device, z.dtype
if preserve_zero and torch.count_nonzero(z) == 0:
pad = torch.zeros(B, target_c - C, D, H, W, device=dev, dtype=dt)
else:
pad = torch.randn(B, target_c - C, D, H, W, device=dev, dtype=dt)
return torch.cat([z, pad], dim=1)
else:
return z
@staticmethod
def _mix_image_into_latent(vae, image_bhwc: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
if vae is None or image_bhwc is None:
return z
try:
# Align image spatial to VAE grid by padding (mirror) if needed
try:
stride = int(vae.spacial_compression_decode())
except Exception:
stride = 8
h, w = image_bhwc.shape[1:3]
def _align_up(x, s):
return int(((x + s - 1) // s) * s)
Ht, Wt = _align_up(h, stride), _align_up(w, stride)
x = image_bhwc
if (Ht != h) or (Wt != w):
pad_h = Ht - h
pad_w = Wt - w
x_nchw = x.movedim(-1, 1)
x_nchw = F.pad(x_nchw, (0, pad_w, 0, pad_h), mode='replicate')
x = x_nchw.movedim(1, -1)
enc = vae.encode(x[:, :, :, :3])
# If batch mismatches, use first encoding and tile
while enc.ndim < z.ndim:
enc = enc.unsqueeze(2) # add depth dim if needed
while enc.ndim > z.ndim:
# reduce extra depth dims
if enc.ndim == 5 and enc.shape[2] == 1:
enc = enc.squeeze(2)
else:
enc = enc[(slice(None), slice(None)) + (slice(0,1),) * (enc.ndim-2)]
if enc.ndim == 5:
enc = enc.squeeze(2)
if enc.shape[0] != z.shape[0]:
enc = enc[:1]
enc = enc.repeat(z.shape[0], *([1] * (enc.ndim - 1)))
# Resize spatial if needed (nearest)
if enc.ndim == 4:
if enc.shape[2:] != z.shape[2:]:
enc = F.interpolate(enc, size=z.shape[2:], mode="nearest")
elif enc.ndim == 5:
if enc.shape[2:] != z.shape[2:]:
enc = F.interpolate(enc, size=z.shape[2:], mode="nearest")
# Channel adapt for mixing safety
if enc.shape[1] != z.shape[1]:
cmin = min(enc.shape[1], z.shape[1])
enc = enc[:, :cmin]
z = z[:, :cmin]
return enc + z
except Exception:
return z
def run(
self,
model,
mode: str,
family: str,
width: int,
height: int,
batch_size: int,
sigma: float,
bias: float,
mix_image: bool = False,
latent=None,
vae=None,
image=None,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if mode == "adapt":
if latent is None or "samples" not in latent:
# Produce an empty latent matching model's latent_format
stride = self._detect_stride(vae, family)
h8, w8 = max(1, height // stride), max(1, width // stride)
target_c, target_dims = self._latent_format(model)
if target_dims == 3:
z = torch.zeros(batch_size, target_c, 1, h8, w8, device=device)
else:
z = torch.zeros(batch_size, target_c, h8, w8, device=device)
else:
z = latent["samples"].to(device)
z = self._adapt_channels(model, z, preserve_zero=True)
return ({"samples": z},)
# generate
stride = self._detect_stride(vae, family)
h8, w8 = max(1, height // stride), max(1, width // stride)
target_c, target_dims = self._latent_format(model)
if target_dims == 3:
z = torch.randn(batch_size, target_c, 1, h8, w8, device=device) * float(sigma) + float(bias)
else:
z = torch.randn(batch_size, target_c, h8, w8, device=device) * float(sigma) + float(bias)
if mix_image and (vae is not None) and (image is not None):
# image is BHWC 0..1
img = image.to(device)
z = self._mix_image_into_latent(vae, img, z)
# Final channel adaptation
z = self._adapt_channels(model, z, preserve_zero=False)
return ({"samples": z},)