File size: 26,940 Bytes
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
from comfy.ldm.modules import attention as comfy_attention
import logging
import comfy.model_patcher
import comfy.utils
import comfy.sd
import torch
import comfy.model_management as mm
from comfy.cli_args import args

sageattn_modes = [
    "disabled",
    "auto",
    "auto_speed",
    "auto_quality",
    "sageattn_qk_int8_pv_fp16_cuda",
    "sageattn_qk_int8_pv_fp16_triton",
    "sageattn_qk_int8_pv_fp8_cuda",
    "sageattn_qk_int8_pv_fp8_cuda++",
]

_initialized = False
# Avoid spamming logs each attention call
_sage_warned_once = False
_sage_generic_warned_once = False
_original_functions = {}

# Runtime override knobs (may be set by other nodes, e.g., CADE2 Beta)
# CURRENT_PV_ACCUM can be None, "fp32+fp16" or "fp32+fp32"
CURRENT_PV_ACCUM = None

# Lightweight attention-entropy probe (for AQClip Attn-mode)
_attn_entropy_enabled = False
_attn_entropy_last = None  # torch.Tensor | None, shape (B,1,h',w') in [0,1]
_attn_probe_heads_cap = 4
_attn_probe_tokens_cap = 1024

def enable_attention_entropy_capture(enable: bool, max_tokens: int = 1024, max_heads: int = 4):
    """Toggle capturing a tiny attention entropy map during optimized_attention.
    Stores a normalized map per forward pass; consumer may upsample to latent size.
    """
    global _attn_entropy_enabled, _attn_probe_tokens_cap, _attn_probe_heads_cap, _attn_entropy_last
    _attn_entropy_enabled = bool(enable)
    _attn_probe_tokens_cap = int(max(128, min(16384, max_tokens)))
    _attn_probe_heads_cap = int(max(1, min(32, max_heads)))
    if not _attn_entropy_enabled:
        _attn_entropy_last = None

def get_attention_entropy_map(clear: bool = False):
    """Return last captured attention entropy map (B,1,h',w') in [0,1] or None."""
    global _attn_entropy_last
    out = _attn_entropy_last
    if clear:
        _attn_entropy_last = None
    return out

# ------------------------ KV pruning (self-attention) ------------------------
_kv_prune_enabled = False
_kv_prune_keep = 0.85
_kv_prune_min_tokens = 128

def set_kv_prune(enable: bool, keep: float = 0.85, min_tokens: int = 128):
    """Enable lightweight K/V token pruning inside optimized attention.
    - Applies only to self-attention (len(Q)==len(K)).
    - Keeps top-`keep` fraction of keys/values by L2 energy of K, averaged over heads.
    - Skips pruning when an attention mask is provided (shape mismatch risk).
    """
    global _kv_prune_enabled, _kv_prune_keep, _kv_prune_min_tokens
    _kv_prune_enabled = bool(enable)
    try:
        _kv_prune_keep = float(max(0.5, min(1.0, keep)))
    except Exception:
        _kv_prune_keep = 0.85
    try:
        _kv_prune_min_tokens = int(max(1, min_tokens))
    except Exception:
        _kv_prune_min_tokens = 128

if not _initialized:
    _original_functions["orig_attention"] = comfy_attention.optimized_attention
    _original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model
    _original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models
    _initialized = True

