diff --git "a/mod/easy/mg_cade25_easy.py" "b/mod/easy/mg_cade25_easy.py" new file mode 100644--- /dev/null +++ "b/mod/easy/mg_cade25_easy.py" @@ -0,0 +1,2968 @@ +"""CADE 2.5: refined adaptive enhancer with reference clean and accumulation override. +""" + +from __future__ import annotations # moved/renamed module: mg_cade25 + +import torch +import os +import numpy as np +import torch.nn.functional as F + +import nodes +import comfy.model_management as model_management + +from ..hard.mg_adaptive import AdaptiveSamplerHelper +from ..hard.mg_zesmart_sampler_v1_1 import _build_hybrid_sigmas +import comfy.sample as _sample +import comfy.samplers as _samplers +import comfy.utils as _utils +from ..hard.mg_upscale_module import MagicUpscaleModule, clear_gpu_and_ram_cache +from ..hard.mg_controlfusion import _build_depth_map as _cf_build_depth_map +from ..hard.mg_ids import IntelligentDetailStabilizer +from .. import mg_sagpu_attention as sa_patch +from .preset_loader import get as load_preset +from ..hard.mg_controlfusion import _pyracanny as _cf_pyracanny, _build_depth_map as _cf_build_depth +# FDG/NAG experimental paths removed for now; keeping code lean + +_ONNX_RT = None +_ONNX_SESS = {} # name -> onnxruntime.InferenceSession +_ONNX_WARNED = False +_ONNX_DEBUG = False +_ONNX_FORCE_CPU = True # pin ONNX to CPU for deterministic behavior +_ONNX_COUNT_DEBUG = True # print detected counts (faces/hands/persons) when True (temporarily forced ON) + +# Lazy CLIPSeg cache +_CLIPSEG_MODEL = None +_CLIPSEG_PROC = None +_CLIPSEG_DEV = "cpu" +_CLIPSEG_FORCE_CPU = True # pin CLIPSeg to CPU to avoid device drift + +# ONNX keypoints (wholebody/pose) parsing toggles (set by UI at runtime) +_ONNX_KPTS_ENABLE = False +_ONNX_KPTS_SIGMA = 2.5 +_ONNX_KPTS_GAIN = 1.5 +_ONNX_KPTS_CONF = 0.20 + +# Per-iteration spatial guidance mask (B,1,H,W) in [0,1]; used by cfg_func when enabled +CURRENT_ONNX_MASK_BCHW = None + + +# --- AQClip-Lite: adaptive soft quantile clipping in latent space (tile overlap) --- +@torch.no_grad() +def _aqclip_lite(latent_bchw: torch.Tensor, + tile: int = 32, + stride: int = 16, + alpha: float = 2.0, + ema_state: dict | None = None, + ema_beta: float = 0.8) -> tuple[torch.Tensor, dict]: + try: + z = latent_bchw + B, C, H, W = z.shape + dev, dt = z.device, z.dtype + ksize = max(8, min(int(tile), min(H, W))) + kstride = max(1, min(int(stride), ksize)) + + # Confidence proxy: gradient magnitude on channel-mean latent + zm = z.mean(dim=1, keepdim=True) + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=dev, dtype=dt).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=dev, dtype=dt).view(1, 1, 3, 3) + gx = F.conv2d(zm, kx, padding=1) + gy = F.conv2d(zm, ky, padding=1) + gmag = torch.sqrt(gx * gx + gy * gy) + gpool = F.avg_pool2d(gmag, kernel_size=ksize, stride=kstride) + gmax = gpool.amax(dim=(2, 3), keepdim=True).clamp_min(1e-6) + Hn = (gpool / gmax).squeeze(1) # B,h',w' + L = Hn.shape[1] * Hn.shape[2] + Hn = Hn.reshape(B, L) + + # Map confidence -> quantiles + ql = 0.5 * (Hn ** 2) + qh = 1.0 - 0.5 * ((1.0 - Hn) ** 2) + + # Per-tile mean/std + unf = F.unfold(z, kernel_size=ksize, stride=kstride) # B, C*ksize*ksize, L + M = unf.shape[1] + mu = unf.mean(dim=1).to(torch.float32) # B,L + var = (unf.to(torch.float32) - mu.unsqueeze(1)).pow(2).mean(dim=1) + sigma = (var + 1e-12).sqrt() + + # Normal inverse approximation: ndtri(q) = sqrt(2)*erfinv(2q-1) + def _ndtri(q: torch.Tensor) -> torch.Tensor: + return (2.0 ** 0.5) * torch.special.erfinv(q.mul(2.0).sub(1.0).clamp(-0.999999, 0.999999)) + k_neg = _ndtri(ql).abs() + k_pos = _ndtri(qh).abs() + lo = mu - k_neg * sigma + hi = mu + k_pos * sigma + + # EMA smooth + if ema_state is None: + ema_state = {} + b = float(max(0.0, min(0.999, ema_beta))) + if 'lo' in ema_state and 'hi' in ema_state and ema_state['lo'].shape == lo.shape: + lo = b * ema_state['lo'] + (1.0 - b) * lo + hi = b * ema_state['hi'] + (1.0 - b) * hi + ema_state['lo'] = lo.detach() + ema_state['hi'] = hi.detach() + + # Soft tanh clip (vectorized in unfold domain) + mid = (lo + hi) * 0.5 + half = (hi - lo) * 0.5 + half = half.clamp_min(1e-6) + y = (unf.to(torch.float32) - mid.unsqueeze(1)) / half.unsqueeze(1) + y = torch.tanh(float(alpha) * y) + unf_clipped = mid.unsqueeze(1) + half.unsqueeze(1) * y + unf_clipped = unf_clipped.to(dt) + + out = F.fold(unf_clipped, output_size=(H, W), kernel_size=ksize, stride=kstride) + ones = torch.ones((B, M, L), device=dev, dtype=dt) + w = F.fold(ones, output_size=(H, W), kernel_size=ksize, stride=kstride).clamp_min(1e-6) + out = out / w + return out, ema_state + except Exception: + return latent_bchw, (ema_state or {}) + + +def _try_init_onnx(models_dir: str): + """Initialize onnxruntime and load all .onnx models in models_dir. + We prefer GPU providers when available, but gracefully fall back to CPU. + """ + global _ONNX_RT, _ONNX_SESS, _ONNX_WARNED + import os + if _ONNX_RT is None: + try: + import onnxruntime as ort + _ONNX_RT = ort + except Exception: + if not _ONNX_WARNED: + print("[CADE2.5][ONNX] onnxruntime not available, skipping ONNX detectors.") + _ONNX_WARNED = True + return False + + # Build provider preference list + try: + avail = set(_ONNX_RT.get_available_providers()) + except Exception: + avail = set() + pref = [] + if _ONNX_FORCE_CPU: + pref = ["CPUExecutionProvider"] + else: + for p in ("CUDAExecutionProvider", "DmlExecutionProvider", "CPUExecutionProvider"): + if p in avail or p == "CPUExecutionProvider": + pref.append(p) + if _ONNX_DEBUG: + try: + print(f"[CADE2.5][ONNX] Available providers: {sorted(list(avail))}") + print(f"[CADE2.5][ONNX] Provider preference: {pref}") + except Exception: + pass + + # Load any .onnx in models_dir + try: + for fname in os.listdir(models_dir): + if not fname.lower().endswith('.onnx'): + continue + if fname in _ONNX_SESS: + continue + full = os.path.join(models_dir, fname) + try: + _ONNX_SESS[fname] = _ONNX_RT.InferenceSession(full, providers=pref) + if _ONNX_DEBUG: + try: + print(f"[CADE2.5][ONNX] Loaded model: {fname}") + except Exception: + pass + except Exception as e: + if not _ONNX_WARNED: + print(f"[CADE2.5][ONNX] failed to load {fname}: {e}") + except Exception as e: + if not _ONNX_WARNED: + print(f"[CADE2.5][ONNX] cannot list models in {models_dir}: {e}") + if not _ONNX_SESS and not _ONNX_WARNED: + print("[CADE2.5][ONNX] No ONNX models found in", models_dir) + _ONNX_WARNED = True + return len(_ONNX_SESS) > 0 + + +def _try_init_clipseg(): + """Lazy-load CLIPSeg processor + model and choose device. + Returns True on success. + """ + global _CLIPSEG_MODEL, _CLIPSEG_PROC, _CLIPSEG_DEV + if (_CLIPSEG_MODEL is not None) and (_CLIPSEG_PROC is not None): + return True + try: + from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation # type: ignore + except Exception: + if not globals().get("_CLIPSEG_WARNED", False): + print("[CADE2.5][CLIPSeg] transformers not available; CLIPSeg disabled.") + globals()["_CLIPSEG_WARNED"] = True + return False + try: + _CLIPSEG_PROC = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + _CLIPSEG_MODEL = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + if _CLIPSEG_FORCE_CPU: + _CLIPSEG_DEV = "cpu" + else: + _CLIPSEG_DEV = "cuda" if torch.cuda.is_available() else "cpu" + _CLIPSEG_MODEL = _CLIPSEG_MODEL.to(_CLIPSEG_DEV) + _CLIPSEG_MODEL.eval() + return True + except Exception as e: + print(f"[CADE2.5][CLIPSeg] failed to load model: {e}") + return False + + +def _clipseg_build_mask(image_bhwc: torch.Tensor, + text: str, + preview: int = 224, + threshold: float = 0.4, + blur: float = 7.0, + dilate: int = 4, + gain: float = 1.0, + ref_embed: torch.Tensor | None = None, + clip_vision=None, + ref_threshold: float = 0.03) -> torch.Tensor | None: + """Return BHWC single-channel mask [0,1] from CLIPSeg. + - Uses cached CLIPSeg model; gracefully returns None on failure. + - Applies optional threshold/blur/dilate and scaling gain. + - If clip_vision + ref_embed provided, gates mask by CLIP-Vision distance. + """ + if not text or not isinstance(text, str): + return None + if not _try_init_clipseg(): + return None + try: + # Prepare preview image (CPU PIL) + target = int(max(16, min(1024, preview))) + img = image_bhwc.detach().to('cpu') + B, H, W, C = img.shape + x = img[0].movedim(-1, 0).unsqueeze(0) # 1,C,H,W + x = F.interpolate(x, size=(target, target), mode='bilinear', align_corners=False) + x = x.clamp(0, 1) + arr = (x[0].movedim(0, -1).numpy() * 255.0).astype('uint8') + from PIL import Image # lazy import + pil_img = Image.fromarray(arr) + + # Run CLIPSeg + import re + prompts = [t.strip() for t in re.split(r"[\|,;\n]+", text) if t.strip()] + if not prompts: + prompts = [text.strip()] + prompts = prompts[:8] + inputs = _CLIPSEG_PROC(text=prompts, images=[pil_img] * len(prompts), return_tensors="pt") + inputs = {k: v.to(_CLIPSEG_DEV) for k, v in inputs.items()} + with torch.inference_mode(): + outputs = _CLIPSEG_MODEL(**inputs) # type: ignore + # logits: [N, H', W'] for N prompts + logits = outputs.logits # [N,h,w] + if logits.ndim == 2: + logits = logits.unsqueeze(0) + prob = torch.sigmoid(logits) # [N,h,w] + # Soft-OR fuse across prompts + prob = 1.0 - torch.prod(1.0 - prob.clamp(0, 1), dim=0, keepdim=True) # [1,h,w] + prob = prob.unsqueeze(1) # [1,1,h,w] + # Resize to original image size + prob = F.interpolate(prob, size=(H, W), mode='bilinear', align_corners=False) + m = prob[0, 0].to(dtype=image_bhwc.dtype, device=image_bhwc.device) + # Threshold + blur (approx) + if threshold > 0.0: + m = torch.where(m > float(threshold), m, torch.zeros_like(m)) + # Gaussian blur via our depthwise helper + if blur > 0.0: + rad = int(max(1, min(7, round(blur)))) + m = _gaussian_blur_nchw(m.unsqueeze(0).unsqueeze(0), sigma=float(max(0.5, blur)), radius=rad)[0, 0] + # Dilation via max-pool + if int(dilate) > 0: + k = int(dilate) * 2 + 1 + p = int(dilate) + m = F.max_pool2d(m.unsqueeze(0).unsqueeze(0), kernel_size=k, stride=1, padding=p)[0, 0] + # Optional CLIP-Vision gating by reference distance + if (clip_vision is not None) and (ref_embed is not None): + try: + cur = _encode_clip_image(image_bhwc, clip_vision, target_res=224) + dist = _clip_cosine_distance(cur, ref_embed) + if dist > float(ref_threshold): + # up to +50% gain if сильно уехали + gate = 1.0 + min(0.5, (dist - float(ref_threshold)) * 4.0) + m = m * gate + except Exception: + pass + m = (m * float(max(0.0, gain))).clamp(0, 1) + return m.unsqueeze(0).unsqueeze(-1) # BHWC with B=1,C=1 + except Exception as e: + if not globals().get("_CLIPSEG_WARNED", False): + print(f"[CADE2.5][CLIPSeg] mask failed: {e}") + globals()["_CLIPSEG_WARNED"] = True + return None + + +def _np_to_mask_tensor(np_map: np.ndarray, out_h: int, out_w: int, device, dtype): + """Convert numpy heatmap [H,W] or [1,H,W] or [H,W,1] to BHWC torch mask with B=1 and resize to out_h,out_w.""" + if np_map.ndim == 3: + np_map = np_map.reshape(np_map.shape[-2], np_map.shape[-1]) if (np_map.shape[0] == 1) else np_map.squeeze() + if np_map.ndim != 2: + return None + t = torch.from_numpy(np_map.astype(np.float32)) + t = t.clamp_min(0.0) + t = t.unsqueeze(0).unsqueeze(0) # B=1,C=1,H,W + t = F.interpolate(t, size=(out_h, out_w), mode="bilinear", align_corners=False) + t = t.permute(0, 2, 3, 1).to(device=device, dtype=dtype) # B,H,W,C + return t.clamp(0, 1) + + +# --- Firefly/Hot-pixel remover (image space, BHWC in 0..1) --- +def _median_pool3x3_bhwc(img_bhwc: torch.Tensor) -> torch.Tensor: + B, H, W, C = img_bhwc.shape + x = img_bhwc.permute(0, 3, 1, 2) # B,C,H,W + unfold = F.unfold(x, kernel_size=3, padding=1) # B, 9*C, H*W + unfold = unfold.view(B, x.shape[1], 9, H, W) # B,C,9,H,W + med, _ = torch.median(unfold, dim=2) # B,C,H,W + return med.permute(0, 2, 3, 1) # B,H,W,C + + +def _despeckle_fireflies(img_bhwc: torch.Tensor, + thr: float = 0.985, + max_iso: float | None = None, + grad_gate: float = 0.25) -> torch.Tensor: + try: + dev, dt = img_bhwc.device, img_bhwc.dtype + B, H, W, C = img_bhwc.shape + # Scale-aware window + s = max(H, W) / 1024.0 + k = 3 if s <= 1.1 else (5 if s <= 2.0 else 7) + pad = k // 2 + # Value/Saturation from RGB (fast, no colorspace conv required) + R, G, Bc = img_bhwc[..., 0], img_bhwc[..., 1], img_bhwc[..., 2] + V = torch.maximum(R, torch.maximum(G, Bc)) + m = torch.minimum(R, torch.minimum(G, Bc)) + S = 1.0 - (m / (V + 1e-6)) + # Dynamic bright threshold from top tail; allow manual override for very high thr + try: + q = float(torch.quantile(V.reshape(-1), 0.9995).item()) + thr_eff = float(thr) if float(thr) >= 0.99 else max(float(thr), min(0.997, q)) + except Exception: + thr_eff = float(thr) + v_thr = max(0.985, thr_eff) + s_thr = 0.06 + cand = (V > v_thr) & (S < s_thr) + # gradient gate to protect real edges/highlights + lum = (0.2126 * R + 0.7152 * G + 0.0722 * Bc) + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=dev, dtype=dt).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=dev, dtype=dt).view(1, 1, 3, 3) + gx = F.conv2d(lum.unsqueeze(1), kx, padding=1) + gy = F.conv2d(lum.unsqueeze(1), ky, padding=1) + grad = torch.sqrt(gx * gx + gy * gy).squeeze(1) + safe_gate = float(grad_gate) * (k / 3.0) ** 0.5 + cand = cand & (grad < safe_gate) + if not cand.any(): + return img_bhwc + # Prefer connected components (OpenCV) to drop small bright specks + try: + import cv2 + masks = [] + for b in range(cand.shape[0]): + msk = cand[b].detach().to('cpu').numpy().astype('uint8') * 255 + num, labels, stats, _ = cv2.connectedComponentsWithStats(msk, connectivity=8) + rem = np.zeros_like(msk, dtype='uint8') + # Size threshold grows with k + area_max = int(max(3, round((k * k) * 0.8))) + for lbl in range(1, num): + area = stats[lbl, cv2.CC_STAT_AREA] + if area <= area_max: + rem[labels == lbl] = 255 + masks.append(torch.from_numpy(rem > 0)) + rm = torch.stack(masks, dim=0).to(device=dev) # B,H,W (bool) + rm = rm.unsqueeze(-1) # B,H,W,1 + if not rm.any(): + return img_bhwc + med = _median_pool3x3_bhwc(img_bhwc) + return torch.where(rm, med, img_bhwc) + except Exception: + # Fallback: isolation via local density + dens = F.avg_pool2d(cand.float().unsqueeze(1), k, 1, pad).squeeze(1) + max_iso_eff = (2.0 / (k * k)) if (max_iso is None) else float(max_iso) + iso = cand & (dens < max_iso_eff) + if not iso.any(): + return img_bhwc + med = _median_pool3x3_bhwc(img_bhwc) + return torch.where(iso.unsqueeze(-1), med, img_bhwc) + except Exception: + return img_bhwc + + +def _try_heatmap_from_outputs(outputs: list, preview_hw: tuple[int, int]): + """Return [H,W] heatmap from model outputs if possible. + Supports: + - Segmentation logits/probabilities (NCHW / NHWC) + - Keypoints arrays -> gaussian disks on points + - Bounding boxes -> soft rectangles + """ + if not outputs: + return None + + Ht, Wt = int(preview_hw[0]), int(preview_hw[1]) + + def to_float(arr): + if arr.dtype not in (np.float32, np.float64): + try: + arr = arr.astype(np.float32) + except Exception: + return None + return arr + + def sigmoid(x): + return 1.0 / (1.0 + np.exp(-x)) + + # 1) Prefer any spatial heatmap first + for out in outputs: + try: + arr = np.asarray(out) + except Exception: + continue + arr = to_float(arr) + if arr is None: + continue + if arr.ndim == 4: + n, a, b, c = arr.shape + if c <= 4 and a >= 8 and b >= 8: + if c == 1: + hm = sigmoid(arr[0, :, :, 0]) if np.max(np.abs(arr)) > 1.5 else arr[0, :, :, 0] + else: + ex = np.exp(arr[0] - np.max(arr[0], axis=-1, keepdims=True)) + prob = ex / np.clip(ex.sum(axis=-1, keepdims=True), 1e-6, None) + hm = 1.0 - prob[..., 0] if prob.shape[-1] > 1 else prob[..., 0] + return hm.astype(np.float32) + else: + if a == 1: + ch = arr[0, 0] + hm = sigmoid(ch) if np.max(np.abs(ch)) > 1.5 else ch + return hm.astype(np.float32) + else: + x = arr[0] + x = x - np.max(x, axis=0, keepdims=True) + ex = np.exp(x) + prob = ex / np.clip(np.sum(ex, axis=0, keepdims=True), 1e-6, None) + bg = prob[0] if prob.shape[0] > 1 else prob[0] + hm = 1.0 - bg + return hm.astype(np.float32) + if arr.ndim == 3: + if arr.shape[0] == 1 and arr.shape[1] >= 8 and arr.shape[2] >= 8: + return arr[0].astype(np.float32) + if arr.ndim == 2 and arr.shape[0] >= 8 and arr.shape[1] >= 8: + return arr.astype(np.float32) + + # 2) Try keypoints and boxes + heat = np.zeros((Ht, Wt), dtype=np.float32) + + def draw_gaussian(hm, cx, cy, sigma=2.5, amp=1.0): + r = max(1, int(3 * sigma)) + xs = np.arange(-r, r + 1, dtype=np.float32) + ys = np.arange(-r, r + 1, dtype=np.float32) + gx = np.exp(-(xs**2) / (2 * sigma * sigma)) + gy = np.exp(-(ys**2) / (2 * sigma * sigma)) + g = np.outer(gy, gx) * float(amp) + x0 = int(round(cx)) - r + y0 = int(round(cy)) - r + x1 = x0 + g.shape[1] + y1 = y0 + g.shape[0] + if x1 < 0 or y1 < 0 or x0 >= Wt or y0 >= Ht: + return + xs0 = max(0, x0) + ys0 = max(0, y0) + xs1 = min(Wt, x1) + ys1 = min(Ht, y1) + gx0 = xs0 - x0 + gy0 = ys0 - y0 + gx1 = gx0 + (xs1 - xs0) + gy1 = gy0 + (ys1 - ys0) + hm[ys0:ys1, xs0:xs1] = np.maximum(hm[ys0:ys1, xs0:xs1], g[gy0:gy1, gx0:gx1]) + + def draw_soft_rect(hm, x0, y0, x1, y1, edge=3.0): + x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) + if x1 <= 0 or y1 <= 0 or x0 >= Wt or y0 >= Ht: + return + xs0 = max(0, min(x0, x1)) + ys0 = max(0, min(y0, y1)) + xs1 = min(Wt, max(x0, x1)) + ys1 = min(Ht, max(y0, y1)) + if xs1 - xs0 <= 0 or ys1 - ys0 <= 0: + return + hm[ys0:ys1, xs0:xs1] = np.maximum(hm[ys0:ys1, xs0:xs1], 1.0) + # feather edges with simple blur-like falloff + if edge > 0: + rad = int(edge) + if rad > 0: + # quick separable triangle filter + line = np.linspace(0, 1, rad + 1, dtype=np.float32)[1:] + for d in range(1, rad + 1): + w = line[d - 1] + if ys0 - d >= 0: + hm[ys0 - d:ys0, xs0:xs1] = np.maximum(hm[ys0 - d:ys0, xs0:xs1], w) + if ys1 + d <= Ht: + hm[ys1:ys1 + d, xs0:xs1] = np.maximum(hm[ys1:ys1 + d, xs0:xs1], w) + if xs0 - d >= 0: + hm[max(0, ys0 - d):min(Ht, ys1 + d), xs0 - d:xs0] = np.maximum( + hm[max(0, ys0 - d):min(Ht, ys1 + d), xs0 - d:xs0], w) + if xs1 + d <= Wt: + hm[max(0, ys0 - d):min(Ht, ys1 + d), xs1:xs1 + d] = np.maximum( + hm[max(0, ys0 - d):min(Ht, ys1 + d), xs1:xs1 + d], w) + + # Inspect outputs to find plausible keypoints/boxes + for out in outputs: + try: + arr = np.asarray(out) + except Exception: + continue + arr = to_float(arr) + if arr is None: + continue + a = arr + # Squeeze batch dims like [1,N,4] -> [N,4] + while a.ndim > 2 and a.shape[0] == 1: + a = np.squeeze(a, axis=0) + # Keypoints: [N,2] or [N,3] or [K, N, 2/3] (relax N limit; subsample if huge) + if a.ndim == 2 and a.shape[-1] in (2, 3): + pts = a + elif a.ndim == 3 and a.shape[-1] in (2, 3): + pts = a.reshape(-1, a.shape[-1]) + else: + pts = None + if pts is not None: + # Coordinates range guess: if max>1.2 -> absolute; else normalized + maxv = float(np.nanmax(np.abs(pts[:, :2]))) if pts.size else 0.0 + for px, py, *rest in pts: + if np.isnan(px) or np.isnan(py): + continue + if maxv <= 1.2: + cx = float(px) * (Wt - 1) + cy = float(py) * (Ht - 1) + else: + cx = float(px) + cy = float(py) + base_sig = max(1.5, min(Ht, Wt) / 128.0) + if _ONNX_KPTS_ENABLE: + draw_gaussian(heat, cx, cy, sigma=base_sig * float(_ONNX_KPTS_SIGMA), amp=float(_ONNX_KPTS_GAIN)) + else: + draw_gaussian(heat, cx, cy, sigma=base_sig) + continue + + # Wholebody-style packed keypoints: [N, K*3] with triples (x,y,conf) + if _ONNX_KPTS_ENABLE and a.ndim == 2 and a.shape[-1] >= 6 and (a.shape[-1] % 3) == 0: + K = a.shape[-1] // 3 + if K >= 5 and K <= 256: + # Guess coordinate range once + with np.errstate(invalid='ignore'): + maxv = float(np.nanmax(np.abs(a[:, :2]))) if a.size else 0.0 + for i in range(a.shape[0]): + row = a[i] + kp = row.reshape(K, 3) + for (px, py, pc) in kp: + if np.isnan(px) or np.isnan(py): + continue + if np.isfinite(pc) and pc < float(_ONNX_KPTS_CONF): + continue + if maxv <= 1.2: + cx = float(px) * (Wt - 1) + cy = float(py) * (Ht - 1) + else: + cx = float(px) + cy = float(py) + base_sig = max(1.0, min(Ht, Wt) / 128.0) + draw_gaussian(heat, cx, cy, sigma=base_sig * float(_ONNX_KPTS_SIGMA), amp=float(_ONNX_KPTS_GAIN)) + continue + # 1D edge-case: single detection row (post-processed model output) + if _ONNX_KPTS_ENABLE and a.ndim == 1 and a.shape[0] >= 6: + D = a.shape[0] + parsed = False + # try [xyxy, conf, cls, kpts] or [xyxy, conf, kpts] or [xyxy, kpts] or [kpts] + for offset in (6, 5, 4, 0): + t = D - offset + if t >= 6 and (t % 3) == 0: + k = t // 3 + kp = a[offset:offset + 3 * k].reshape(k, 3) + parsed = True + break + if parsed: + with np.errstate(invalid='ignore'): + maxv = float(np.nanmax(np.abs(kp[:, :2]))) if kp.size else 0.0 + for (px, py, pc) in kp: + if np.isnan(px) or np.isnan(py): + continue + if np.isfinite(pc) and pc < float(_ONNX_KPTS_CONF): + continue + if maxv <= 1.2: + cx = float(px) * (Wt - 1) + cy = float(py) * (Ht - 1) + else: + cx = float(px) + cy = float(py) + base_sig = max(1.0, min(Ht, Wt) / 128.0) + draw_gaussian(heat, cx, cy, sigma=base_sig * float(_ONNX_KPTS_SIGMA), amp=float(_ONNX_KPTS_GAIN)) + continue + # Boxes: [N,4+] (x0,y0,x1,y1) or [N, (x,y,w,h, [conf, ...])]; relax N limit (handle YOLO-style outputs) + if a.ndim == 2 and a.shape[-1] >= 4: + boxes = a + elif a.ndim == 3 and a.shape[-1] >= 4: + # choose the smallest first two dims as N + if a.shape[0] == 1: + boxes = a.reshape(-1, a.shape[-1]) + else: + boxes = a.reshape(-1, a.shape[-1]) + else: + boxes = None + if boxes is not None: + # Optional score gating (try to find a confidence column) + score = None + if boxes.shape[-1] >= 5: + score = boxes[:, 4] + # If trailing columns look like probabilities in [0,1], mix the best one; if they look like class ids, ignore + if boxes.shape[-1] > 5: + try: + tail = boxes[:, 5:] + tmax = np.max(tail, axis=-1) + # Heuristic: treat as probs if within [0,1] and not integer-like; else assume class ids + maybe_prob = np.all((tmax >= 0.0) & (tmax <= 1.0)) + frac = np.abs(tmax - np.round(tmax)) + maybe_classid = (np.mean(frac < 1e-6) > 0.9) and (np.max(tmax) >= 1.0) + if maybe_prob and not maybe_classid: + score = score * tmax + except Exception: + pass + # Keep top-K by score if available + if score is not None: + try: + order = np.argsort(-score) + keep = order[: min(12, order.shape[0])] + boxes = boxes[keep] + score = score[keep] + except Exception: + score = None + + xy = boxes[:, :4] + maxv = float(np.nanmax(np.abs(xy))) if xy.size else 0.0 + if maxv <= 1.2: + x0 = xy[:, 0] * (Wt - 1) + y0 = xy[:, 1] * (Ht - 1) + x1 = xy[:, 2] * (Wt - 1) + y1 = xy[:, 3] * (Ht - 1) + else: + x0, y0, x1, y1 = xy[:, 0], xy[:, 1], xy[:, 2], xy[:, 3] + # Heuristic: if many boxes are inverted, treat as [x,y,w,h] + invalid = np.sum((x1 <= x0) | (y1 <= y0)) + if invalid > 0.5 * x0.shape[0]: + x, y, w, h = x0, y0, x1, y1 + x0 = x - w * 0.5 + y0 = y - h * 0.5 + x1 = x + w * 0.5 + y1 = y + h * 0.5 + for i in range(x0.shape[0]): + if score is not None and np.isfinite(score[i]) and score[i] < 0.05: + continue + draw_soft_rect(heat, x0[i], y0[i], x1[i], y1[i], edge=3.0) + + # Embedded keypoints in YOLO-style rows: try to parse trailing triples (x,y,conf) + if _ONNX_KPTS_ENABLE and boxes.shape[-1] > 6: + D = boxes.shape[-1] + for i in range(boxes.shape[0]): + row = boxes[i] + parsed = False + # try [xyxy, conf, cls, kpts] or [xyxy, conf, kpts] or [xyxy, kpts] + for offset in (6, 5, 4): + t = D - offset + if t >= 6 and t % 3 == 0: + k = t // 3 + kp = row[offset:offset + 3 * k].reshape(k, 3) + parsed = True + break + if not parsed: + continue + for (px, py, pc) in kp: + if np.isnan(px) or np.isnan(py): + continue + if pc < float(_ONNX_KPTS_CONF): + continue + if maxv <= 1.2: + cx = float(px) * (Wt - 1) + cy = float(py) * (Ht - 1) + else: + cx = float(px) + cy = float(py) + base_sig = max(1.0, min(Ht, Wt) / 128.0) + draw_gaussian(heat, cx, cy, sigma=base_sig * float(_ONNX_KPTS_SIGMA), amp=float(_ONNX_KPTS_GAIN)) + + if heat.max() > 0: + heat = np.clip(heat, 0.0, 1.0) + return heat + return None + + +def _onnx_build_mask(image_bhwc: torch.Tensor, preview: int, sensitivity: float, models_dir: str, anomaly_gain: float = 1.0) -> torch.Tensor: + """Return BHWC single-channel mask [0,1] fused from all auto-detected ONNX models. + - Auto-loads any .onnx in models_dir. + - Heuristically extracts spatial heatmaps; non-spatial outputs are ignored. + - Uses soft-OR fusion across models. Models whose filename contains 'anomaly' are scaled by anomaly_gain. + """ + if not _try_init_onnx(models_dir): + # Explicit hint when debugging counts + try: + if globals().get("_ONNX_COUNT_DEBUG", False): + print("[CADE2.5][ONNX] inactive: onnxruntime not available or init failed") + except Exception: + pass + return torch.zeros((image_bhwc.shape[0], image_bhwc.shape[1], image_bhwc.shape[2], 1), device=image_bhwc.device, dtype=image_bhwc.dtype) + + if not _ONNX_SESS: + # Explicit hint when debugging counts + try: + if globals().get("_ONNX_COUNT_DEBUG", False): + print(f"[CADE2.5][ONNX] inactive: no .onnx models loaded from {models_dir}") + except Exception: + pass + return torch.zeros((image_bhwc.shape[0], image_bhwc.shape[1], image_bhwc.shape[2], 1), device=image_bhwc.device, dtype=image_bhwc.dtype) + + # One-time session summary when counting is enabled + if globals().get("_ONNX_COUNT_DEBUG", False): + try: + names = list(_ONNX_SESS.keys()) + short = names[:3] + more = "..." if len(names) > 3 else "" + print(f"[CADE2.5][ONNX] sessions={len(names)} models={short}{more}") + except Exception: + pass + + B, H, W, C = image_bhwc.shape + device = image_bhwc.device + dtype = image_bhwc.dtype + + # Process per-batch image + masks = [] + img_cpu = image_bhwc.detach().to('cpu') + for b in range(B): + masks_b = [] + counts_b: dict[str, int] = {} + if globals().get("_ONNX_COUNT_DEBUG", False): + try: + print(f"[CADE2.5][ONNX] build mask image[{b}] preview={int(max(16, min(1024, preview)))}") + except Exception: + pass + # Prepare base BCHW tensor and default preview size; per-model resize comes later + target = int(max(16, min(1024, preview))) + xb = img_cpu[b].movedim(-1, 0).unsqueeze(0) # 1,C,H,W + if _ONNX_DEBUG: + try: + print(f"[CADE2.5][ONNX] Build mask for image[{b}] -> preview {target}x{target}") + except Exception: + pass + + for name, sess in list(_ONNX_SESS.items()): + try: + inputs = sess.get_inputs() + if not inputs: + continue + in_name = inputs[0].name + in_shape = inputs[0].shape if hasattr(inputs[0], 'shape') else None + # Choose layout automatically based on the presence of channel dim=3 + if isinstance(in_shape, (list, tuple)) and len(in_shape) == 4: + dim_vals = [] + for d in in_shape: + try: + dim_vals.append(int(d)) + except Exception: + dim_vals.append(-1) + if dim_vals[-1] == 3: + layout = "NHWC" + else: + layout = "NCHW" + else: + layout = "NCHW?" + if _ONNX_DEBUG: + try: + print(f"[CADE2.5][ONNX] Model '{name}' in_shape={in_shape} layout={layout}") + except Exception: + pass + # Build per-model sized variants (respect fixed input shapes when provided) + th, tw = target, target + try: + if isinstance(in_shape, (list, tuple)) and len(in_shape) == 4: + dd = [] + for d in in_shape: + try: + dd.append(int(d)) + except Exception: + dd.append(-1) + if layout == "NCHW" and dd[2] > 8 and dd[3] > 8: + th, tw = int(dd[2]), int(dd[3]) + if layout.startswith("NHWC") and dd[1] > 8 and dd[2] > 8: + th, tw = int(dd[1]), int(dd[2]) + except Exception: + th, tw = target, target + + x_stretch_m = F.interpolate(xb, size=(th, tw), mode='bilinear', align_corners=False).clamp(0, 1) + if th == tw: + x_letter_m = _letterbox_nchw(xb, th).clamp(0, 1) + else: + sq = max(th, tw) + x_letter_sq = _letterbox_nchw(xb, sq).clamp(0, 1) + x_letter_m = F.interpolate(x_letter_sq, size=(th, tw), mode='bilinear', align_corners=False).clamp(0, 1) + + variants = [ + ("stretch-RGB", x_stretch_m), + ("letterbox-RGB", x_letter_m), + ("stretch-BGR", x_stretch_m[:, [2, 1, 0], :, :]), + ("letterbox-BGR", x_letter_m[:, [2, 1, 0], :, :]), + ] + + # Try multiple input variants and scales + hm = None + chosen = None + for vname, vx in variants: + if layout.startswith("NHWC"): + xin = vx.permute(0, 2, 3, 1) + else: + xin = vx + for scale in (1.0, 255.0): + inp = (xin * float(scale)).numpy().astype(np.float32) + feed = {in_name: inp} + outs = sess.run(None, feed) + if _ONNX_DEBUG: + try: + shapes = [] + for o in outs: + try: + shapes.append(tuple(np.asarray(o).shape)) + except Exception: + shapes.append("?") + print(f"[CADE2.5][ONNX] '{name}' {vname} scale={scale} -> outs shapes {shapes}") + except Exception: + pass + hm = _try_heatmap_from_outputs(outs, (target, target)) + if _ONNX_DEBUG: + try: + if hm is None: + print(f"[CADE2.5][ONNX] '{name}' {vname} scale={scale}: no spatial heatmap detected") + else: + print(f"[CADE2.5][ONNX] '{name}' {vname} scale={scale}: heat stats min={np.min(hm):.4f} max={np.max(hm):.4f} mean={np.mean(hm):.4f}") + except Exception: + pass + if hm is not None and np.max(hm) > 0: + chosen = (vname, scale) + break + if hm is not None and np.max(hm) > 0: + break + if hm is None: + continue + # Scale by sensitivity and optional anomaly gain + gain = float(max(0.0, sensitivity)) + if 'anomaly' in name.lower(): + gain *= float(max(0.0, anomaly_gain)) + hm = np.clip(hm * gain, 0.0, 1.0) + # Heuristic rejection of stripe artifacts (horizontal/vertical banding) + try: + rm = np.mean(hm, axis=1) # row means (H) + cm = np.mean(hm, axis=0) # col means (W) + rstd = float(np.std(rm)) + cstd = float(np.std(cm)) + zig_r = float(np.mean(np.abs(np.diff(rm)))) + zig_c = float(np.mean(np.abs(np.diff(cm)))) + horiz_bands = (rstd > 10.0 * max(cstd, 1e-6)) and (zig_r > 0.02) + vert_bands = (cstd > 10.0 * max(rstd, 1e-6)) and (zig_c > 0.02) + if horiz_bands or vert_bands: + if _ONNX_DEBUG: + print(f"[CADE2.5][ONNX] '{name}' rejected as stripe artifact (rstd={rstd:.4f} cstd={cstd:.4f} zig_r={zig_r:.4f} zig_c={zig_c:.4f})") + hm = None + except Exception: + pass + tmask = _np_to_mask_tensor(hm, H, W, device, dtype) + if tmask is not None: + masks_b.append(tmask) + # Optional counting for debugging: derive counts from raw outputs if possible + try: + if globals().get("_ONNX_COUNT_DEBUG", False): + lower = name.lower() + is_face = ("face" in lower) + is_hand = ("hand" in lower) + is_pose = any(w in lower for w in ["wholebody", "pose", "person", "body"]) and not (is_face or is_hand) + + boxes_cnt = 0 + persons_cnt = 0 + wrists_cnt = 0 + + try: + for outx in outs: + arr0 = np.asarray(outx) + if arr0 is None: + continue + a = arr0.astype(np.float32, copy=False) + while a.ndim > 2 and a.shape[0] == 1: + try: + a = np.squeeze(a, axis=0) + except Exception: + break + + # Face/Hand: count only plausible (N,D) detection tables, not spatial maps/grids + if (is_face or is_hand) and a.ndim == 2: + N, D = a.shape + if D >= 4 and D <= 64 and N <= 512 and not (N >= 32 and D >= 32): + n = N + # Try to locate a confidence column + conf = None + if D >= 5: + c = a[:, 4] + if np.all(np.isfinite(c)) and 0.0 <= float(np.nanmin(c)) and float(np.nanmax(c)) <= 1.0: + conf = c + if conf is None: + c = a[:, -1] + if np.all(np.isfinite(c)) and 0.0 <= float(np.nanmin(c)) and float(np.nanmax(c)) <= 1.0: + conf = c + if conf is not None: + n = int(np.sum(conf >= 0.25)) + boxes_cnt = max(boxes_cnt, int(n)) + + # Pose: prefer keypoints formats only + if is_pose: + # Packed per row [N, K*3] + if a.ndim == 2 and a.shape[-1] >= 6 and (a.shape[-1] % 3) == 0: + persons_cnt = max(persons_cnt, int(a.shape[0])) + try: + K = a.shape[-1] // 3 + if K >= 11: + lw = a[:, 9*3 + 2] + rw = a[:, 10*3 + 2] + wrists_cnt = max(wrists_cnt, int(np.sum(lw >= 0.2) + np.sum(rw >= 0.2))) + except Exception: + pass + # [N,K,2/3] + if a.ndim == 3 and a.shape[-1] in (2, 3): + persons_cnt = max(persons_cnt, int(a.shape[0])) + try: + if a.shape[1] >= 11 and a.shape[-1] == 3: + lw = a[:, 9, 2] + rw = a[:, 10, 2] + wrists_cnt = max(wrists_cnt, int(np.sum(lw >= 0.2) + np.sum(rw >= 0.2))) + except Exception: + pass + except Exception: + pass + + # Map to categories by model name using derived counts + if is_face: + if boxes_cnt > 0: + counts_b["faces"] = counts_b.get("faces", 0) + boxes_cnt + elif is_hand: + if boxes_cnt > 0: + counts_b["hands"] = counts_b.get("hands", 0) + boxes_cnt + elif is_pose: + if persons_cnt > 0: + counts_b["persons"] = counts_b.get("persons", 0) + persons_cnt + # Fallback hands from wrists or 2 per person + if wrists_cnt > 0: + counts_b["hands"] = counts_b.get("hands", 0) + wrists_cnt + elif persons_cnt > 0: + counts_b["hands"] = counts_b.get("hands", 0) + (2 * persons_cnt) + except Exception: + pass + if _ONNX_DEBUG: + try: + area = float(tmask.movedim(-1,1).mean().item()) + if chosen is not None: + vname, scale = chosen + print(f"[CADE2.5][ONNX] '{name}' via {vname} x{scale} area={area:.4f}") + else: + print(f"[CADE2.5][ONNX] '{name}' contribution area={area:.4f}") + except Exception: + pass + except Exception: + # Ignore failing models + continue + if not masks_b: + masks.append(torch.zeros((1, H, W, 1), device=device, dtype=dtype)) + if _ONNX_DEBUG or globals().get("_ONNX_COUNT_DEBUG", False): + try: + print(f"[CADE2.5][ONNX] Detected (image[{b}]): none (no contributing models)") + except Exception: + pass + else: + # Soft-OR fusion: 1 - prod(1 - m) + stack = torch.stack([masks_b[i] for i in range(len(masks_b))], dim=0) # M,1,H,W,1? actually B dims kept as 1 + fused = 1.0 - torch.prod(1.0 - stack.clamp(0, 1), dim=0) + # Light smoothing via bilinear down/up (anti alias) + ch = fused.permute(0, 3, 1, 2) # B=1,C=1,H,W + dd = F.interpolate(ch, scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=False) + uu = F.interpolate(dd, size=(H, W), mode='bilinear', align_corners=False) + fused = uu.permute(0, 2, 3, 1).clamp(0, 1) + if _ONNX_DEBUG or globals().get("_ONNX_COUNT_DEBUG", False): + try: + area = float(fused.movedim(-1,1).mean().item()) + if _ONNX_DEBUG: + print(f"[CADE2.5][ONNX] Fused area (image[{b}])={area:.4f}") + # Print per-image counts if requested + if globals().get("_ONNX_COUNT_DEBUG", False): + if counts_b: + faces = counts_b.get("faces", 0) + hands = counts_b.get("hands", 0) + persons = counts_b.get("persons", 0) + print(f"[CADE2.5][ONNX] Detected (image[{b}]): faces={faces} hands={hands} persons={persons}") + else: + print(f"[CADE2.5][ONNX] Detected (image[{b}]): counts unavailable (no categories or cv2 missing), area={area:.4f}") + except Exception: + pass + masks.append(fused) + + return torch.cat(masks, dim=0) + +def _sampler_names(): + try: + import comfy.samplers + return comfy.samplers.KSampler.SAMPLERS + except Exception: + return ["euler"] + + +def _scheduler_names(): + try: + import comfy.samplers + scheds = list(comfy.samplers.KSampler.SCHEDULERS) + if "MGHybrid" not in scheds: + scheds.append("MGHybrid") + return scheds + except Exception: + return ["normal", "MGHybrid"] + + +def safe_decode(vae, lat, tile=512, ovlp=64): + h, w = lat["samples"].shape[-2:] + if min(h, w) > 1024: + # Increase overlap for ultra-hires to reduce seam artifacts + ov = 128 if max(h, w) > 2048 else ovlp + return vae.decode_tiled(lat["samples"], tile_x=tile, tile_y=tile, overlap=ov) + return vae.decode(lat["samples"]) + + +def safe_encode(vae, img, tile=512, ovlp=64): + import math, torch.nn.functional as F + h, w = img.shape[1:3] + try: + stride = int(vae.spacial_compression_decode()) + except Exception: + stride = 8 + if stride <= 0: + stride = 8 + def _align_up(x, s): + return int(((x + s - 1) // s) * s) + Ht = _align_up(h, stride) + Wt = _align_up(w, stride) + x = img + if (Ht != h) or (Wt != w): + # pad on bottom/right using replicate to avoid black borders + pad_h = Ht - h + pad_w = Wt - w + x_nchw = img.movedim(-1, 1) + x_nchw = F.pad(x_nchw, (0, pad_w, 0, pad_h), mode='replicate') + x = x_nchw.movedim(1, -1) + if min(Ht, Wt) > 1024: + ov = 128 if max(Ht, Wt) > 2048 else ovlp + return vae.encode_tiled(x[:, :, :, :3], tile_x=tile, tile_y=tile, overlap=ov) + return vae.encode(x[:, :, :, :3]) + + + +def _gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid( + torch.linspace(-1, 1, kernel_size, device=device), + torch.linspace(-1, 1, kernel_size, device=device), + indexing="ij", + ) + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + + +def _sharpen_image(image: torch.Tensor, sharpen_radius: int, sigma: float, alpha: float): + if sharpen_radius == 0: + return (image,) + + image = image.to(model_management.get_torch_device()) + batch_size, height, width, channels = image.shape + + kernel_size = sharpen_radius * 2 + 1 + kernel = _gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha * 10) + kernel = kernel.to(dtype=image.dtype) + center = kernel_size // 2 + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) + tensor_image = F.pad(tensor_image, (sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + return (result.to(model_management.intermediate_device()),) + + +def _encode_clip_image(image: torch.Tensor, clip_vision, target_res: int) -> torch.Tensor: + # image: BHWC in [0,1] + img = image.movedim(-1, 1) # BCHW + img = F.interpolate(img, size=(target_res, target_res), mode="bilinear", align_corners=False) + img = (img * 2.0) - 1.0 + embeds = clip_vision.encode_image(img)["image_embeds"] + embeds = F.normalize(embeds, dim=-1) + return embeds + + +def _clip_cosine_distance(a: torch.Tensor, b: torch.Tensor) -> float: + if a.shape != b.shape: + m = min(a.shape[0], b.shape[0]) + a = a[:m] + b = b[:m] + sim = (a * b).sum(dim=-1).mean().clamp(-1.0, 1.0).item() + return 1.0 - sim + + +def _soft_symmetry_blend(image_bhwc: torch.Tensor, + mask_bhwc: torch.Tensor, + alpha: float = 0.03, + lp_sigma: float = 1.5) -> torch.Tensor: + """Gently mix a mirrored low-frequency component inside mask. + - image_bhwc: [B,H,W,C] in [0,1] + - mask_bhwc: [B,H,W,1] in [0,1] + """ + try: + if image_bhwc is None or mask_bhwc is None: + return image_bhwc + if image_bhwc.ndim != 4 or mask_bhwc.ndim != 4: + return image_bhwc + B, H, W, C = image_bhwc.shape + if C < 3: + return image_bhwc + # Mirror along width + mirror = torch.flip(image_bhwc, dims=[2]) + # Low-pass both + x = image_bhwc.movedim(-1, 1) + y = mirror.movedim(-1, 1) + rad = max(1, int(round(lp_sigma))) + x_lp = _gaussian_blur_nchw(x, sigma=float(lp_sigma), radius=rad) + y_lp = _gaussian_blur_nchw(y, sigma=float(lp_sigma), radius=rad) + # High-pass from original + hp = x - x_lp + # Blend LPs inside mask + m = mask_bhwc.movedim(-1, 1).clamp(0, 1) + a = float(max(0.0, min(0.2, alpha))) + base = (1.0 - a * m) * x_lp + (a * m) * y_lp + res = (base + hp).movedim(1, -1).clamp(0, 1) + return res.to(image_bhwc.dtype) + except Exception: + return image_bhwc + + +def _gaussian_blur_nchw(x: torch.Tensor, sigma: float = 1.0, radius: int = 1) -> torch.Tensor: + """Lightweight depthwise Gaussian blur for NCHW tensors. + Uses reflect padding and a normalized kernel built by _gaussian_kernel. + """ + if radius <= 0: + return x + ksz = radius * 2 + 1 + kernel = _gaussian_kernel(ksz, sigma, device=x.device).to(dtype=x.dtype) + kernel = kernel.repeat(x.shape[1], 1, 1).unsqueeze(1) # [C,1,K,K] + x_pad = F.pad(x, (radius, radius, radius, radius), mode='reflect') + y = F.conv2d(x_pad, kernel, padding=0, groups=x.shape[1]) + return y + + +def _letterbox_nchw(x: torch.Tensor, target: int, pad_val: float = 114.0 / 255.0) -> torch.Tensor: + """Letterbox a BCHW tensor to target x target with constant padding (YOLO-style). + Preserves aspect ratio, centers content, pads with pad_val. + """ + if x.ndim != 4: + return F.interpolate(x, size=(target, target), mode='bilinear', align_corners=False) + b, c, h, w = x.shape + if h == 0 or w == 0: + return F.interpolate(x, size=(target, target), mode='bilinear', align_corners=False) + r = float(min(target / max(1, h), target / max(1, w))) + nh = max(1, int(round(h * r))) + nw = max(1, int(round(w * r))) + y = F.interpolate(x, size=(nh, nw), mode='bilinear', align_corners=False) + pt = (target - nh) // 2 + pb = target - nh - pt + pl = (target - nw) // 2 + pr = target - nw - pl + if pt < 0 or pb < 0 or pl < 0 or pr < 0: + # Fallback stretch if rounding went weird + return F.interpolate(x, size=(target, target), mode='bilinear', align_corners=False) + return F.pad(y, (pl, pr, pt, pb), mode='constant', value=float(pad_val)) + + +def _fdg_filter(delta: torch.Tensor, low_gain: float, high_gain: float, sigma: float = 1.0, radius: int = 1) -> torch.Tensor: + """Frequency-Decoupled Guidance: split delta into low/high bands and reweight. + delta: [B,C,H,W] + """ + low = _gaussian_blur_nchw(delta, sigma=sigma, radius=radius) + high = delta - low + return low * float(low_gain) + high * float(high_gain) + + +def _fdg_energy_fraction(delta: torch.Tensor, sigma: float = 1.0, radius: int = 1) -> torch.Tensor: + """Return fraction of high-frequency energy: E_high / (E_low + E_high).""" + low = _gaussian_blur_nchw(delta, sigma=sigma, radius=radius) + high = delta - low + e_low = (low * low).mean(dim=(1, 2, 3), keepdim=True) + e_high = (high * high).mean(dim=(1, 2, 3), keepdim=True) + frac = e_high / (e_low + e_high + 1e-8) + return frac + + +def _fdg_split_three(delta: torch.Tensor, + sigma_lo: float = 0.8, + sigma_hi: float = 2.0, + radius: int = 1) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + sig_lo = float(max(0.05, sigma_lo)) + sig_hi = float(max(sig_lo + 1e-3, sigma_hi)) + blur_lo = _gaussian_blur_nchw(delta, sigma=sig_lo, radius=radius) + blur_hi = _gaussian_blur_nchw(delta, sigma=sig_hi, radius=radius) + low = blur_hi + mid = blur_lo - blur_hi + high = delta - blur_lo + return low, mid, high + + +def _wrap_model_with_guidance(model, guidance_mode: str, rescale_multiplier: float, momentum_beta: float, cfg_curve: float, perp_damp: float, use_zero_init: bool=False, zero_init_steps: int=0, fdg_low: float = 0.6, fdg_high: float = 1.3, fdg_sigma: float = 1.0, ze_zero_steps: int = 0, ze_adaptive: bool = False, ze_r_switch_hi: float = 0.6, ze_r_switch_lo: float = 0.45, fdg_low_adaptive: bool = False, fdg_low_min: float = 0.45, fdg_low_max: float = 0.7, fdg_ema_beta: float = 0.8, use_local_mask: bool = False, mask_inside: float = 1.0, mask_outside: float = 1.0, + midfreq_enable: bool = False, midfreq_gain: float = 0.0, midfreq_sigma_lo: float = 0.8, midfreq_sigma_hi: float = 2.0, + mahiro_plus_enable: bool = False, mahiro_plus_strength: float = 0.5, + eps_scale_enable: bool = False, eps_scale: float = 0.0): + + """Clone model and attach a cfg mixing function implementing RescaleCFG/FDG, CFGZero*/FD, or hybrid ZeResFDG. + guidance_mode: 'default' | 'RescaleCFG' | 'RescaleFDG' | 'CFGZero*' | 'CFGZeroFD' | 'ZeResFDG' + """ + if guidance_mode == "default": + return model + m = model.clone() + + # State for momentum and sigma normalization across steps + prev_delta = {"t": None} + sigma_seen = {"max": None, "min": None} + # Spectral switching/adaptive low state + spec_state = {"ema": None, "mode": "CFGZeroFD"} + + def cfg_func(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args.get("sigma", None) + x_orig = args.get("input", None) + + # Local spatial gain from CURRENT_ONNX_MASK_BCHW, resized to cond spatial size + def _local_gain_for(hw): + if not bool(use_local_mask): + return None + m = globals().get("CURRENT_ONNX_MASK_BCHW", None) + if m is None: + return None + try: + Ht, Wt = int(hw[0]), int(hw[1]) + g = m.to(device=cond.device, dtype=cond.dtype) + if g.shape[-2] != Ht or g.shape[-1] != Wt: + g = F.interpolate(g, size=(Ht, Wt), mode='bilinear', align_corners=False) + gi = float(mask_inside) + go = float(mask_outside) + gain = g * gi + (1.0 - g) * go # [B,1,H,W] + return gain + except Exception: + return None + + # Allow hybrid switch per-step + mode = guidance_mode + if guidance_mode == "ZeResFDG": + if bool(ze_adaptive): + try: + delta_raw = args["cond"] - args["uncond"] + frac_b = _fdg_energy_fraction(delta_raw, sigma=float(fdg_sigma), radius=1) # [B,1,1,1] + frac = float(frac_b.mean().clamp(0.0, 1.0).item()) + except Exception: + frac = 0.0 + if spec_state["ema"] is None: + spec_state["ema"] = frac + else: + beta = float(max(0.0, min(0.99, fdg_ema_beta))) + spec_state["ema"] = beta * float(spec_state["ema"]) + (1.0 - beta) * frac + r = float(spec_state["ema"]) + # Hysteresis: switch up/down with two thresholds + if spec_state["mode"] == "CFGZeroFD" and r >= float(ze_r_switch_hi): + spec_state["mode"] = "RescaleFDG" + elif spec_state["mode"] == "RescaleFDG" and r <= float(ze_r_switch_lo): + spec_state["mode"] = "CFGZeroFD" + mode = spec_state["mode"] + else: + try: + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + matched_idx = (sigmas == args["timestep"][0]).nonzero() + if len(matched_idx) > 0: + current_idx = matched_idx.item() + else: + current_idx = 0 + except Exception: + current_idx = 0 + mode = "CFGZeroFD" if current_idx <= int(ze_zero_steps) else "RescaleFDG" + + if mode in ("CFGZero*", "CFGZeroFD"): + # Optional zero-init for the first N steps + if use_zero_init and "model_options" in args and args.get("timestep") is not None: + try: + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + matched_idx = (sigmas == args["timestep"][0]).nonzero() + if len(matched_idx) > 0: + current_idx = matched_idx.item() + else: + # fallback lookup + current_idx = 0 + if current_idx <= int(zero_init_steps): + return cond * 0.0 + except Exception: + pass + # Project cond onto uncond subspace (batch-wise alpha) + bsz = cond.shape[0] + pos_flat = cond.view(bsz, -1) + neg_flat = uncond.view(bsz, -1) + dot = torch.sum(pos_flat * neg_flat, dim=1, keepdim=True) + denom = torch.sum(neg_flat * neg_flat, dim=1, keepdim=True).clamp_min(1e-8) + alpha = (dot / denom).view(bsz, *([1] * (cond.dim() - 1))) + resid = cond - uncond * alpha + # Adaptive low gain if enabled + low_gain_eff = float(fdg_low) + if bool(fdg_low_adaptive) and spec_state["ema"] is not None: + s = float(spec_state["ema"]) # 0..1 fraction of high-frequency energy + lmin = float(fdg_low_min) + lmax = float(fdg_low_max) + low_gain_eff = max(0.0, min(2.0, lmin + (lmax - lmin) * s)) + if mode == "CFGZeroFD": + resid = _fdg_filter(resid, low_gain=low_gain_eff, high_gain=fdg_high, sigma=float(fdg_sigma), radius=1) + # Apply local spatial gain to residual guidance + lg = _local_gain_for((cond.shape[-2], cond.shape[-1])) + if lg is not None: + resid = resid * lg.expand(-1, resid.shape[1], -1, -1) + noise_pred = uncond * alpha + cond_scale * resid + return noise_pred + + # RescaleCFG/FDG path (with optional momentum/perp damping and S-curve shaping) + delta = cond - uncond + pd = float(max(0.0, min(1.0, perp_damp))) + if pd > 0.0 and (prev_delta["t"] is not None) and (prev_delta["t"].shape == delta.shape): + prev = prev_delta["t"] + denom = (prev * prev).sum(dim=(1,2,3), keepdim=True).clamp_min(1e-6) + coeff = ((delta * prev).sum(dim=(1,2,3), keepdim=True) / denom) + parallel = coeff * prev + delta = delta - pd * parallel + beta = float(max(0.0, min(0.95, momentum_beta))) + if beta > 0.0: + if prev_delta["t"] is None or prev_delta["t"].shape != delta.shape: + prev_delta["t"] = delta.detach() + delta = (1.0 - beta) * delta + beta * prev_delta["t"] + prev_delta["t"] = delta.detach() + cond = uncond + delta + else: + prev_delta["t"] = delta.detach() + # After momentum: optionally apply FDG and rebuild cond + if mode == "RescaleFDG": + # Adaptive low gain if enabled + low_gain_eff = float(fdg_low) + if bool(fdg_low_adaptive) and spec_state["ema"] is not None: + s = float(spec_state["ema"]) # 0..1 + lmin = float(fdg_low_min) + lmax = float(fdg_low_max) + low_gain_eff = max(0.0, min(2.0, lmin + (lmax - lmin) * s)) + delta_fdg = _fdg_filter(delta, low_gain=low_gain_eff, high_gain=fdg_high, sigma=float(fdg_sigma), radius=1) + # Optional mid-frequency emphasis blended on top + if bool(midfreq_enable) and abs(float(midfreq_gain)) > 1e-6: + lo_b, mid_b, hi_b = _fdg_split_three(delta, sigma_lo=float(midfreq_sigma_lo), sigma_hi=float(midfreq_sigma_hi), radius=1) + lg = _local_gain_for((cond.shape[-2], cond.shape[-1])) + if lg is not None: + mid_b = mid_b * lg.expand(-1, mid_b.shape[1], -1, -1) + delta_fdg = delta_fdg + float(midfreq_gain) * mid_b + lg = _local_gain_for((cond.shape[-2], cond.shape[-1])) + if lg is not None: + delta_fdg = delta_fdg * lg.expand(-1, delta_fdg.shape[1], -1, -1) + cond = uncond + delta_fdg + else: + lg = _local_gain_for((cond.shape[-2], cond.shape[-1])) + if lg is not None: + delta = delta * lg.expand(-1, delta.shape[1], -1, -1) + cond = uncond + delta + + cond_scale_eff = cond_scale + if cfg_curve > 0.0 and (sigma is not None): + s = sigma + if s.ndim > 1: + s = s.flatten() + s_max = float(torch.max(s).item()) + s_min = float(torch.min(s).item()) + if sigma_seen["max"] is None: + sigma_seen["max"] = s_max + sigma_seen["min"] = s_min + else: + sigma_seen["max"] = max(sigma_seen["max"], s_max) + sigma_seen["min"] = min(sigma_seen["min"], s_min) + lo = max(1e-6, sigma_seen["min"]) + hi = max(lo * (1.0 + 1e-6), sigma_seen["max"]) + t = (torch.log(s + 1e-6) - torch.log(torch.tensor(lo, device=sigma.device))) / (torch.log(torch.tensor(hi, device=sigma.device)) - torch.log(torch.tensor(lo, device=sigma.device)) + 1e-6) + t = t.clamp(0.0, 1.0) + k = 6.0 * float(cfg_curve) + s_curve = torch.tanh((t - 0.5) * k) + gain = 1.0 + 0.15 * float(cfg_curve) * s_curve + if gain.ndim > 0: + gain = gain.mean().item() + cond_scale_eff = cond_scale * float(gain) + + # Epsilon scaling (exposure bias correction): early steps get multiplier closer to (1 + eps_scale) + eps_mult = 1.0 + if bool(eps_scale_enable) and (sigma is not None): + try: + s = sigma + if s.ndim > 1: + s = s.flatten() + s_max = float(torch.max(s).item()) + s_min = float(torch.min(s).item()) + if sigma_seen["max"] is None: + sigma_seen["max"] = s_max + sigma_seen["min"] = s_min + else: + sigma_seen["max"] = max(sigma_seen["max"], s_max) + sigma_seen["min"] = min(sigma_seen["min"], s_min) + lo = max(1e-6, sigma_seen["min"]) + hi = max(lo * (1.0 + 1e-6), sigma_seen["max"]) + t_lin = (torch.log(s + 1e-6) - torch.log(torch.tensor(lo, device=sigma.device))) / (torch.log(torch.tensor(hi, device=sigma.device)) - torch.log(torch.tensor(lo, device=sigma.device)) + 1e-6) + t_lin = t_lin.clamp(0.0, 1.0) + w_early = (1.0 - t_lin).mean().item() + eps_mult = float(1.0 + eps_scale * w_early) + except Exception: + eps_mult = float(1.0 + eps_scale) + + if sigma is None or x_orig is None: + return uncond + cond_scale * (cond - uncond) + sigma_ = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x = x_orig / (sigma_ * sigma_ + 1.0) + v_cond = ((x - (x_orig - cond)) * (sigma_ ** 2 + 1.0) ** 0.5) / (sigma_) + v_uncond = ((x - (x_orig - uncond)) * (sigma_ ** 2 + 1.0) ** 0.5) / (sigma_) + v_cfg = v_uncond + cond_scale_eff * (v_cond - v_uncond) + ro_pos = torch.std(v_cond, dim=(1, 2, 3), keepdim=True) + ro_cfg = torch.std(v_cfg, dim=(1, 2, 3), keepdim=True).clamp_min(1e-6) + v_rescaled = v_cfg * (ro_pos / ro_cfg) + v_final = float(rescale_multiplier) * v_rescaled + (1.0 - float(rescale_multiplier)) * v_cfg + eps = x_orig - (x - (v_final * eps_mult) * sigma_ / (sigma_ * sigma_ + 1.0) ** 0.5) + return eps + + m.set_model_sampler_cfg_function(cfg_func, disable_cfg1_optimization=True) + + # Optional directional post-mix (Muse Blend), global, no ONNX + if bool(mahiro_plus_enable): + s_clamp = float(max(0.0, min(1.0, mahiro_plus_strength))) + mb_state = {"ema": None} + + def _sqrt_sign(x: torch.Tensor) -> torch.Tensor: + return x.sign() * torch.sqrt(x.abs().clamp_min(1e-12)) + + def _hp_split(x: torch.Tensor, radius: int = 1, sigma: float = 1.0): + low = _gaussian_blur_nchw(x, sigma=sigma, radius=radius) + high = x - low + return low, high + + def _sched_gain(args) -> float: + # Gentle mid-steps boost: triangle peak at the middle of schedule + try: + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + idx_t = args.get("timestep", None) + if idx_t is None: + return 1.0 + matched = (sigmas == idx_t[0]).nonzero() + if len(matched) == 0: + return 1.0 + i = float(matched.item()) + n = float(sigmas.shape[0]) + if n <= 1: + return 1.0 + phase = i / (n - 1.0) + tri = 1.0 - abs(2.0 * phase - 1.0) + return float(0.6 + 0.4 * tri) # 0.6 at edges -> 1.0 mid + except Exception: + return 1.0 + + def mahiro_plus_post(args): + try: + scale = args.get('cond_scale', 1.0) + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + cfg = args['denoised'] + + # Orthogonalize positive to negative direction (batch-wise) + bsz = cond_p.shape[0] + pos_flat = cond_p.view(bsz, -1) + neg_flat = uncond_p.view(bsz, -1) + dot = torch.sum(pos_flat * neg_flat, dim=1, keepdim=True) + denom = torch.sum(neg_flat * neg_flat, dim=1, keepdim=True).clamp_min(1e-8) + alpha = (dot / denom).view(bsz, *([1] * (cond_p.dim() - 1))) + c_orth = cond_p - uncond_p * alpha + + leap_raw = float(scale) * c_orth + # Light high-pass emphasis for detail, protect low-frequency tone + low, high = _hp_split(leap_raw, radius=1, sigma=1.0) + leap = 0.35 * low + 1.00 * high + + # Directional agreement (global cosine over flattened dims) + u_leap = float(scale) * uncond_p + merge = 0.5 * (leap + cfg) + nu = _sqrt_sign(u_leap).flatten(1) + nm = _sqrt_sign(merge).flatten(1) + sim = F.cosine_similarity(nu, nm, dim=1).mean() + a = torch.clamp((sim + 1.0) * 0.5, 0.0, 1.0) + # Small EMA for temporal smoothness + if mb_state["ema"] is None: + mb_state["ema"] = float(a) + else: + mb_state["ema"] = 0.8 * float(mb_state["ema"]) + 0.2 * float(a) + a_eff = float(mb_state["ema"]) + w = a_eff * cfg + (1.0 - a_eff) * leap + + # Gentle energy match to CFG + dims = tuple(range(1, w.dim())) + ro_w = torch.std(w, dim=dims, keepdim=True).clamp_min(1e-6) + ro_cfg = torch.std(cfg, dim=dims, keepdim=True).clamp_min(1e-6) + w_res = w * (ro_cfg / ro_w) + + # Schedule gain over steps (mid stronger) + s_eff = s_clamp * _sched_gain(args) + out = (1.0 - s_eff) * cfg + s_eff * w_res + return out + except Exception: + return args['denoised'] + + try: + m.set_model_sampler_post_cfg_function(mahiro_plus_post) + except Exception: + pass + + # Quantile clamp stabilizer (per-sample): soft range limit for denoised tensor + # Always on, under the hood. Helps prevent rare exploding values. + def _qclamp_post(args): + try: + x = args.get("denoised", None) + if x is None: + return args["denoised"] + dt = x.dtype + xf = x.to(dtype=torch.float32) + B = xf.shape[0] + lo_q, hi_q = 0.001, 0.999 + out = [] + for i in range(B): + t = xf[i].reshape(-1) + try: + lo = torch.quantile(t, lo_q) + hi = torch.quantile(t, hi_q) + except Exception: + n = t.numel() + k_lo = max(1, int(n * lo_q)) + k_hi = max(1, int(n * hi_q)) + lo = torch.kthvalue(t, k_lo).values + hi = torch.kthvalue(t, k_hi).values + out.append(xf[i].clamp(min=lo, max=hi)) + y = torch.stack(out, dim=0).to(dtype=dt) + return y + except Exception: + return args["denoised"] + + try: + m.set_model_sampler_post_cfg_function(_qclamp_post) + except Exception: + pass + + return m + + +def _edge_mask(image_bhwc: torch.Tensor, + threshold: float = 0.20, + blur: float = 1.0) -> torch.Tensor: + """Return a simple edge mask BHWC [0,1] using Sobel magnitude on luminance. + It is resolution-agnostic and intentionally lightweight. + """ + try: + img = image_bhwc + lum = (0.2126 * img[..., 0] + 0.7152 * img[..., 1] + 0.0722 * img[..., 2]) + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=img.device, dtype=img.dtype).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=img.device, dtype=img.dtype).view(1, 1, 3, 3) + g = torch.sqrt((F.conv2d(lum.unsqueeze(1), kx, padding=1) ** 2) + (F.conv2d(lum.unsqueeze(1), ky, padding=1) ** 2)) + g = g.squeeze(1) + # Robust normalization via 98th percentile + try: + q = torch.quantile(g.flatten(), 0.98).clamp_min(1e-6) + except Exception: + q = torch.topk(g.flatten(), max(1, int(g.numel() * 0.02))).values.min().clamp_min(1e-6) + m = (g / q).clamp(0, 1) + if threshold > 0.0: + m = (m > float(threshold)).to(img.dtype) + if blur > 0.0: + rad = int(max(1, min(5, round(float(blur))))) + m = _gaussian_blur_nchw(m.unsqueeze(1), sigma=float(max(0.5, blur)), radius=rad).squeeze(1) + return m.unsqueeze(-1) + except Exception: + return torch.zeros((image_bhwc.shape[0], image_bhwc.shape[1], image_bhwc.shape[2], 1), device=image_bhwc.device, dtype=image_bhwc.dtype) + + + +def _cf_edges_post(acc_t: torch.Tensor, + edge_width: float, + edge_smooth: float, + edge_single_line: bool, + edge_single_strength: float) -> torch.Tensor: + try: + import cv2, numpy as _np + img = (acc_t.clamp(0,1).detach().to('cpu').numpy()*255.0).astype(_np.uint8) + # Thickness adjust + if float(edge_width) != 0.0: + s = float(edge_width) + k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) + op = cv2.dilate if s > 0 else cv2.erode + it = int(max(0, min(6, round(abs(s) * 4.0)))) + frac = abs(s) * 4.0 - it + for _ in range(max(0, it)): + img = op(img, k, iterations=1) + if frac > 1e-6: + y2 = op(img, k, iterations=1) + img = ((1.0-frac)*img.astype(_np.float32) + frac*y2.astype(_np.float32)).astype(_np.uint8) + # Collapse double lines to single centerline + if bool(edge_single_line) and float(edge_single_strength) > 1e-6: + try: + s = float(edge_single_strength) + close = cv2.morphologyEx(img, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)), iterations=1) + if hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'thinning'): + sk = cv2.ximgproc.thinning(close) + else: + iters = max(1, int(round(2 + 6*s))) + sk = _np.zeros_like(close) + src = close.copy() + elem = cv2.getStructuringElement(cv2.MORPH_CROSS, (3,3)) + for _ in range(iters): + er = cv2.erode(src, elem, iterations=1) + opn = cv2.morphologyEx(er, cv2.MORPH_OPEN, elem) + tmp = cv2.subtract(er, opn) + sk = cv2.bitwise_or(sk, tmp) + src = er + if not _np.any(src): + break + img = ((1.0 - s) * img.astype(_np.float32) + s * sk.astype(_np.float32)).astype(_np.uint8) + except Exception: + pass + # Smooth + if float(edge_smooth) > 1e-6: + sigma = max(0.1, min(2.0, float(edge_smooth) * 1.2)) + img = cv2.GaussianBlur(img, (0,0), sigmaX=sigma) + out = torch.from_numpy((img.astype(_np.float32)/255.0)).to(device=acc_t.device, dtype=acc_t.dtype) + return out.clamp(0,1) + except Exception: + # Torch fallback: light blur-only and basic thicken/thin + y = acc_t + if float(edge_width) > 1e-6: + k = max(1, int(round(float(edge_width) * 2))) + p = k + y = F.max_pool2d(y.unsqueeze(0).unsqueeze(0), kernel_size=2*k+1, stride=1, padding=p)[0,0] + if float(edge_width) < -1e-6: + k = max(1, int(round(abs(float(edge_width)) * 2))) + p = k + maxed = F.max_pool2d((1.0 - y).unsqueeze(0).unsqueeze(0), kernel_size=2*k+1, stride=1, padding=p)[0,0] + y = 1.0 - maxed + if float(edge_smooth) > 1e-6: + s = max(1, int(round(float(edge_smooth)*2))) + y = F.avg_pool2d(y.unsqueeze(0).unsqueeze(0), kernel_size=2*s+1, stride=1, padding=s)[0,0] + return y.clamp(0,1) + + +def _build_cf_edge_mask_from_step(image_bhwc: torch.Tensor, preset_step: str) -> torch.Tensor | None: + try: + p = load_preset("mg_controlfusion", preset_step) + # Safe converters (preset values may be blank strings) + def _safe_int(val, default): + try: + iv = int(val) + return iv if iv > 0 else default + except Exception: + return default + def _safe_float(val, default): + try: + return float(val) + except Exception: + return default + + # Read CF params with safe defaults + enable_depth = bool(p.get('enable_depth', True)) + depth_model_path = str(p.get('depth_model_path', '')) + depth_resolution = _safe_int(p.get('depth_resolution', 768), 768) + hires_mask_auto = bool(p.get('hires_mask_auto', True)) + pyra_low = _safe_int(p.get('pyra_low', 109), 109) + pyra_high = _safe_int(p.get('pyra_high', 147), 147) + pyra_resolution = _safe_int(p.get('pyra_resolution', 1024), 1024) + edge_thin_iter = int(p.get('edge_thin_iter', 0)) + edge_boost = _safe_float(p.get('edge_boost', 0.0), 0.0) + smart_tune = bool(p.get('smart_tune', False)) + smart_boost = _safe_float(p.get('smart_boost', 0.2), 0.2) + edge_width = _safe_float(p.get('edge_width', 0.0), 0.0) + edge_smooth = _safe_float(p.get('edge_smooth', 0.0), 0.0) + edge_single_line = bool(p.get('edge_single_line', False)) + edge_single_strength = _safe_float(p.get('edge_single_strength', 0.0), 0.0) + edge_depth_gate = bool(p.get('edge_depth_gate', False)) + edge_depth_gamma = _safe_float(p.get('edge_depth_gamma', 1.5), 1.5) + edge_alpha = _safe_float(p.get('edge_alpha', 1.0), 1.0) + # Treat blend_factor as extra gain for edges (depth is not mixed here) + blend_factor = _safe_float(p.get('blend_factor', 0.02), 0.02) + # ControlNet multipliers section — use edge_strength_mul as an additional gain for edge mask + edge_strength_mul = _safe_float(p.get('edge_strength_mul', 1.0), 1.0) + + # Build edges with CF PyraCanny + ed = _cf_pyracanny(image_bhwc, pyra_low, pyra_high, pyra_resolution, + edge_thin_iter, edge_boost, smart_tune, smart_boost, + preserve_aspect=bool(hires_mask_auto)) + ed = _cf_edges_post(ed, edge_width, edge_smooth, edge_single_line, edge_single_strength) + # Depth-gate edges if enabled + if edge_depth_gate and enable_depth: + try: + depth = _cf_build_depth(image_bhwc, int(depth_resolution), str(depth_model_path), bool(hires_mask_auto)) + g = depth.clamp(0,1) ** float(edge_depth_gamma) + ed = (ed * g).clamp(0,1) + except Exception: + pass + # Apply opacity + edge strength + blend factor (as extra gain for edges only) + total_gain = max(0.0, float(edge_alpha)) * max(0.0, float(edge_strength_mul)) * max(0.0, float(blend_factor)) + # Keep at least alpha if blend_factor is set to 0 in presets + if total_gain == 0.0: + total_gain = max(0.0, float(edge_alpha)) + ed = (ed * total_gain).clamp(0,1) + # Return BHWC single-channel + return ed.unsqueeze(0).unsqueeze(-1) + except Exception: + return None +def _mask_dilate(mask_bhw1: torch.Tensor, k: int = 3) -> torch.Tensor: + if k <= 1: + return mask_bhw1 + m = mask_bhw1.movedim(-1, 1) + m = F.max_pool2d(m, kernel_size=k, stride=1, padding=k // 2) + return m.movedim(1, -1) + + +def _mask_erode(mask_bhw1: torch.Tensor, k: int = 3) -> torch.Tensor: + if k <= 1: + return mask_bhw1 + m = mask_bhw1.movedim(-1, 1) + e = 1.0 - F.max_pool2d(1.0 - m, kernel_size=k, stride=1, padding=k // 2) + return e.movedim(1, -1) + + +class ComfyAdaptiveDetailEnhancer25: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "preset_step": (["Step 1", "Step 2", "Step 3", "Step 4"], {"default": "Step 1", "tooltip": "Choose the Step preset. Toggle Custom below to apply UI values; otherwise Step preset values are used."}), + "custom": ("BOOLEAN", {"default": False, "tooltip": "Custom override: when enabled, your UI values override the selected Step for visible controls; hidden parameters still come from the Step preset."}), "model": ("MODEL", {}), + "positive": ("CONDITIONING", {}), + "negative": ("CONDITIONING", {}), + "vae": ("VAE", {}), + "latent": ("LATENT", {}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "denoise": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.0001}), + "sampler_name": (_sampler_names(), {"default": _sampler_names()[0]}), + "scheduler": (_scheduler_names(), {"default": _scheduler_names()[0]}), + "iterations": ("INT", {"default": 1, "min": 1, "max": 1000}), + "steps_delta": ("FLOAT", {"default": 0.0, "min": -1000.0, "max": 1000.0, "step": 0.01}), + "cfg_delta": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.01}), + "denoise_delta": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.0001}), + "apply_sharpen": ("BOOLEAN", {"default": False}), + "apply_upscale": ("BOOLEAN", {"default": False}), + "apply_ids": ("BOOLEAN", {"default": False}), + "clip_clean": ("BOOLEAN", {"default": False}), + "ids_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "upscale_method": (MagicUpscaleModule.upscale_methods, {"default": "lanczos"}), + "scale_by": ("FLOAT", {"default": 1.2, "min": 1.0, "max": 8.0, "step": 0.01}), + "scale_delta": ("FLOAT", {"default": 0.0, "min": -8.0, "max": 8.0, "step": 0.01}), + "noise_offset": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 0.5, "step": 0.01}), + "threshold": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "RMS latent drift threshold (smaller = more damping)."}), + }, + "optional": { + "Sharpnes_strenght": ("FLOAT", {"default": 0.300, "min": 0.0, "max": 1.0, "step": 0.001}), + "latent_compare": ("BOOLEAN", {"default": False, "tooltip": "Use latent drift to gently damp params (safer than overwriting latents)."}), + "accumulation": (["default", "fp32+fp16", "fp32+fp32"], {"default": "default", "tooltip": "Override SageAttention PV accumulation mode for this node run."}), + "reference_clean": ("BOOLEAN", {"default": False, "tooltip": "Use CLIP-Vision similarity to a reference image to stabilize output."}), + "reference_image": ("IMAGE", {}), + "clip_vision": ("CLIP_VISION", {}), + "ref_preview": ("INT", {"default": 224, "min": 64, "max": 512, "step": 16}), + "ref_threshold": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 0.2, "step": 0.001}), + "ref_cooldown": ("INT", {"default": 1, "min": 1, "max": 8}), + + # ONNX detectors (beta) unified toggle for Hands/Face/Pose + "onnx_detectors": ("BOOLEAN", {"default": False, "tooltip": "Use auto ONNX detectors (any .onnx in models) to refine artifact mask."}), + "onnx_preview": ("INT", {"default": 224, "min": 64, "max": 512, "step": 16, "tooltip": "Square preview size fed to ONNX models."}), + "onnx_sensitivity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 2.0, "step": 0.01, "tooltip": "Global gain for fused ONNX mask."}), + "onnx_anomaly_gain": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01, "tooltip": "Extra gain for 'anomaly' models (e.g., anomaly_det.onnx)."}), + + # Guidance controls + "guidance_mode": (["default", "RescaleCFG", "RescaleFDG", "CFGZero*", "CFGZeroFD", "ZeResFDG"], {"default": "RescaleCFG", "tooltip": "Rescale (stable), RescaleFDG (spectral), CFGZero*, CFGZeroFD, or hybrid ZeResFDG."}), + "rescale_multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Blend between rescaled and plain CFG (like comfy RescaleCFG)."}), + "momentum_beta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 0.95, "step": 0.01, "tooltip": "EMA momentum in eps-space for (cond-uncond), 0 to disable."}), + "cfg_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "S-curve shaping of cond_scale across steps (0=flat)."}), + "perp_damp": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Remove a small portion of the component parallel to previous delta (0-1)."}), + + # NAG (Normalized Attention Guidance) toggles + "use_nag": ("BOOLEAN", {"default": False, "tooltip": "Apply NAG inside CrossAttention (positive branch) during this node."}), + "nag_scale": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), + + # CFGZero* extras + "use_zero_init": ("BOOLEAN", {"default": False, "tooltip": "For CFGZero*, zero out first few steps."}), + "zero_init_steps": ("INT", {"default": 0, "min": 0, "max": 20, "step": 1}), + + # FDG controls (placed last to avoid reordering existing fields) + "fdg_low": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.01, "tooltip": "Low-frequency gain (<1 to restrain masses)."}), + "fdg_high": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 2.5, "step": 0.01, "tooltip": "High-frequency gain (>1 to boost details)."}), + "fdg_sigma": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 2.5, "step": 0.05, "tooltip": "Gaussian sigma for FDG low-pass split."}), + "ze_res_zero_steps": ("INT", {"default": 2, "min": 0, "max": 20, "step": 1, "tooltip": "Hybrid: number of initial steps to use CFGZeroFD before switching to RescaleFDG."}), + + # Adaptive spectral switch (ZeRes) and adaptive low gain + "ze_adaptive": ("BOOLEAN", {"default": False, "tooltip": "Enable spectral switch: CFGZeroFD, RescaleFDG by HF/LF ratio (EMA)."}), + "ze_r_switch_hi": ("FLOAT", {"default": 0.60, "min": 0.10, "max": 0.95, "step": 0.01, "tooltip": "Switch to RescaleFDG when EMA fraction of high-frequency."}), + "ze_r_switch_lo": ("FLOAT", {"default": 0.45, "min": 0.05, "max": 0.90, "step": 0.01, "tooltip": "Switch back to CFGZeroFD when EMA fraction (hysteresis)."}), + "fdg_low_adaptive": ("BOOLEAN", {"default": False, "tooltip": "Adapt fdg_low by HF fraction (EMA)."}), + "fdg_low_min": ("FLOAT", {"default": 0.45, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Lower bound for adaptive fdg_low."}), + "fdg_low_max": ("FLOAT", {"default": 0.70, "min": 0.0, "max": 2.0, "step": 0.01, "tooltip": "Upper bound for adaptive fdg_low."}), + "fdg_ema_beta": ("FLOAT", {"default": 0.80, "min": 0.0, "max": 0.99, "step": 0.01, "tooltip": "EMA smoothing for spectral ratio (higher = smoother)."}), + + # ONNX local guidance (placed last to avoid reordering) + "onnx_local_guidance": ("BOOLEAN", {"default": False, "tooltip": "Modulate guidance spatially by ONNX mask."}), + "onnx_mask_inside": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Multiplier for guidance inside mask (protects)."}), + "onnx_mask_outside": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 2.0, "step": 0.01, "tooltip": "Multiplier for guidance outside mask."}), + "onnx_debug": ("BOOLEAN", {"default": False, "tooltip": "Print ONNX mask area per iteration."}), + + # ONNX wholebody keypoints local heatmap (placed last) + "onnx_kpts_enable": ("BOOLEAN", {"default": False, "tooltip": "Parse YOLO wholebody keypoints and add local heatmap."}), + "onnx_kpts_sigma": ("FLOAT", {"default": 2.5, "min": 0.5, "max": 8.0, "step": 0.1, "tooltip": "Keypoint Gaussian sigma multiplier."}), + "onnx_kpts_gain": ("FLOAT", {"default": 1.5, "min": 0.1, "max": 5.0, "step": 0.1, "tooltip": "Keypoint heat amplitude multiplier."}), + "onnx_kpts_conf": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Keypoint confidence threshold."}), + + # Muse Blend global directional post-mix + "muse_blend": ("BOOLEAN", {"default": False, "tooltip": "Enable Muse Blend: gentle directional positive blend (global)."}), + "muse_blend_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Overall influence of Muse Blend over baseline CFG (0..1)."}), + # Exposure Bias Correction (epsilon scaling) + "eps_scale_enable": ("BOOLEAN", {"default": False, "tooltip": "Exposure Bias Correction: scale predicted noise early in schedule."}), + "eps_scale": ("FLOAT", {"default": 0.005, "min": -1.0, "max": 1.0, "step": 0.0005, "tooltip": "Signed scaling near early steps (recommended ~0.0045; use with care)."}), + "clipseg_enable": ("BOOLEAN", {"default": False, "tooltip": "Use CLIPSeg to build a text-driven mask (e.g., 'eyes | hands | face')."}), + "clipseg_text": ("STRING", {"default": "", "multiline": False}), + "clipseg_preview": ("INT", {"default": 224, "min": 64, "max": 512, "step": 16}), + "clipseg_threshold": ("FLOAT", {"default": 0.40, "min": 0.0, "max": 1.0, "step": 0.05}), + "clipseg_blur": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 15.0, "step": 0.1}), + "clipseg_dilate": ("INT", {"default": 4, "min": 0, "max": 10, "step": 1}), + "clipseg_gain": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01}), + "clipseg_blend": (["fuse", "replace", "intersect"], {"default": "fuse", "tooltip": "How to combine CLIPSeg with ONNX mask."}), + "clipseg_ref_gate": ("BOOLEAN", {"default": False, "tooltip": "If reference provided, boost mask when far from reference (CLIP-Vision)."}), + "clipseg_ref_threshold": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 0.2, "step": 0.001}), + + # Polish mode (final hi-res refinement) + "polish_enable": ("BOOLEAN", {"default": False, "tooltip": "Polish: keep low-frequency shape from reference while allowing high-frequency details to refine."}), + "polish_keep_low": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "How much low-frequency (global form, lighting) to take from reference image (0=use current, 1=use reference)."}), + "polish_edge_lock": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Edge lock strength: protects edges from sideways drift (0=off, 1=strong)."}), + "polish_sigma": ("FLOAT", {"default": 1.0, "min": 0.3, "max": 3.0, "step": 0.1, "tooltip": "Radius for low/high split: larger keeps bigger shapes as 'low' (global form)."}), + "polish_start_after": ("INT", {"default": 1, "min": 0, "max": 3, "step": 1, "tooltip": "Enable polish after N iterations (0=immediately)."}), + "polish_keep_low_ramp": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Starting share of low-frequency mix; ramps to polish_keep_low over remaining iterations."}), + + }, + } + + RETURN_TYPES = ("LATENT", "IMAGE", "INT", "FLOAT", "FLOAT", "IMAGE") + RETURN_NAMES = ("LATENT", "IMAGE", "steps", "cfg", "denoise", "mask_preview") + FUNCTION = "apply_cade2" + CATEGORY = "MagicNodes" + + def apply_cade2(self, model, vae, positive, negative, latent, seed, steps, cfg, denoise, + sampler_name, scheduler, noise_offset, iterations=1, steps_delta=0.0, + cfg_delta=0.0, denoise_delta=0.0, apply_sharpen=False, + apply_upscale=False, apply_ids=False, clip_clean=False, + ids_strength=0.5, upscale_method="lanczos", scale_by=1.2, scale_delta=0.0, + Sharpnes_strenght=0.300, threshold=0.03, latent_compare=False, accumulation="default", + reference_clean=False, reference_image=None, clip_vision=None, ref_preview=224, ref_threshold=0.03, ref_cooldown=1, + onnx_detectors=False, onnx_preview=224, onnx_sensitivity=0.5, onnx_anomaly_gain=1.0, + guidance_mode="RescaleCFG", rescale_multiplier=0.7, momentum_beta=0.0, cfg_curve=0.0, perp_damp=0.0, + use_nag=False, nag_scale=4.0, nag_tau=2.5, nag_alpha=0.25, + use_zero_init=False, zero_init_steps=0, + fdg_low=0.6, fdg_high=1.3, fdg_sigma=1.0, ze_res_zero_steps=2, + ze_adaptive=False, ze_r_switch_hi=0.60, ze_r_switch_lo=0.45, + fdg_low_adaptive=False, fdg_low_min=0.45, fdg_low_max=0.70, fdg_ema_beta=0.80, + onnx_local_guidance=False, onnx_mask_inside=1.0, onnx_mask_outside=1.0, onnx_debug=False, + onnx_kpts_enable=False, onnx_kpts_sigma=2.5, onnx_kpts_gain=1.5, onnx_kpts_conf=0.20, + muse_blend=False, muse_blend_strength=0.5, + eps_scale_enable=False, eps_scale=0.005, + clipseg_enable=False, clipseg_text="", clipseg_preview=224, + clipseg_threshold=0.40, clipseg_blur=7.0, clipseg_dilate=4, + clipseg_gain=1.0, clipseg_blend="fuse", clipseg_ref_gate=False, clipseg_ref_threshold=0.03, + polish_enable=False, polish_keep_low=0.4, polish_edge_lock=0.2, polish_sigma=1.0, + polish_start_after=1, polish_keep_low_ramp=0.2, + preset_step="Step 1", custom_override=False): + # Load base preset for the selected Step. When custom_override is True, + # visible UI controls (top-level) are kept from UI; hidden ones still come from preset. + try: + p = load_preset("mg_cade25", preset_step) if isinstance(preset_step, str) else {} + except Exception: + p = {} + def pv(name, cur, top=False): + return cur if (top and bool(custom_override)) else p.get(name, cur) + seed = int(pv("seed", seed, top=True)) + steps = int(pv("steps", steps, top=True)) + cfg = float(pv("cfg", cfg, top=True)) + denoise = float(pv("denoise", denoise, top=True)) + sampler_name = str(pv("sampler_name", sampler_name, top=True)) + scheduler = str(pv("scheduler", scheduler, top=True)) + iterations = int(pv("iterations", iterations)) + steps_delta = float(pv("steps_delta", steps_delta)) + cfg_delta = float(pv("cfg_delta", cfg_delta)) + denoise_delta = float(pv("denoise_delta", denoise_delta)) + apply_sharpen = bool(pv("apply_sharpen", apply_sharpen)) + apply_upscale = bool(pv("apply_upscale", apply_upscale)) + apply_ids = bool(pv("apply_ids", apply_ids)) + clip_clean = bool(pv("clip_clean", clip_clean)) + ids_strength = float(pv("ids_strength", ids_strength)) + upscale_method = str(pv("upscale_method", upscale_method)) + scale_by = float(pv("scale_by", scale_by)) + scale_delta = float(pv("scale_delta", scale_delta)) + noise_offset = float(pv("noise_offset", noise_offset)) + threshold = float(pv("threshold", threshold)) + Sharpnes_strenght = float(pv("Sharpnes_strenght", Sharpnes_strenght)) + latent_compare = bool(pv("latent_compare", latent_compare)) + accumulation = str(pv("accumulation", accumulation)) + reference_clean = bool(pv("reference_clean", reference_clean)) + ref_preview = int(pv("ref_preview", ref_preview)) + ref_threshold = float(pv("ref_threshold", ref_threshold)) + ref_cooldown = int(pv("ref_cooldown", ref_cooldown)) + onnx_detectors = bool(pv("onnx_detectors", onnx_detectors)) + onnx_preview = int(pv("onnx_preview", onnx_preview)) + onnx_sensitivity = float(pv("onnx_sensitivity", onnx_sensitivity)) + onnx_anomaly_gain = float(pv("onnx_anomaly_gain", onnx_anomaly_gain)) + guidance_mode = str(pv("guidance_mode", guidance_mode)) + rescale_multiplier = float(pv("rescale_multiplier", rescale_multiplier)) + momentum_beta = float(pv("momentum_beta", momentum_beta)) + cfg_curve = float(pv("cfg_curve", cfg_curve)) + perp_damp = float(pv("perp_damp", perp_damp)) + use_nag = bool(pv("use_nag", use_nag)) + nag_scale = float(pv("nag_scale", nag_scale)) + nag_tau = float(pv("nag_tau", nag_tau)) + nag_alpha = float(pv("nag_alpha", nag_alpha)) + use_zero_init = bool(pv("use_zero_init", use_zero_init)) + zero_init_steps = int(pv("zero_init_steps", zero_init_steps)) + fdg_low = float(pv("fdg_low", fdg_low)) + fdg_high = float(pv("fdg_high", fdg_high)) + fdg_sigma = float(pv("fdg_sigma", fdg_sigma)) + ze_res_zero_steps = int(pv("ze_res_zero_steps", ze_res_zero_steps)) + ze_adaptive = bool(pv("ze_adaptive", ze_adaptive)) + ze_r_switch_hi = float(pv("ze_r_switch_hi", ze_r_switch_hi)) + ze_r_switch_lo = float(pv("ze_r_switch_lo", ze_r_switch_lo)) + fdg_low_adaptive = bool(pv("fdg_low_adaptive", fdg_low_adaptive)) + fdg_low_min = float(pv("fdg_low_min", fdg_low_min)) + fdg_low_max = float(pv("fdg_low_max", fdg_low_max)) + fdg_ema_beta = float(pv("fdg_ema_beta", fdg_ema_beta)) + # AQClip-Lite (hidden in Easy UI, controllable via presets) + aqclip_enable = bool(pv("aqclip_enable", False)) + aq_tile = int(pv("aq_tile", 32)) + aq_stride = int(pv("aq_stride", 16)) + aq_alpha = float(pv("aq_alpha", 2.0)) + aq_ema_beta = float(pv("aq_ema_beta", 0.8)) + midfreq_enable = bool(pv("midfreq_enable", False)) + midfreq_gain = float(pv("midfreq_gain", 0.0)) + midfreq_sigma_lo = float(pv("midfreq_sigma_lo", 0.8)) + midfreq_sigma_hi = float(pv("midfreq_sigma_hi", 2.0)) + onnx_local_guidance = bool(pv("onnx_local_guidance", onnx_local_guidance)) + onnx_mask_inside = float(pv("onnx_mask_inside", onnx_mask_inside)) + onnx_mask_outside = float(pv("onnx_mask_outside", onnx_mask_outside)) + onnx_debug = bool(pv("onnx_debug", onnx_debug)) + onnx_kpts_enable = bool(pv("onnx_kpts_enable", onnx_kpts_enable)) + onnx_kpts_sigma = float(pv("onnx_kpts_sigma", onnx_kpts_sigma)) + onnx_kpts_gain = float(pv("onnx_kpts_gain", onnx_kpts_gain)) + onnx_kpts_conf = float(pv("onnx_kpts_conf", onnx_kpts_conf)) + muse_blend = bool(pv("muse_blend", muse_blend)) + muse_blend_strength = float(pv("muse_blend_strength", muse_blend_strength)) + eps_scale_enable = bool(pv("eps_scale_enable", eps_scale_enable)) + eps_scale = float(pv("eps_scale", eps_scale)) + clipseg_enable = bool(pv("clipseg_enable", clipseg_enable)) + clipseg_text = str(pv("clipseg_text", clipseg_text, top=True)) + clipseg_preview = int(pv("clipseg_preview", clipseg_preview)) + clipseg_threshold = float(pv("clipseg_threshold", clipseg_threshold)) + clipseg_blur = float(pv("clipseg_blur", clipseg_blur)) + clipseg_dilate = int(pv("clipseg_dilate", clipseg_dilate)) + clipseg_gain = float(pv("clipseg_gain", clipseg_gain)) + clipseg_blend = str(pv("clipseg_blend", clipseg_blend)) + clipseg_ref_gate = bool(pv("clipseg_ref_gate", clipseg_ref_gate)) + clipseg_ref_threshold = float(pv("clipseg_ref_threshold", clipseg_ref_threshold)) + polish_enable = bool(pv("polish_enable", polish_enable)) + polish_keep_low = float(pv("polish_keep_low", polish_keep_low)) + polish_edge_lock = float(pv("polish_edge_lock", polish_edge_lock)) + polish_sigma = float(pv("polish_sigma", polish_sigma)) + polish_start_after = int(pv("polish_start_after", polish_start_after)) + polish_keep_low_ramp = float(pv("polish_keep_low_ramp", polish_keep_low_ramp)) + # CADE Seg: per-step toggle to include CF edges into Seg mask + seg_use_cf_edges = bool(pv("seg_use_cf_edges", True)) + # Hard reset of any sticky globals from prior runs + try: + global CURRENT_ONNX_MASK_BCHW, _ONNX_KPTS_ENABLE, _ONNX_KPTS_SIGMA, _ONNX_KPTS_GAIN, _ONNX_KPTS_CONF + CURRENT_ONNX_MASK_BCHW = None + # Reset KPTS toggles to sane defaults; they will be set again if enabled below + _ONNX_KPTS_ENABLE = False + _ONNX_KPTS_SIGMA = 2.5 + _ONNX_KPTS_GAIN = 1.5 + _ONNX_KPTS_CONF = 0.20 + except Exception: + pass + + image = safe_decode(vae, latent) + + tuned_steps, tuned_cfg, tuned_denoise = AdaptiveSamplerHelper().tune( + image, steps, cfg, denoise) + + current_steps = tuned_steps + current_cfg = tuned_cfg + current_denoise = tuned_denoise + # Work on a detached copy to avoid mutating input latent across runs + try: + current_latent = {"samples": latent["samples"].clone()} + except Exception: + current_latent = {"samples": latent["samples"]} + current_scale = scale_by + + # Derive a user-friendly step tag for logs + try: + _ps = str(preset_step) + _num = ''.join(ch for ch in _ps if ch.isdigit()) + step_tag = f"Step:{_num}" if _num else _ps + except Exception: + step_tag = str(preset_step) + + # Smart seed selection (Sobol + light probing) when effective seed==0 and not in custom override mode + try: + if int(seed) == 0 and not bool(custom_override): + seed = _smart_seed_select( + model, vae, positive, negative, current_latent, + str(sampler_name), str(scheduler), float(current_cfg), float(current_denoise), + base_seed=0, step_tag=step_tag, + clip_vision=clip_vision, reference_image=reference_image, clipseg_text=str(clipseg_text)) + except Exception: + pass + + # Visual separation and start marker after seed is finalized + try: + print("") + except Exception: + pass + try: + print(f"\x1b[32m==== {step_tag}, Starting main job ====\x1b[0m") + except Exception: + pass + + ref_embed = None + if reference_clean and (clip_vision is not None) and (reference_image is not None): + try: + ref_embed = _encode_clip_image(reference_image, clip_vision, ref_preview) + except Exception: + ref_embed = None + + # Pre-disable any lingering NAG patch from previous runs and set PV accumulation for this node + try: + sa_patch.enable_crossattention_nag_patch(False) + except Exception: + pass + prev_accum = getattr(sa_patch, "CURRENT_PV_ACCUM", None) + sa_patch.CURRENT_PV_ACCUM = None if accumulation == "default" else accumulation + # Enable NAG patch if requested + try: + sa_patch.enable_crossattention_nag_patch(bool(use_nag), float(nag_scale), float(nag_tau), float(nag_alpha)) + except Exception: + pass + + # Enable attention-entropy probe for AQClip Attn-mode (read from preset) + try: + aq_attn = bool(p.get("aq_attn", False)) if isinstance(p, dict) else False + if hasattr(sa_patch, "enable_attention_entropy_capture"): + sa_patch.enable_attention_entropy_capture(aq_attn, max_tokens=1024, max_heads=4) + except Exception: + pass + + # Enable KV pruning for self-attention (read from preset) + try: + kv_enable = bool(p.get("kv_prune_enable", False)) if isinstance(p, dict) else False + kv_keep = float(p.get("kv_keep", 0.85)) if isinstance(p, dict) else 0.85 + kv_min_tokens = int(p.get("kv_min_tokens", 128)) if isinstance(p, dict) else 128 + if hasattr(sa_patch, "set_kv_prune"): + sa_patch.set_kv_prune(kv_enable, kv_keep, kv_min_tokens) + except Exception: + pass + + onnx_mask_last = None + try: + with torch.inference_mode(): + __cade_noop = 0 # ensure non-empty with-block + + # Preflight: reset sticky state and build external masks once (CPU-pinned) + try: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + pass + pre_mask = None + pre_area = 0.0 + # ONNX detectors disabled in Easy: prefer CLIPSeg + edge fusion + onnx_detectors = False + # Build CLIPSeg mask once + if bool(clipseg_enable) and isinstance(clipseg_text, str) and clipseg_text.strip() != "": + try: + cmask = _clipseg_build_mask(image, clipseg_text, int(clipseg_preview), float(clipseg_threshold), float(clipseg_blur), int(clipseg_dilate), float(clipseg_gain), None, None, float(clipseg_ref_threshold)) + if cmask is not None: + if pre_mask is None: + pre_mask = cmask + else: + if clipseg_blend == "replace": + pre_mask = cmask + elif clipseg_blend == "intersect": + pre_mask = (pre_mask * cmask).clamp(0, 1) + else: + pre_mask = (1.0 - (1.0 - pre_mask) * (1.0 - cmask)).clamp(0, 1) + except Exception: + pass + # Edge mask from ControlFusion Step (with depth gating) when enabled; fallback to Sobel + if bool(seg_use_cf_edges): + try: + emask = _build_cf_edge_mask_from_step(image, str(preset_step)) + except Exception: + emask = None + if emask is None: + try: + emask = _edge_mask(image, threshold=0.20, blur=1.0) + except Exception: + emask = None + if emask is not None: + pre_mask = emask if pre_mask is None else (1.0 - (1.0 - pre_mask) * (1.0 - emask)).clamp(0, 1) + if pre_mask is not None: + onnx_mask_last = pre_mask + om = pre_mask.movedim(-1, 1) + pre_area = float(om.mean().item()) + if bool(onnx_local_guidance): + try: + if 0.02 <= pre_area <= 0.35: + CURRENT_ONNX_MASK_BCHW = om.clamp(0, 1).to(model_management.get_torch_device()) + else: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + CURRENT_ONNX_MASK_BCHW = None + # One-time damping from area (disabled by default) + if False: + try: + if pre_area > 0.005: + damp = 1.0 - min(0.04, 0.008 + pre_area * 0.02) + current_denoise = max(0.10, current_denoise * damp) + current_cfg = max(1.0, current_cfg * (1.0 - 0.003)) + except Exception: + pass + # Preflight symmetry disabled (kept for experiments only) + if False: + try: + img0 = image + sym_mask = _clipseg_build_mask(img0, "face | head | torso | shoulders", preview=int(clipseg_preview), threshold=0.45, blur=5.0, dilate=2, gain=1.0) + if sym_mask is not None: + img_sym = _soft_symmetry_blend(img0, sym_mask, alpha=0.012, lp_sigma=1.75) + current_latent = {"samples": safe_encode(vae, img_sym)} + image = img_sym + except Exception: + pass + # Compact status + try: + provs = [] + if _ONNX_RT is not None: + provs = list(_ONNX_RT.get_available_providers()) + clipseg_status = "on" if bool(clipseg_enable) and isinstance(clipseg_text, str) and clipseg_text.strip() != "" else "off" + kpts = f"kpts={'on' if bool(onnx_kpts_enable) else 'off'} sigma={float(onnx_kpts_sigma):.2f} gain={float(onnx_kpts_gain):.2f} conf={float(onnx_kpts_conf):.2f}" + # print preflight info only in debug sessions (muted by default) + if False: + print(f"[CADE2.5][preflight] onnx_sessions={len(_ONNX_SESS)} providers={provs if provs else ['CPU']} clipseg={clipseg_status} device={'cpu' if _CLIPSEG_FORCE_CPU else _CLIPSEG_DEV} mask_area={pre_area:.4f} {kpts}") + except Exception: + pass + # Freeze per-iteration external mask rebuild + onnx_detectors = False + clipseg_enable = False + # Depth gate cache for micro-detail injection (reuse per resolution) + depth_gate_cache = {"size": None, "mask": None} + for i in range(iterations): + if i % 2 == 0: + clear_gpu_and_ram_cache() + + prev_samples = current_latent["samples"].clone().detach() + + iter_seed = seed + i * 7777 + if noise_offset > 0.0: + # Deterministic noise offset tied to iter_seed + fade = 1.0 - (i / max(1, iterations)) + try: + gen = torch.Generator(device='cpu') + except Exception: + gen = torch.Generator() + gen.manual_seed(int(iter_seed) & 0xFFFFFFFF) + eps = torch.randn( + size=current_latent["samples"].shape, + dtype=current_latent["samples"].dtype, + device='cpu', + generator=gen, + ).to(current_latent["samples"].device) + current_latent["samples"] += (noise_offset * fade) * eps + + # Pre-sampling ONNX detectors: handled once below (kept compact) + + # Pre-sampling ONNX detectors (build mask and optionally adjust params for this iteration) + if onnx_detectors and (i % max(1, 1) == 0): + try: + import os + models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "models") + img_preview = safe_decode(vae, current_latent) + # Set toggles for this iteration + globals()["_ONNX_DEBUG"] = bool(onnx_debug) + globals()["_ONNX_COUNT_DEBUG"] = True # force counts ON for debugging session + globals()["_ONNX_KPTS_ENABLE"] = bool(onnx_kpts_enable) + globals()["_ONNX_KPTS_SIGMA"] = float(onnx_kpts_sigma) + globals()["_ONNX_KPTS_GAIN"] = float(onnx_kpts_gain) + globals()["_ONNX_KPTS_CONF"] = float(onnx_kpts_conf) + onnx_mask = _onnx_build_mask(img_preview, int(onnx_preview), float(onnx_sensitivity), models_dir, float(onnx_anomaly_gain)) + if onnx_mask is not None: + onnx_mask_last = onnx_mask + om = onnx_mask.movedim(-1, 1) + area = float(om.mean().item()) + if bool(onnx_debug): + print(f"[CADE2.5][ONNX] iter={i} mask_area={area:.4f}") + if area > 0.005: + damp = 1.0 - min(0.25, 0.06 + onnx_sensitivity * 0.05 + area * 0.25) + current_denoise = max(0.10, current_denoise * damp) + current_cfg = max(1.0, current_cfg * (1.0 - 0.015 * onnx_sensitivity)) + # Prepare spatial mask for cfg_func if requested + if bool(onnx_local_guidance): + # store BCHW mask in global for this iteration (only for reasonable areas) + if 0.02 <= area <= 0.35: + CURRENT_ONNX_MASK_BCHW = om.clamp(0, 1).to(model_management.get_torch_device()) + else: + CURRENT_ONNX_MASK_BCHW = None + else: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + CURRENT_ONNX_MASK_BCHW = None + + # CF edge mask (from current image) and fusion (only when enabled) + if bool(seg_use_cf_edges): + try: + img_prev2 = safe_decode(vae, current_latent) + em2 = _build_cf_edge_mask_from_step(img_prev2, str(preset_step)) + if em2 is not None: + if onnx_mask_last is None: + onnx_mask_last = em2 + else: + onnx_mask_last = (1.0 - (1.0 - onnx_mask_last) * (1.0 - em2)).clamp(0, 1) + om = onnx_mask_last.movedim(-1, 1) + area = float(om.mean().item()) + if bool(onnx_local_guidance): + if 0.02 <= area <= 0.35: + CURRENT_ONNX_MASK_BCHW = om.clamp(0, 1).to(model_management.get_torch_device()) + else: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + pass + + # CLIPSeg mask (optional) and fusion with ONNX + try: + if bool(clipseg_enable) and isinstance(clipseg_text, str) and clipseg_text.strip() != "": + cmask = _clipseg_build_mask(img_prev2, clipseg_text, int(clipseg_preview), float(clipseg_threshold), float(clipseg_blur), int(clipseg_dilate), float(clipseg_gain), ref_embed if bool(clipseg_ref_gate) else None, clip_vision if bool(clipseg_ref_gate) else None, float(clipseg_ref_threshold)) + if cmask is not None: + if onnx_mask_last is None: + fused = cmask + else: + if clipseg_blend == "replace": + fused = cmask + elif clipseg_blend == "intersect": + fused = (onnx_mask_last * cmask).clamp(0, 1) + else: + fused = (1.0 - (1.0 - onnx_mask_last) * (1.0 - cmask)).clamp(0, 1) + onnx_mask_last = fused + om = fused.movedim(-1, 1) + area = float(om.mean().item()) + if bool(onnx_debug): + print(f"[CADE2.5][MASK] iter={i} mask_area={area:.4f}") + if area > 0.005: + damp = 1.0 - min(0.10, 0.02 + float(onnx_sensitivity) * 0.02 + area * 0.08) + current_denoise = max(0.10, current_denoise * damp) + current_cfg = max(1.0, current_cfg * (1.0 - 0.008 * float(onnx_sensitivity))) + if bool(onnx_local_guidance): + if 0.02 <= area <= 0.35: + CURRENT_ONNX_MASK_BCHW = om.clamp(0, 1).to(model_management.get_torch_device()) + else: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + pass + + # Guidance override via cfg_func when requested + sampler_model = _wrap_model_with_guidance( + model, guidance_mode, rescale_multiplier, momentum_beta, cfg_curve, perp_damp, + use_zero_init=bool(use_zero_init), zero_init_steps=int(zero_init_steps), + fdg_low=float(fdg_low), fdg_high=float(fdg_high), fdg_sigma=float(fdg_sigma), + midfreq_enable=bool(midfreq_enable), midfreq_gain=float(midfreq_gain), midfreq_sigma_lo=float(midfreq_sigma_lo), midfreq_sigma_hi=float(midfreq_sigma_hi), + ze_zero_steps=int(ze_res_zero_steps), + ze_adaptive=bool(ze_adaptive), ze_r_switch_hi=float(ze_r_switch_hi), ze_r_switch_lo=float(ze_r_switch_lo), + fdg_low_adaptive=bool(fdg_low_adaptive), fdg_low_min=float(fdg_low_min), fdg_low_max=float(fdg_low_max), fdg_ema_beta=float(fdg_ema_beta), + use_local_mask=bool(onnx_local_guidance), mask_inside=float(onnx_mask_inside), mask_outside=float(onnx_mask_outside), + mahiro_plus_enable=bool(muse_blend), mahiro_plus_strength=float(muse_blend_strength), + eps_scale_enable=bool(eps_scale_enable), eps_scale=float(eps_scale) + ) + + # Local best-of-2 in ROI (hands/face), only on early iteration to limit overhead + try: + do_local_refine = False # disable local best-of-2 by default + if do_local_refine: + img_roi = safe_decode(vae, current_latent) + roi = _clipseg_build_mask(img_roi, "hand | hands | face", preview=max(192, int(clipseg_preview//2)), threshold=0.40, blur=5.0, dilate=2, gain=1.0) + if roi is None and onnx_mask_last is not None: + roi = torch.clamp(onnx_mask_last, 0.0, 1.0) + if roi is not None: + # Area gating + try: + ra = float(roi.mean().item()) + except Exception: + ra = 0.0 + if not (0.02 <= ra <= 0.15): + raise Exception("ROI area out of range; skip local_refine") + # Light erosion to avoid halo influence + try: + m = roi[..., 0].unsqueeze(1) + # disable erosion effect (kernel=1) + ero = 1.0 - F.max_pool2d(1.0 - m, kernel_size=1, stride=1, padding=0) + roi = ero.clamp(0, 1).movedim(1, -1) + except Exception: + pass + # micro sampling params + micro_steps = int(max(2, min(4, round(max(1, current_steps) * 0.05)))) + micro_denoise = float(min(0.22, max(0.10, current_denoise * 0.30))) + s1 = int((iter_seed ^ 0x9E3779B1) & 0xFFFFFFFFFFFFFFFF) + s2 = int((iter_seed ^ 0x85EBCA77) & 0xFFFFFFFFFFFFFFFF) + # Candidate A + lat_in_a = {"samples": current_latent["samples"].clone()} + lat_a, = nodes.common_ksampler( + sampler_model, s1, micro_steps, current_cfg, sampler_name, scheduler, + positive, negative, lat_in_a, denoise=micro_denoise) + img_a = safe_decode(vae, lat_a) + # Candidate B + lat_in_b = {"samples": current_latent["samples"].clone()} + lat_b, = nodes.common_ksampler( + sampler_model, s2, micro_steps, current_cfg, sampler_name, scheduler, + positive, negative, lat_in_b, denoise=micro_denoise) + img_b = safe_decode(vae, lat_b) + + # Score inside ROI + def _roi_stats(img, roi_mask): + try: + m = roi_mask[..., 0].clamp(0, 1) + R, G, Bc = img[..., 0], img[..., 1], img[..., 2] + lum = (0.2126 * R + 0.7152 * G + 0.0722 * Bc) + # edges + kx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=img.device, dtype=img.dtype).view(1,1,3,3) + ky = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=img.device, dtype=img.dtype).view(1,1,3,3) + gx = F.conv2d(lum.unsqueeze(1), kx, padding=1) + gy = F.conv2d(lum.unsqueeze(1), ky, padding=1) + g = torch.sqrt(gx*gx + gy*gy).squeeze(1) + wmean = (g*m).mean() / (m.mean()+1e-6) + # speckles + V = torch.maximum(R, torch.maximum(G, Bc)) + mi = torch.minimum(R, torch.minimum(G, Bc)) + S = 1.0 - (mi / (V + 1e-6)) + cand = (V > 0.98) & (S < 0.12) + speck = (cand.float()*m).mean() / (m.mean()+1e-6) + lmean = (lum*m).mean() / (m.mean()+1e-6) + return float(wmean.item()), float(speck.item()), float(lmean.item()) + except Exception: + return 0.0, 0.5, 0.5 + + ed_a, sp_a, lm_a = _roi_stats(img_a, roi) + ed_b, sp_b, lm_b = _roi_stats(img_b, roi) + edge_target = 0.08 + score_a = -abs(ed_a - edge_target) - 0.8*sp_a - 0.10*abs(lm_a - 0.5) + score_b = -abs(ed_b - edge_target) - 0.8*sp_b - 0.10*abs(lm_b - 0.5) + + # Optional CLIP-Vision ref boost + if ref_embed is not None and clip_vision is not None: + try: + emb_a = _encode_clip_image(img_a, clip_vision, target_res=224) + emb_b = _encode_clip_image(img_b, clip_vision, target_res=224) + sim_a = float((emb_a * ref_embed).sum(dim=-1).mean().clamp(-1.0, 1.0).item()) + sim_b = float((emb_b * ref_embed).sum(dim=-1).mean().clamp(-1.0, 1.0).item()) + score_a += 0.25 * (0.5*(sim_a+1.0)) + score_b += 0.25 * (0.5*(sim_b+1.0)) + except Exception: + pass + + if score_b > score_a: + current_latent = lat_b + else: + current_latent = lat_a + except Exception: + pass + + if str(scheduler) == "MGHybrid": + try: + # Build ZeSmart hybrid sigmas with safe defaults + sigmas = _build_hybrid_sigmas( + sampler_model, int(current_steps), str(sampler_name), "hybrid", + mix=0.5, denoise=float(current_denoise), jitter=0.01, seed=int(iter_seed), + _debug=False, tail_smooth=0.15, auto_hybrid_tail=True, auto_tail_strength=0.4, + ) + # Prepare latent + noise like in MG_ZeSmartSampler + lat_img = current_latent["samples"] + lat_img = _sample.fix_empty_latent_channels(sampler_model, lat_img) + batch_inds = current_latent.get("batch_index", None) + noise = _sample.prepare_noise(lat_img, int(iter_seed), batch_inds) + noise_mask = current_latent.get("noise_mask", None) + callback = nodes.latent_preview.prepare_callback(sampler_model, int(current_steps)) + disable_pbar = not _utils.PROGRESS_BAR_ENABLED + sampler_obj = _samplers.sampler_object(str(sampler_name)) + samples = _sample.sample_custom( + sampler_model, noise, float(current_cfg), sampler_obj, sigmas, + positive, negative, lat_img, + noise_mask=noise_mask, callback=callback, + disable_pbar=disable_pbar, seed=int(iter_seed) + ) + current_latent = {**current_latent} + current_latent["samples"] = samples + except Exception as e: + # Fallback to original path if anything goes wrong + print(f"[CADE2.5][MGHybrid] fallback to common_ksampler due to: {e}") + current_latent, = nodes.common_ksampler( + sampler_model, iter_seed, int(current_steps), current_cfg, sampler_name, _scheduler_names()[0], + positive, negative, current_latent, denoise=current_denoise) + else: + current_latent, = nodes.common_ksampler( + sampler_model, iter_seed, int(current_steps), current_cfg, sampler_name, scheduler, + positive, negative, current_latent, denoise=current_denoise) + + if bool(latent_compare): + latent_diff = current_latent["samples"] - prev_samples + rms = torch.sqrt(torch.mean(latent_diff * latent_diff)) + drift = float(rms.item()) + if drift > float(threshold): + overshoot = max(0.0, drift - float(threshold)) + damp = 1.0 - min(0.15, overshoot * 2.0) + current_denoise = max(0.20, current_denoise * damp) + cfg_damp = 0.997 if damp > 0.9 else 0.99 + current_cfg = max(1.0, current_cfg * cfg_damp) + + # AQClip-Lite: adaptive soft clipping in latent space (before decode) + try: + if bool(aqclip_enable): + if 'aq_state' not in locals(): + aq_state = None + H_override = None + try: + if bool(aq_attn) and hasattr(sa_patch, "get_attention_entropy_map"): + Hm = sa_patch.get_attention_entropy_map(clear=False) + if Hm is not None: + H_override = F.interpolate(Hm, size=(current_latent["samples"].shape[-2], current_latent["samples"].shape[-1]), mode="bilinear", align_corners=False) + except Exception: + H_override = None + z_new, aq_state = _aqclip_lite( + current_latent["samples"], + tile=int(aq_tile), stride=int(aq_stride), + alpha=float(aq_alpha), ema_state=aq_state, ema_beta=float(aq_ema_beta), + H_override=H_override, + ) + current_latent["samples"] = z_new + except Exception: + pass + + image = safe_decode(vae, current_latent) + + # Polish mode: keep global form (low frequencies) from reference while letting details refine + if bool(polish_enable) and (i >= int(polish_start_after)): + try: + # Prepare tensors + img = image + ref = reference_image if (reference_image is not None) else img + if ref.shape[1] != img.shape[1] or ref.shape[2] != img.shape[2]: + # resize reference to match current image + ref_n = ref.movedim(-1, 1) + ref_n = F.interpolate(ref_n, size=(img.shape[1], img.shape[2]), mode='bilinear', align_corners=False) + ref = ref_n.movedim(1, -1) + x = img.movedim(-1, 1) + r = ref.movedim(-1, 1) + # Low/high split via Gaussian blur + rad = max(1, int(round(float(polish_sigma) * 2))) + low_x = _gaussian_blur_nchw(x, sigma=float(polish_sigma), radius=rad) + low_r = _gaussian_blur_nchw(r, sigma=float(polish_sigma), radius=rad) + high_x = x - low_x + # Mix low from reference and current with ramp + # a starts from polish_keep_low_ramp and linearly ramps to polish_keep_low over remaining iterations + try: + denom = max(1, int(iterations) - int(polish_start_after)) + t = max(0.0, min(1.0, (i - int(polish_start_after)) / denom)) + except Exception: + t = 1.0 + a0 = float(polish_keep_low_ramp) + at = float(polish_keep_low) + a = a0 + (at - a0) * t + low_mix = low_r * a + low_x * (1.0 - a) + new = low_mix + high_x + # Micro-detail injection on tail: very light HF boost gated by edges+depth + try: + phase = (i + 1) / max(1, int(iterations)) + ramp = max(0.0, min(1.0, (phase - 0.70) / 0.30)) + if ramp > 0.0: + micro = x - _gaussian_blur_nchw(x, sigma=0.6, radius=1) + gray = x.mean(dim=1, keepdim=True) + sobel_x = torch.tensor([[[-1,0,1],[-2,0,2],[-1,0,1]]], dtype=gray.dtype, device=gray.device).unsqueeze(1) + sobel_y = torch.tensor([[[-1,-2,-1],[0,0,0],[1,2,1]]], dtype=gray.dtype, device=gray.device).unsqueeze(1) + gx = F.conv2d(gray, sobel_x, padding=1) + gy = F.conv2d(gray, sobel_y, padding=1) + mag = torch.sqrt(gx*gx + gy*gy) + m_edge = (mag - mag.amin()) / (mag.amax() - mag.amin() + 1e-8) + g_edge = (1.0 - m_edge).clamp(0.0, 1.0).pow(0.65) + try: + sz = (int(img.shape[1]), int(img.shape[2])) + if depth_gate_cache.get("size") != sz or depth_gate_cache.get("mask") is None: + model_path = os.path.join(os.path.dirname(__file__), '..', 'depth-anything', 'depth_anything_v2_vitl.pth') + dm = _cf_build_depth_map(img, res=512, model_path=model_path, hires_mode=True) + depth_gate_cache = {"size": sz, "mask": dm} + dm = depth_gate_cache.get("mask") + if dm is not None: + g_depth = (dm.movedim(-1, 1).clamp(0,1)) ** 1.35 + else: + g_depth = torch.ones_like(g_edge) + except Exception: + g_depth = torch.ones_like(g_edge) + g = (g_edge * g_depth).clamp(0.0, 1.0) + micro_boost = 0.018 * ramp + new = new + micro_boost * (micro * g) + except Exception: + pass + # Edge-lock: protect edges from drift by biasing toward low_mix along edges + el = float(polish_edge_lock) + if el > 1e-6: + # Sobel edge magnitude on grayscale + gray = x.mean(dim=1, keepdim=True) + sobel_x = torch.tensor([[[-1,0,1],[-2,0,2],[-1,0,1]]], dtype=gray.dtype, device=gray.device).unsqueeze(1) + sobel_y = torch.tensor([[[-1,-2,-1],[0,0,0],[1,2,1]]], dtype=gray.dtype, device=gray.device).unsqueeze(1) + gx = F.conv2d(gray, sobel_x, padding=1) + gy = F.conv2d(gray, sobel_y, padding=1) + mag = torch.sqrt(gx*gx + gy*gy) + m = (mag - mag.amin()) / (mag.amax() - mag.amin() + 1e-8) + # Blend toward low_mix near edges + new = new * (1.0 - el*m) + (low_mix) * (el*m) + img2 = new.movedim(1, -1).clamp(0,1) + # Feed back to latent for next steps + current_latent = {"samples": safe_encode(vae, img2)} + image = img2 + except Exception: + pass + + # ONNX detectors (beta): fuse hands/face/pose mask if available (post-sampling; skip if already set) + if onnx_detectors and (i % max(1, 1) == 0) and (onnx_mask_last is None): + try: + import os + models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "models") + globals()["_ONNX_DEBUG"] = False + globals()["_ONNX_KPTS_ENABLE"] = bool(onnx_kpts_enable) + globals()["_ONNX_KPTS_SIGMA"] = float(onnx_kpts_sigma) + globals()["_ONNX_KPTS_GAIN"] = float(onnx_kpts_gain) + globals()["_ONNX_KPTS_CONF"] = float(onnx_kpts_conf) + onnx_mask = _onnx_build_mask(image, int(onnx_preview), float(onnx_sensitivity), models_dir, float(onnx_anomaly_gain)) + if onnx_mask is not None: + onnx_mask_last = onnx_mask + om = onnx_mask.movedim(-1,1) + area = float(om.mean().item()) + # verbose post-mask log removed; keep single compact log above + if area > 0.005: + damp = 1.0 - min(0.25, 0.06 + onnx_sensitivity*0.05 + area*0.25) + current_denoise = max(0.10, current_denoise * damp) + current_cfg = max(1.0, current_cfg * (1.0 - 0.015*onnx_sensitivity)) + except Exception: + pass + + if reference_clean and (ref_embed is not None) and (i % max(1, ref_cooldown) == 0): + try: + cur_embed = _encode_clip_image(image, clip_vision, ref_preview) + dist = _clip_cosine_distance(cur_embed, ref_embed) + if dist > ref_threshold: + current_denoise = max(0.10, current_denoise * 0.9) + current_cfg = max(1.0, current_cfg * 0.99) + except Exception: + pass + + if apply_upscale and current_scale != 1.0: + current_latent, image = MagicUpscaleModule().process_upscale( + current_latent, vae, upscale_method, current_scale) + # After upscale at large sizes, add a tiny HF sprinkle gated by edges+depth + try: + H, W = int(image.shape[1]), int(image.shape[2]) + if max(H, W) > 1536: + # Simple BHWC blur + def _gb_bhwc(im: torch.Tensor, radius: float, sigma: float) -> torch.Tensor: + if radius <= 0.0: + return im + pad = int(max(1, round(radius))) + ksz = pad * 2 + 1 + k = _gaussian_kernel(ksz, sigma, device=im.device).to(dtype=im.dtype) + k = k.unsqueeze(0).unsqueeze(0) + b, h, w, c = im.shape + xch = im.permute(0, 3, 1, 2) + y = F.conv2d(F.pad(xch, (pad, pad, pad, pad), mode='reflect'), k.repeat(c, 1, 1, 1), groups=c) + return y.permute(0, 2, 3, 1) + blur = _gb_bhwc(image, radius=1.0, sigma=0.8) + hf = (image - blur).clamp(-1, 1) + lum = (0.2126 * image[..., 0] + 0.7152 * image[..., 1] + 0.0722 * image[..., 2]) + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=lum.device, dtype=lum.dtype).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=lum.device, dtype=lum.dtype).view(1, 1, 3, 3) + g = torch.sqrt(F.conv2d(lum.unsqueeze(1), kx, padding=1)**2 + F.conv2d(lum.unsqueeze(1), ky, padding=1)**2).squeeze(1) + m = (g - g.amin()) / (g.amax() - g.amin() + 1e-8) + g_edge = (1.0 - m).clamp(0,1).pow(0.5).unsqueeze(-1) + try: + sz = (H, W) + if depth_gate_cache.get("size") != sz or depth_gate_cache.get("mask") is None: + model_path = os.path.join(os.path.dirname(__file__), '..', 'depth-anything', 'depth_anything_v2_vitl.pth') + dm = _cf_build_depth_map(image, res=512, model_path=model_path, hires_mode=True) + depth_gate_cache = {"size": sz, "mask": dm} + dm = depth_gate_cache.get("mask") + if dm is not None: + g_depth = dm.clamp(0,1) ** 1.35 + else: + g_depth = torch.ones_like(g_edge) + except Exception: + g_depth = torch.ones_like(g_edge) + g_tot = (g_edge * g_depth).clamp(0,1) + image = (image + 0.045 * hf * g_tot).clamp(0,1) + except Exception: + pass + current_cfg = max(4.0, current_cfg * (1.0 / current_scale)) + current_denoise = max(0.15, current_denoise * (1.0 / current_scale)) + + current_steps = max(1, current_steps - steps_delta) + current_cfg = max(0.0, current_cfg - cfg_delta) + current_denoise = max(0.0, current_denoise - denoise_delta) + current_scale = max(1.0, current_scale - scale_delta) + + if apply_upscale and current_scale != 1.0 and max(image.shape[1:3]) > 1024: + current_latent = {"samples": safe_encode(vae, image)} + + finally: + # Always disable NAG patch and clear local mask, even on errors + try: + sa_patch.enable_crossattention_nag_patch(False) + except Exception: + pass + # Disable KV pruning as well (avoid leaking state) + try: + if hasattr(sa_patch, "set_kv_prune"): + sa_patch.set_kv_prune(False, 1.0, 128) + except Exception: + pass + try: + sa_patch.CURRENT_PV_ACCUM = prev_accum + except Exception: + pass + try: + CURRENT_ONNX_MASK_BCHW = None + except Exception: + pass + + if apply_ids: + image, = IntelligentDetailStabilizer().stabilize(image, ids_strength) + + if apply_sharpen: + image, = _sharpen_image(image, 2, 1.0, Sharpnes_strenght) + + # ONNX mask preview as IMAGE (RGB) + if onnx_mask_last is None: + onnx_mask_last = torch.zeros((image.shape[0], image.shape[1], image.shape[2], 1), device=image.device, dtype=image.dtype) + onnx_mask_img = onnx_mask_last.repeat(1, 1, 1, 3).clamp(0, 1) + + # Final pass: remove isolated hot whites ("fireflies") without touching real edges/highlights 6.0/9.0, 0.05 + try: + image = _despeckle_fireflies(image, thr=0.998, max_iso=4.0/9.0, grad_gate=0.15) + except Exception: + pass + + return current_latent, image, int(current_steps), float(current_cfg), float(current_denoise), onnx_mask_img + + +# === Easy UI wrapper: show only top-level controls === +class CADEEasyUI(ComfyAdaptiveDetailEnhancer25): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "preset_step": (["Step 1", "Step 2", "Step 3", "Step 4"], {"default": "Step 1", "tooltip": "Choose the Step preset. Toggle Custom below to apply UI values; otherwise Step preset values are used."}), + "custom": ("BOOLEAN", {"default": False, "tooltip": "Custom override: when enabled, your UI values override the selected Step for visible controls; hidden parameters still come from the Step preset."}), + "model": ("MODEL", {}), + "positive": ("CONDITIONING", {}), + "negative": ("CONDITIONING", {}), + "vae": ("VAE", {}), + "latent": ("LATENT", {}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True, "tooltip": "Seed 0 = SmartSeed (Sobol + light probe). Non?zero = fixed seed (deterministic)."}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.0001}), + "sampler_name": (_sampler_names(), {"default": _sampler_names()[0]}), + "scheduler": (_scheduler_names(), {"default": "MGHybrid"}), + }, + "optional": { + # Reference inputs must remain available in Easy + "reference_image": ("IMAGE", {}), + "clip_vision": ("CLIP_VISION", {}), + # CLIPSeg prompt + "clipseg_text": ("STRING", {"default": "", "multiline": False, "tooltip": "This field tells what the step should focus on (e.g., hand, feet, face). Separate with commas."}), + } + } + + # Easy outputs (hide steps/cfg/denoise) + RETURN_TYPES = ("LATENT", "IMAGE", "IMAGE") + RETURN_NAMES = ("LATENT", "IMAGE", "mask_preview") + FUNCTION = "apply_easy" + + def apply_easy(self, + preset_step, + model, positive, negative, vae, latent, + seed, steps, cfg, denoise, sampler_name, scheduler, + clipseg_text="", reference_image=None, clip_vision=None, custom=False): + lat, img, _s, _c, _d, mask = super().apply_cade2( + model, vae, positive, negative, latent, + int(seed), int(steps), float(cfg), float(denoise), + str(sampler_name), str(scheduler), 0.0, + preset_step=str(preset_step), custom_override=bool(custom), clipseg_text=str(clipseg_text), + reference_image=reference_image, + clip_vision=clip_vision, + ) + return lat, img, mask + + # Show simpler outputs in Easy variant + RETURN_TYPES = ("LATENT", "IMAGE", "IMAGE") + RETURN_NAMES = ("LATENT", "IMAGE", "mask_preview") + FUNCTION = "apply_easy" + + def apply_easy(self, + preset_step, + model, positive, negative, vae, latent, + seed, steps, cfg, denoise, sampler_name, scheduler, + clipseg_text=""): + lat, img, _s, _c, _d, mask = super().apply_cade2( + model, vae, positive, negative, latent, + int(seed), int(steps), float(cfg), float(denoise), + str(sampler_name), str(scheduler), 0.0, + preset_step=str(preset_step), custom_override=bool(custom), clipseg_text=str(clipseg_text), + ) + return lat, img, mask + + + + + +# === Smart seed helpers (Sobol/Halton + light probing) === +def _splitmix64(x: int) -> int: + x = (x + 0x9E3779B97F4A7C15) & 0xFFFFFFFFFFFFFFFF + z = x + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9 & 0xFFFFFFFFFFFFFFFF + z = (z ^ (z >> 27)) * 0x94D049BB133111EB & 0xFFFFFFFFFFFFFFFF + z = z ^ (z >> 31) + return z & 0xFFFFFFFFFFFFFFFF + +def _halton_single(index: int, base: int) -> float: + f = 1.0 + r = 0.0 + i = index + while i > 0: + f = f / base + r = r + f * (i % base) + i //= base + return r + +def _sobol_like_2d(n: int, anchor: int) -> tuple[float, float]: + # lightweight 2D low-discrepancy via Halton(2,3) scrambled by anchor + i = n + 1 + (anchor % 9973) + return (_halton_single(i, 2), _halton_single(i, 3)) + +def _edge_density(img_bhwc: torch.Tensor) -> float: + lum = (0.2126 * img_bhwc[..., 0] + 0.7152 * img_bhwc[..., 1] + 0.0722 * img_bhwc[..., 2]) + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=img_bhwc.device, dtype=img_bhwc.dtype).view(1,1,3,3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=img_bhwc.device, dtype=img_bhwc.dtype).view(1,1,3,3) + gx = F.conv2d(lum.unsqueeze(1), kx, padding=1) + gy = F.conv2d(lum.unsqueeze(1), ky, padding=1) + g = torch.sqrt(gx*gx + gy*gy) + return float(g.mean().item()) + +def _speckle_fraction(img_bhwc: torch.Tensor) -> float: + # reuse S/V-based candidate mask from despeckle logic (no replacement) to estimate fraction + R, G, Bc = img_bhwc[..., 0], img_bhwc[..., 1], img_bhwc[..., 2] + V = torch.maximum(R, torch.maximum(G, Bc)) + mi = torch.minimum(R, torch.minimum(G, Bc)) + S = 1.0 - (mi / (V + 1e-6)) + v_thr = 0.98 + s_thr = 0.12 + cand = (V > v_thr) & (S < s_thr) + return float(cand.float().mean().item()) + +def _smart_seed_state_path() -> str: + import os + base = os.path.join(os.path.dirname(__file__), "..", "state") + os.makedirs(base, exist_ok=True) + return os.path.join(base, "smart_seed.json") + +def _smart_seed_counter(anchor: int) -> int: + import os, json + path = _smart_seed_state_path() + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + data = {} + key = hex(anchor & 0xFFFFFFFFFFFFFFFF) + n = int(data.get(key, 0)) + data[key] = n + 1 + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception: + pass + return n + +def _smart_seed_select(model, + vae, + positive, + negative, + latent, + sampler_name: str, + scheduler: str, + cfg: float, + denoise: float, + base_seed: int | None = None, + k: int = 6, + probe_steps: int = 6, + clip_vision=None, + reference_image=None, + clipseg_text: str = "", + step_tag: str | None = None) -> int: + # Log start of SmartSeed selection + try: + try: + # Visual separation before SmartSeed block + print("") + print("") + if step_tag: + print(f"\x1b[34m==== {step_tag}, Smart_seed_random: Start ====\x1b[0m") + else: + print("\x1b[34m==== Smart_seed_random: Start ====\x1b[0m") + except Exception: + pass + + # Optional: precompute CLIP-Vision embedding of reference image + ref_embed = None + if (clip_vision is not None) and (reference_image is not None): + try: + ref_embed = _encode_clip_image(reference_image, clip_vision, target_res=224) + except Exception: + ref_embed = None + + # Anchor from latent shape + sampler/scheduler (+ cfg/denoise) + sh = latent["samples"].shape if isinstance(latent, dict) and "samples" in latent else None + anchor = 1469598103934665603 # FNV offset basis + for v in (sh[2] if sh else 0, sh[3] if sh else 0, len(str(sampler_name)), len(str(scheduler)), int(cfg * 1000), int(denoise * 1000)): + anchor = _splitmix64(anchor ^ int(v)) + # Advance a persistent counter per anchor to vary indices between runs + offset = _smart_seed_counter(anchor) * 7 + + # Build K candidate seeds from Halton(2,3) + cands: list[int] = [] + for i in range(k): + u, v = _sobol_like_2d(offset + i, anchor) + lo = int(u * (1 << 32)) & 0xFFFFFFFF + hi = int(v * (1 << 32)) & 0xFFFFFFFF + seed64 = _splitmix64((hi << 32) ^ lo ^ anchor) + cands.append(seed64 & 0xFFFFFFFFFFFFFFFF) + + best_seed = cands[0] + best_score = -1e9 + for sd in cands: + try: + # quick KSampler preview at low steps + lat_in = {"samples": latent["samples"].clone()} if isinstance(latent, dict) else latent + lat_out, = nodes.common_ksampler( + model, int(sd), int(probe_steps), float(cfg), str(sampler_name), str(scheduler), + positive, negative, lat_in, denoise=float(min(denoise, 0.65)) + ) + img = safe_decode(vae, lat_out) + # Base score: edge density toward a target + low speckle + balanced exposure + ed = _edge_density(img) + speck = _speckle_fraction(img) + lum = float(img.mean().item()) + edge_target = 0.10 + score = -abs(ed - edge_target) - 2.0 * speck - 0.5 * abs(lum - 0.5) + + # Perceptual metrics: luminance std and Laplacian variance (downscaled) + try: + lum_t = (0.2126 * img[..., 0] + 0.7152 * img[..., 1] + 0.0722 * img[..., 2]) + lstd = float(lum_t.std().item()) + lch = lum_t.unsqueeze(1) + lch_small = F.interpolate(lch, size=(128, 128), mode='bilinear', align_corners=False) + lap_k = torch.tensor([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]], device=lch_small.device, dtype=lch_small.dtype).view(1, 1, 3, 3) + lap = F.conv2d(lch_small, lap_k, padding=1) + lap_var = float(lap.var().item()) + score += 0.15 * lstd + 0.10 * lap_var + except Exception: + pass + + # Semantic alignment via CLIP-Vision when available + if ref_embed is not None and clip_vision is not None: + try: + cand_embed = _encode_clip_image(img, clip_vision, target_res=224) + sim = float((cand_embed * ref_embed).sum(dim=-1).mean().clamp(-1.0, 1.0).item()) + sim01 = 0.5 * (sim + 1.0) + score += 0.75 * sim01 + except Exception: + pass + + # Focus coverage via CLIPSeg when text provided + if isinstance(clipseg_text, str) and clipseg_text.strip() != "": + try: + cmask = _clipseg_build_mask(img, clipseg_text, preview=192, threshold=0.40, blur=5.0, dilate=2, gain=1.0) + if cmask is not None: + area = float(cmask.mean().item()) + cov_target = 0.06 + cov_score = 1.0 - min(1.0, abs(area - cov_target) / max(cov_target, 1e-3)) + score += 0.30 * cov_score + except Exception: + pass + if score > best_score: + best_score = score + best_seed = sd + except Exception: + continue + + # Log end with selected seed + try: + if step_tag: + print(f"\x1b[34m==== {step_tag}, Smart_seed_random: End. Seed is: {int(best_seed & 0xFFFFFFFFFFFFFFFF)} ====\x1b[0m") + else: + print(f"\x1b[34m==== Smart_seed_random: End. Seed is: {int(best_seed & 0xFFFFFFFFFFFFFFFF)} ====\x1b[0m") + except Exception: + pass + return int(best_seed & 0xFFFFFFFFFFFFFFFF) + except Exception: + # Fallback to time-based random + try: + import time + fallback_seed = int(_splitmix64(int(time.time_ns()))) + except Exception: + fallback_seed = int(base_seed or 0) + try: + if step_tag: + print(f"\x1b[34m==== {step_tag}, Smart_seed_random: End. Seed is: {fallback_seed} ====\x1b[0m") + else: + print(f"\x1b[34m==== Smart_seed_random: End. Seed is: {fallback_seed} ====\x1b[0m") + except Exception: + pass + return fallback_seed + +