class MGSagpuBaseLoader:
    original_linear = None
    cublas_patched = False

    @torch.compiler.disable()
    def _patch_modules(self, patch_cublaslinear, sage_attention):
        from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight

        if sage_attention != "disabled":
            print("Patching comfy attention to use sageattn")
            try:
                from sageattention import sageattn
                from sageattention import (
                    sageattn_qk_int8_pv_fp16_cuda,
                    sageattn_qk_int8_pv_fp16_triton,
                    sageattn_qk_int8_pv_fp8_cuda,
                    sageattn_qk_int8_pv_fp8_cuda_sm90,
                )
            except ImportError:
                from SageAttention import sageattn
                from SageAttention import (
                    sageattn_qk_int8_pv_fp16_cuda,
                    sageattn_qk_int8_pv_fp16_triton,
                    sageattn_qk_int8_pv_fp8_cuda,
                    sageattn_qk_int8_pv_fp8_cuda_sm90,
                )
            def set_sage_func(sage_attention):
                # Helper: pick best kernel for current GPU
                def select_auto(quality: bool):
                    def _auto(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                        major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) if torch.cuda.is_available() else (0, 0)
                        try:
                            if major == 12 and minor == 0:
                                # RTX 50 series
                                pv = "fp32+fp32" if quality else "fp32+fp16"
                                return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout)
                            elif major == 9:
                                # H100 family
                                pv = "fp32+fp32" if quality else "fp32+fp32"
                                return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout)
                            elif major == 8 and minor == 9:
                                pv = "fp32+fp32" if quality else "fp32+fp16"
                                return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout)
                            elif major == 8 and minor in (0, 6):
                                # Ampere
                                # Prefer CUDA kernel when possible
                                return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
                        except Exception:
                            pass
                        # Generic auto (library decides), works across arch when available
                        return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
                    return _auto
                if sage_attention == "auto":
                    return select_auto(quality=False)
                if sage_attention == "auto_speed":
                    return select_auto(quality=False)
                if sage_attention == "auto_quality":
                    return select_auto(quality=True)
                elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
                    def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                        return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
                    return func
                elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
                    def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                        return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
                    return func
                elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
                    def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
                    return func
                elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
                    # using imported sageattn_qk_int8_pv_fp8_cuda above (name alias consistent for both module names)
                    # This variant requires SM89 (Ada 8.9). On newer GPUs (e.g., SM90),
                    # fall back to generic auto selection to avoid kernel assertion.
                    try:
                        if torch.cuda.is_available():
                            major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
                            if not (major == 8 and minor == 9):
                                logging.warning(f"sageattn_qk_int8_pv_fp8_cuda++ requires SM89, but detected SM{major}{minor}. Falling back to auto kernel selection.")
                                def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                                    return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
                                return func
                    except Exception:
                        pass
                    def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
                        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
                    return func

            sage_func = set_sage_func(sage_attention)

            @torch.compiler.disable()
            def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, transformer_options=None, **kwargs):
                if skip_reshape:
                    b, _, _, dim_head = q.shape
                    tensor_layout="HND"
                else:
                    b, _, dim_head = q.shape
                    dim_head //= heads
                    q, k, v = map(
                        lambda t: t.view(b, -1, heads, dim_head),
                        (q, k, v),
                    )
                    tensor_layout="NHD"
                if mask is not None:
                    # add a batch dimension if there isn't already one
                    if mask.ndim == 2:
                        mask = mask.unsqueeze(0)
                    # add a heads dimension if there isn't already one
                    if mask.ndim == 3:
                        mask = mask.unsqueeze(1)
                # Prefer trying sage kernels; allow runtime overrides via transformer_options or CURRENT_PV_ACCUM

                # Optional K/V pruning for self-attention (token-level top-k)
                try:
                    if _kv_prune_enabled and (mask is None):
                        import math
                        if tensor_layout == "NHD":
                            # q,k,v: B,N,H,D
                            Bn, Nq, Hn, Dh = q.shape
                            Nk = k.shape[1]
                            if Nq == Nk and Nk >= _kv_prune_min_tokens:
                                keep = max(1, int(math.ceil(float(_kv_prune_keep) * Nk)))
                                if keep < Nk:
                                    # importance: mean over heads of L2 norm of K per token
                                    imp = (k.pow(2).sum(dim=-1)).mean(dim=2)  # B,N
                                    top = torch.topk(imp, k=keep, dim=1, largest=True, sorted=False).indices
                                    idx = top.unsqueeze(-1).unsqueeze(-1).expand(Bn, keep, Hn, Dh)
                                    k = torch.gather(k, dim=1, index=idx)
                                    v = torch.gather(v, dim=1, index=idx)
                        else:
                            # HND: q,k,v: B,H,N,D
                            Bb, Hn, Nq, Dh = q.shape
                            Nk = k.shape[2]
                            if Nq == Nk and Nk >= _kv_prune_min_tokens:
                                keep = max(1, int(math.ceil(float(_kv_prune_keep) * Nk)))
                                if keep < Nk:
                                    imp = (k.pow(2).sum(dim=-1)).mean(dim=1)  # B,N
                                    top = torch.topk(imp, k=keep, dim=1, largest=True, sorted=False).indices
                                    idx = top.unsqueeze(1).unsqueeze(-1).expand(Bb, Hn, keep, Dh)
                                    k = torch.gather(k, dim=2, index=idx)
                                    v = torch.gather(v, dim=2, index=idx)
                except Exception:
                    # On any issue, skip pruning silently
                    pass

                try:
                    pv_override = None
                    if transformer_options and isinstance(transformer_options, dict):
                        so = transformer_options.get("sageattn")
                        if isinstance(so, dict):
                            pv_override = so.get("pv_accum_dtype", None)
                    if pv_override is None:
                        pv_override = CURRENT_PV_ACCUM

                    if pv_override is not None:
                        out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout, pv_accum_dtype=pv_override)
                    else:
                        out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
                except Exception as e:
                    global _sage_generic_warned_once
                    if not _sage_generic_warned_once:
                        logging.warning(f"Error running sage attention: {e}. Falling back.")
                        _sage_generic_warned_once = True
                    try:
                        out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
                    except Exception:
                        # Final fallback to PyTorch attention, silent after first warning
                        if tensor_layout == "NHD":
                            q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
                        return comfy_attention.attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, transformer_options=transformer_options, **kwargs)
                # Optional tiny attention-entropy probe (avoid heavy compute)
                try:
                    if _attn_entropy_enabled:
                        import torch
                        with torch.inference_mode():
                            if tensor_layout == "HND":
                                # q: B,H,N,D -> B,N,H,D for uniform handling
                                q_probe = q.transpose(1, 2)
                                k_probe = k.transpose(1, 2)
                            else:
                                q_probe = q
                                k_probe = k
                            B_, N_, H_, Dh = q_probe.shape
                            # Cap heads and tokens
                            h_cap = min(H_, _attn_probe_heads_cap)
                            step = max(1, N_ // _attn_probe_tokens_cap)
                            q_s = q_probe[:, ::step, :h_cap, :].transpose(1, 2)  # B,h,q,d
                            k_s = k_probe[:, ::step, :h_cap, :].transpose(1, 2)  # B,h,k,d
                            scale = (float(Dh) ** -0.5)
                            # logits: B,h,q,k
                            logits = torch.matmul(q_s * scale, k_s.transpose(-1, -2))
                            p = torch.softmax(logits, dim=-1)
                            # entropy per query
                            eps = 1e-9
                            Hq = -(p * (p.clamp_min(eps).log())).sum(dim=-1)  # B,h,q
                            Hq = Hq.mean(dim=1)  # B,q
                            # reshape to approx grid
                            import math
                            Q = Hq.shape[-1]
                            w = int(math.sqrt(Q))
                            w = max(1, w)
                            h = max(1, Q // w)
                            if h * w > Q:
                                Hq = Hq[..., : (h * w)]
                            elif h * w < Q:
                                # pad with last
                                pad = (h * w) - Q
                                if pad > 0:
                                    Hq = torch.cat([Hq, Hq[..., -1:].expand(B_, pad)], dim=-1)
                            Hmap = Hq.reshape(B_, 1, h, w)
                            # normalize per-sample to [0,1]
                            Hmin = Hmap.amin(dim=(2, 3), keepdim=True)
                            Hmax = Hmap.amax(dim=(2, 3), keepdim=True)
                            Hn = (Hmap - Hmin) / (Hmax - Hmin + 1e-6)
                            global _attn_entropy_last
                            _attn_entropy_last = Hn.detach()
                except Exception:
                    pass

                if tensor_layout == "HND":
                    if not skip_output_reshape:
                        out = (
                            out.transpose(1, 2).reshape(b, -1, heads * dim_head)
                        )
                else:
                    if skip_output_reshape:
                        out = out.transpose(1, 2)
                    else:
                        out = out.reshape(b, -1, heads * dim_head)
                return out

            comfy_attention.optimized_attention = attention_sage
            comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
            comfy.ldm.flux.math.optimized_attention = attention_sage
            comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
            comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
            comfy.ldm.wan.model.optimized_attention = attention_sage

        else:
            print("Restoring initial comfy attention")
            comfy_attention.optimized_attention = _original_functions.get("orig_attention")
            comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
            comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
            comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
            comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
            comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")

        if patch_cublaslinear:
            if not MGSagpuBaseLoader.cublas_patched:
                MGSagpuBaseLoader.original_linear = disable_weight_init.Linear
                try:
                    from cublas_ops import CublasLinear
                except ImportError:
                    raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")

                class PatchedLinear(CublasLinear, CastWeightBiasOp):
                    def reset_parameters(self):
                        pass

                    def forward_comfy_cast_weights(self, input):
                        weight, bias = cast_bias_weight(self, input)
                        return torch.nn.functional.linear(input, weight, bias)

                    def forward(self, *args, **kwargs):
                        if self.comfy_cast_weights:
                            return self.forward_comfy_cast_weights(*args, **kwargs)
                        else:
                            return super().forward(*args, **kwargs)

                disable_weight_init.Linear = PatchedLinear
                MGSagpuBaseLoader.cublas_patched = True
        else:
            if MGSagpuBaseLoader.cublas_patched:
                disable_weight_init.Linear = MGSagpuBaseLoader.original_linear
                MGSagpuBaseLoader.cublas_patched = False

from comfy.patcher_extension import CallbacksMP
class MGSagpuAttention(MGSagpuBaseLoader):
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
        }}

    RETURN_TYPES = ("MODEL", )
    FUNCTION = "patch"
    DESCRIPTION = "Node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option."
    EXPERIMENTAL = False
    CATEGORY = "MagicNodes"

    def patch(self, model, sage_attention):
        model_clone = model.clone()
        @torch.compiler.disable()
        def patch_attention_enable(model):
            self._patch_modules(False, sage_attention)
        @torch.compiler.disable()
        def patch_attention_disable(model):
            self._patch_modules(False, "disabled")
        
        model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable)
        model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable)
        
        return model_clone,
 


# Legacy compile helpers removed

# Legacy video helpers removed
import inspect as _inspect
try:
    from comfy.ldm.modules import attention as _cm_attn
except Exception as _e:
    _cm_attn = None

_nag_patch_active = False
_nag_params = {"scale": 5.0, "tau": 2.5, "alpha": 0.25}
_original_functions.setdefault("orig_crossattn_forward", None)
_original_functions.setdefault("orig_crossattn_sig", None)

def _call_orig_crossattn(self, x, context=None, **kwargs):
    #\"\"\"Call the original CrossAttention.forward with kwargs filtered to its signature.\"\"\"
    f = _original_functions.get("orig_crossattn_forward", None)
    if f is None:
        # Should not happen; just try current method
        return self.__class__.forward(self, x, context=context, **kwargs)
    sig = _original_functions.get("orig_crossattn_sig", None)
    if sig is None:
        try:
            sig = _inspect.signature(f)
            _original_functions["orig_crossattn_sig"] = sig
        except Exception:
            sig = None
    if sig is not None:
        allowed = set(sig.parameters.keys())
        fkwargs = {k: v for k, v in kwargs.items() if k in allowed}
    else:
        fkwargs = kwargs
    try:
        return f(self, x, context=context, **fkwargs)
    except TypeError:
        # Some builds have (x, context=None, value=None, mask=None) only
        fkwargs.pop("attn_precision", None)
        fkwargs.pop("transformer_options", None)
        try:
            return f(self, x, context=context, **fkwargs)
        except Exception:
            # Give up; call current method (unpatched) to avoid crashing
            return self.__class__.forward(self, x, context=context, **kwargs)

def _kj_crossattn_forward_nag(self, x, context=None, value=None, mask=None, **kwargs):
    # If patch not active or context not having cond/uncond, defer to original.
    if (not _nag_patch_active) or (_cm_attn is None):
        return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
    try:
        if context is None or not torch.is_tensor(context):
            return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)

        # Expect batch 2 with [uncond, cond]; if not, fall back
        if context.shape[0] < 2:
            return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)

        # Split branches. In most samplers order is [uncond, cond].
        # If x has batch==2, split it likewise; else use the same x for both calls.
        x_has_pair = (torch.is_tensor(x) and x.shape[0] == 2)
        x_u = x[0:1] if x_has_pair else x
        x_c = x[1:2] if x_has_pair else x

        c_u, c_c = context[0:1], context[1:2]

        # value may also be batched
        v = kwargs.get("value", value)
        if torch.is_tensor(v) and v.shape[0] == 2:
            v_u, v_c = v[0:1], v[1:2]
        else:
            v_u = v_c = v

        # Get per-branch outputs using the ORIGINAL forward
        # - Neg branch (for real uncond stream)
        out_u = _call_orig_crossattn(self, x_u, context=c_u, value=v_u, mask=mask, **kwargs)
        # - Pos branch
        z_pos = _call_orig_crossattn(self, x_c, context=c_c, value=v_c, mask=mask, **kwargs)
        # - "Neg guidance" term computed with *positive query but negative context*
        z_neg = _call_orig_crossattn(self, x_c, context=c_u, value=v_u, mask=mask, **kwargs)

        # NAG mixing in the attention output space
        phi = float(_nag_params.get("scale", 5.0))
        tau = float(_nag_params.get("tau", 2.5))
        alpha = float(_nag_params.get("alpha", 0.25))

        g = z_pos * phi - z_neg * (phi - 1.0)
        # L1-norm based clipping to limit deviation from Z+
        def _l1_norm(t):
            return torch.sum(torch.abs(t), dim=-1, keepdim=True).clamp_min(1e-6)
        s_pos = _l1_norm(z_pos)
        s_g   = _l1_norm(g)
        scale = (s_pos * tau) / s_g
        g = torch.where((s_g > s_pos * tau), g * scale, g)

        z_guided = g * alpha + z_pos * (1.0 - alpha)
        if x_has_pair:
            return torch.cat([out_u, z_guided], dim=0)
        else:
            return z_guided
    except Exception as e:
        # If anything goes wrong, use the original forward.
        return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)

def enable_crossattention_nag_patch(enable: bool, nag_scale: float = 5.0, nag_tau: float = 2.5, nag_alpha: float = 0.25):
    #\"\"\"Enable/disable a safe CrossAttention forward wrapper that applies NAG to the positive branch only.
    #This does not modify model weights and is fully reversible. The wrapper preserves
    #unknown kwargs (filters per-signature) to avoid errors on older Comfy builds.
    #\"\"\"
    global _nag_patch_active, _nag_params
    if _cm_attn is None:
        return False
    if enable:
        _nag_params = {"scale": float(nag_scale), "tau": float(nag_tau), "alpha": float(nag_alpha)}
        if _original_functions.get("orig_crossattn_forward", None) is None:
            try:
                _original_functions["orig_crossattn_forward"] = _cm_attn.CrossAttention.forward
                try:
                    _original_functions["orig_crossattn_sig"] = _inspect.signature(_cm_attn.CrossAttention.forward)
                except Exception:
                    _original_functions["orig_crossattn_sig"] = None
            except Exception:
                return False
        # Patch in our wrapper
        try:
            _cm_attn.CrossAttention.forward = _kj_crossattn_forward_nag
            _nag_patch_active = True
            return True
        except Exception:
            return False
    else:
        # Restore original if we have it
        if _original_functions.get("orig_crossattn_forward", None) is not None:
            try:
                _cm_attn.CrossAttention.forward = _original_functions["orig_crossattn_forward"]
            except Exception:
                pass
        _nag_patch_active = False
        return True
# ===============================================================================

PatchSageAttention = MGSagpuAttention