Merge method
Hello π
Did you create the merging method yourself?
By the way great quality of model cards lately, really like them!
Unfortunately don't have enough time to try them all tho π
Thanks, as for merge method I used gpt-5 made custom gpt and uploaded all the .py files in mergekit/mergemethods, and also the markdown code of https://github.com/arcee-ai/mergekit/blob/main/docs/create_a_merge_method.md, It works well but can forget the header import section, I recommend starting new context for coding it. Then drop the new method.py into mergekit mergemethods folder and edit init.py and add new method - import mergekit.merge_methods.yourmethod, then restart it. Also here is the .py code for harmony_forge.
harmony_forge.py
from mergekit.merge_methods.easy_define import merge_method
import torch
import torch.nn.functional as F
from typing import List
@merge_method("harmony_forge")
@torch
	.no_grad()
def harmony_forge(
    tensors: List[torch.Tensor],
    focus: float = 1.0,   # β -> more peaked weights (softmax scale)
    blend: float = 0.5,   # 0=spatial-only .. 1=frequency-only
    **kwargs              # optional: base=<tensor>, return_debug=True
) -> torch.Tensor:
    """
    Harmony Forge (two knobs):
      - Works with or without a base tensor (task-vector mode if base provided).
      - Two controls: focus (decisiveness) and blend (spatial vs frequency).
      - Internally auto-tunes consensus, outlier control, and stability guards.
      - Now memory-safe for very large models (JL projection capped).
    """
    # ------- Basic checks -------
    n = len(tensors)
    if n == 0:
        raise ValueError("harmony_forge: need at least one tensor")
    if n == 1 and "base" not in kwargs:
        return tensors[0]
    # ------- Stack & setup -------
    stack = torch.stack(tensors)            # (n, ...)
    device = stack.device
    work_dtype = torch.float32
    eps = 1e-12
    # ------- Center (base or robust median) -------
    base = kwargs.get("base", None)
    if base is not None:
        center = base.to(device=stack.device, dtype=stack.dtype)
    else:
        center = stack.median(dim=0).values
    deltas = (stack - center)
    Z = deltas.flatten(1).to(work_dtype)    # (n, P)
    P = Z.shape[1]
    # ------- Scale calibration (auto) -------
    norms = torch.norm(Z, dim=1) + eps
    mean_norm = norms.mean()
    scale = (mean_norm / norms).unsqueeze(1)
    Z = Z * scale
    deltas = deltas * scale.view(n, *([1] * (deltas.ndim - 1)))
    # ------- Robust normalization helper -------
    def robust_0_1(x: torch.Tensor) -> torch.Tensor:
        med = x.median()
        mad = (x - med).abs().median() + eps
        z = (x - med) / (1.4826 * mad + eps)
        z = z.clamp(-4.0, 4.0)
        return (z - z.min()) / (z.max() - z.min() + eps)
    # ------- Spatial uniqueness & consensus -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    sims_sp = Zs @ Zs.T
    mean_sim_sp = (sims_sp.sum(dim=1) - 1.0) / max(n - 1, 1)
    u_spatial = robust_0_1((1.0 - mean_sim_sp).clamp(0.0, 1.0))
    # ------- JL projections (memory-safe) -------
    rng = torch.Generator(device=("cuda" if device.type == "cuda" else "cpu"))
    rng.manual_seed(0xACED)
    num_projs = int(kwargs.get("jl_num_projs", 6))
    max_proj_cols = int(kwargs.get("jl_max_cols", 500_000))
    jl_k = int(kwargs.get("jl_k", 128))
    inv_sqrtP = (min(P, max_proj_cols)) ** -0.5
    proj_uniqs = []
    if P > max_proj_cols:
        idx = torch.randperm(P, device=device, generator=rng)[:max_proj_cols]
        Z_sub = Z[:, idx]
        P_eff = Z_sub.shape[1]
    else:
        Z_sub = Z
        P_eff = P
    for _ in range(num_projs):
        R = torch.randn(P_eff, jl_k, dtype=work_dtype, device=device, generator=rng) * inv_sqrtP
        Zp = F.normalize(Z_sub @ R, dim=1, eps=eps)
        ms = (Zp @ Zp.T)
        ms = (ms.sum(dim=1) - 1.0) / max(n - 1, 1)
        proj_uniqs.append((1.0 - ms).clamp(0.0, 1.0))
    u_jl = robust_0_1(torch.stack(proj_uniqs, dim=0).mean(dim=0))
    # ------- Sign consensus & outlier detection -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    maj_dir = F.normalize(Zs.sum(dim=0, keepdim=True), dim=1, eps=eps)
    c_sign = robust_0_1((Zs @ maj_dir.T).squeeze(1).clamp(-1, 1) * 0.5 + 0.5)
    delta_norm = torch.norm(Z, dim=1)
    outlier_pen = 1.0 - robust_0_1(delta_norm)
    # ------- Frequency features & consensus -------
    if P > 4:
        Z_fft = Z[:, :: max(1, Z.shape[1] // 1_000_000)] if Z.shape[1] > 1_000_000 else Z
        spec = torch.fft.rfft(Z_fft, dim=1)
        mag = spec.abs() + eps
        mags = F.normalize(mag, dim=1, eps=eps)
        sims_hm = mags @ mags.T
        mean_sim_hm = (sims_hm.sum(dim=1) - 1.0) / max(n - 1, 1)
        u_freq = robust_0_1((1.0 - mean_sim_hm).clamp(0.0, 1.0))
        P_fft = mag.shape[1]
        low_hi_cut = max(1, int(0.05 * P_fft))
        mid_lo = max(1, int(0.10 * P_fft))
        mid_hi = min(P_fft, max(mid_lo + 1, int(0.40 * P_fft)))
        low_e = mag[:, :low_hi_cut].sum(dim=1)
        mid_e = mag[:, mid_lo:mid_hi].sum(dim=1)
        high_e = mag[:, mid_hi:].sum(dim=1)
        total_e = low_e + mid_e + high_e + eps
        s_mid = robust_0_1(mid_e / total_e)
        s_bal = robust_0_1(1.0 - (low_e / total_e - high_e / total_e).abs().clamp(0, 1))
        unit = spec / (spec.abs() + eps)
        unit_flat = torch.view_as_real(unit).flatten(1)
        unit_flat = F.normalize(unit_flat, dim=1, eps=eps)
        phase_sims = unit_flat @ unit_flat.T
        mean_phase = (phase_sims.sum(dim=1) - 1.0) / max(n - 1, 1)
        c_freq = robust_0_1((mean_phase + 1.0) / 2.0)
    else:
        # fallback for tiny tensors
        u_freq = s_mid = s_bal = c_freq = torch.zeros(n, dtype=work_dtype, device=device)
    # ------- Stability guard -------
    rough = []
    for i in range(n):
        z_i = Z[i]
        side = int(z_i.numel() ** 0.5)
        side = max(1, side)
        mat = z_i[: side * side].reshape(side, side)
        m = mat.to(work_dtype)
        v = F.normalize(torch.randn(m.shape[1], device=device, dtype=work_dtype), dim=0, eps=eps)
        u = F.normalize(m @ v, dim=0, eps=eps)
        sv = torch.dot(u, m @ v).abs()
        rough.append(float(sv.item()))
    rough = torch.tensor(rough, device=device, dtype=work_dtype)
    guard = 1.0 - robust_0_1(rough)
    # ------- Adaptive pooling -------
    def adaptive_pool(S: torch.Tensor) -> torch.Tensor:
        var = S.var(dim=1) + eps
        invvar = 1.0 / var
        w = invvar / invvar.sum()
        return (S * w.unsqueeze(1)).sum(dim=0)
    spatial_pool = adaptive_pool(torch.stack([u_spatial, u_jl, c_sign], dim=0))
    freq_pool = adaptive_pool(torch.stack([u_freq, s_mid, s_bal, c_freq], dim=0))
    # ------- Blend + gating -------
    blend = float(min(max(blend, 0.0), 1.0))
    base_goodness = (1.0 - blend) * spatial_pool + blend * freq_pool
    cons_pool = adaptive_pool(torch.stack([c_sign, c_freq, 1.0 - outlier_pen], dim=0))
    gate_strength = float(robust_0_1(base_goodness).var().clamp(0, 1).item())
    gate = (1.0 - 0.5 * gate_strength) + 0.5 * gate_strength * (0.6 * cons_pool + 0.4 * guard)
    g = base_goodness * gate
    closeness = robust_0_1(1.0 / (delta_norm + 1e-8))
    flatness = 1.0 - robust_0_1(g).var().clamp(0, 1)
    g = g + 0.02 * float(flatness.item()) * closeness
    # ------- Map to weights -------
    focus = float(max(0.0, min(focus, 50.0)))  # cap for stability
    g_med = g.median()
    g_mad = (g - g_med).abs().median() + eps
    g_z = (g - g_med) / (1.4826 * g_mad + eps)
    w = torch.softmax(focus * g_z, dim=0).to(stack.dtype)
    merged = torch.einsum("i...,i->...", stack, w)
    # ------- Debug output -------
    if kwargs.get("return_debug", False) or kwargs.get("on_debug"):
        dbg = {
            "weights": w,
            "signals": {
                "u_spatial": u_spatial, "u_jl": u_jl,
                "u_freq": u_freq, "s_mid": s_mid, "s_bal": s_bal, "c_freq": c_freq,
                "c_sign": c_sign, "outlier_pen": outlier_pen, "guard": guard,
                "spatial_pool": spatial_pool, "freq_pool": freq_pool,
                "base_goodness": base_goodness, "gate": gate, "final_g": g,
            },
            "settings": {"focus": focus, "blend": blend, "base_used": base is not None},
        }
        if kwargs.get("on_debug"):
            try: kwargs["on_debug"](dbg)
            except Exception: pass
        if kwargs.get("return_debug", False):
            return merged, dbg
    return merged
Thanks, as for merge method I used gpt-5 made custom gpt and uploaded all the .py files in mergekit/mergemethods, and also the markdown code of https://github.com/arcee-ai/mergekit/blob/main/docs/create_a_merge_method.md, It works well but can forget the header import section, I recommend starting new context for coding it. Then drop the new method.py into mergekit mergemethods folder and edit init.py and add new method - import mergekit.merge_methods.yourmethod. Also here is the .py code for harmony_forge.
harmony_forge.py
from mergekit.merge_methods.easy_define import merge_method import torch import torch.nn.functional as F from typing import List @merge_method("harmony_forge") @torch .no_grad() def harmony_forge( tensors: List[torch.Tensor], focus: float = 1.0, # β -> more peaked weights (softmax scale) blend: float = 0.5, # 0=spatial-only .. 1=frequency-only **kwargs # optional: base=<tensor>, return_debug=True ) -> torch.Tensor: """ Harmony Forge (two knobs): - Works with or without a base tensor (task-vector mode if base provided). - Two controls: focus (decisiveness) and blend (spatial vs frequency). - Internally auto-tunes consensus, outlier control, and stability guards. - Now memory-safe for very large models (JL projection capped). """ # ------- Basic checks ------- n = len(tensors) if n == 0: raise ValueError("harmony_forge: need at least one tensor") if n == 1 and "base" not in kwargs: return tensors[0] # ------- Stack & setup ------- stack = torch.stack(tensors) # (n, ...) device = stack.device work_dtype = torch.float32 eps = 1e-12 # ------- Center (base or robust median) ------- base = kwargs.get("base", None) if base is not None: center = base.to(device=stack.device, dtype=stack.dtype) else: center = stack.median(dim=0).values deltas = (stack - center) Z = deltas.flatten(1).to(work_dtype) # (n, P) P = Z.shape[1] # ------- Scale calibration (auto) ------- norms = torch.norm(Z, dim=1) + eps mean_norm = norms.mean() scale = (mean_norm / norms).unsqueeze(1) Z = Z * scale deltas = deltas * scale.view(n, *([1] * (deltas.ndim - 1))) # ------- Robust normalization helper ------- def robust_0_1(x: torch.Tensor) -> torch.Tensor: med = x.median() mad = (x - med).abs().median() + eps z = (x - med) / (1.4826 * mad + eps) z = z.clamp(-4.0, 4.0) return (z - z.min()) / (z.max() - z.min() + eps) # ------- Spatial uniqueness & consensus ------- Zs = F.normalize(Z, dim=1, eps=eps) sims_sp = Zs @ Zs.T mean_sim_sp = (sims_sp.sum(dim=1) - 1.0) / max(n - 1, 1) u_spatial = robust_0_1((1.0 - mean_sim_sp).clamp(0.0, 1.0)) # ------- JL projections (memory-safe) ------- rng = torch.Generator(device=("cuda" if device.type == "cuda" else "cpu")) rng.manual_seed(0xACED) num_projs = int(kwargs.get("jl_num_projs", 6)) max_proj_cols = int(kwargs.get("jl_max_cols", 500_000)) jl_k = int(kwargs.get("jl_k", 128)) inv_sqrtP = (min(P, max_proj_cols)) ** -0.5 proj_uniqs = [] if P > max_proj_cols: idx = torch.randperm(P, device=device, generator=rng)[:max_proj_cols] Z_sub = Z[:, idx] P_eff = Z_sub.shape[1] else: Z_sub = Z P_eff = P for _ in range(num_projs): R = torch.randn(P_eff, jl_k, dtype=work_dtype, device=device, generator=rng) * inv_sqrtP Zp = F.normalize(Z_sub @ R, dim=1, eps=eps) ms = (Zp @ Zp.T) ms = (ms.sum(dim=1) - 1.0) / max(n - 1, 1) proj_uniqs.append((1.0 - ms).clamp(0.0, 1.0)) u_jl = robust_0_1(torch.stack(proj_uniqs, dim=0).mean(dim=0)) # ------- Sign consensus & outlier detection ------- Zs = F.normalize(Z, dim=1, eps=eps) maj_dir = F.normalize(Zs.sum(dim=0, keepdim=True), dim=1, eps=eps) c_sign = robust_0_1((Zs @ maj_dir.T).squeeze(1).clamp(-1, 1) * 0.5 + 0.5) delta_norm = torch.norm(Z, dim=1) outlier_pen = 1.0 - robust_0_1(delta_norm) # ------- Frequency features & consensus ------- if P > 4: Z_fft = Z[:, :: max(1, Z.shape[1] // 1_000_000)] if Z.shape[1] > 1_000_000 else Z spec = torch.fft.rfft(Z_fft, dim=1) mag = spec.abs() + eps mags = F.normalize(mag, dim=1, eps=eps) sims_hm = mags @ mags.T mean_sim_hm = (sims_hm.sum(dim=1) - 1.0) / max(n - 1, 1) u_freq = robust_0_1((1.0 - mean_sim_hm).clamp(0.0, 1.0)) P_fft = mag.shape[1] low_hi_cut = max(1, int(0.05 * P_fft)) mid_lo = max(1, int(0.10 * P_fft)) mid_hi = min(P_fft, max(mid_lo + 1, int(0.40 * P_fft))) low_e = mag[:, :low_hi_cut].sum(dim=1) mid_e = mag[:, mid_lo:mid_hi].sum(dim=1) high_e = mag[:, mid_hi:].sum(dim=1) total_e = low_e + mid_e + high_e + eps s_mid = robust_0_1(mid_e / total_e) s_bal = robust_0_1(1.0 - (low_e / total_e - high_e / total_e).abs().clamp(0, 1)) unit = spec / (spec.abs() + eps) unit_flat = torch.view_as_real(unit).flatten(1) unit_flat = F.normalize(unit_flat, dim=1, eps=eps) phase_sims = unit_flat @ unit_flat.T mean_phase = (phase_sims.sum(dim=1) - 1.0) / max(n - 1, 1) c_freq = robust_0_1((mean_phase + 1.0) / 2.0) else: # fallback for tiny tensors u_freq = s_mid = s_bal = c_freq = torch.zeros(n, dtype=work_dtype, device=device) # ------- Stability guard ------- rough = [] for i in range(n): z_i = Z[i] side = int(z_i.numel() ** 0.5) side = max(1, side) mat = z_i[: side * side].reshape(side, side) m = mat.to(work_dtype) v = F.normalize(torch.randn(m.shape[1], device=device, dtype=work_dtype), dim=0, eps=eps) u = F.normalize(m @ v, dim=0, eps=eps) sv = torch.dot(u, m @ v).abs() rough.append(float(sv.item())) rough = torch.tensor(rough, device=device, dtype=work_dtype) guard = 1.0 - robust_0_1(rough) # ------- Adaptive pooling ------- def adaptive_pool(S: torch.Tensor) -> torch.Tensor: var = S.var(dim=1) + eps invvar = 1.0 / var w = invvar / invvar.sum() return (S * w.unsqueeze(1)).sum(dim=0) spatial_pool = adaptive_pool(torch.stack([u_spatial, u_jl, c_sign], dim=0)) freq_pool = adaptive_pool(torch.stack([u_freq, s_mid, s_bal, c_freq], dim=0)) # ------- Blend + gating ------- blend = float(min(max(blend, 0.0), 1.0)) base_goodness = (1.0 - blend) * spatial_pool + blend * freq_pool cons_pool = adaptive_pool(torch.stack([c_sign, c_freq, 1.0 - outlier_pen], dim=0)) gate_strength = float(robust_0_1(base_goodness).var().clamp(0, 1).item()) gate = (1.0 - 0.5 * gate_strength) + 0.5 * gate_strength * (0.6 * cons_pool + 0.4 * guard) g = base_goodness * gate closeness = robust_0_1(1.0 / (delta_norm + 1e-8)) flatness = 1.0 - robust_0_1(g).var().clamp(0, 1) g = g + 0.02 * float(flatness.item()) * closeness # ------- Map to weights ------- focus = float(max(0.0, min(focus, 50.0))) # cap for stability g_med = g.median() g_mad = (g - g_med).abs().median() + eps g_z = (g - g_med) / (1.4826 * g_mad + eps) w = torch.softmax(focus * g_z, dim=0).to(stack.dtype) merged = torch.einsum("i...,i->...", stack, w) # ------- Debug output ------- if kwargs.get("return_debug", False) or kwargs.get("on_debug"): dbg = { "weights": w, "signals": { "u_spatial": u_spatial, "u_jl": u_jl, "u_freq": u_freq, "s_mid": s_mid, "s_bal": s_bal, "c_freq": c_freq, "c_sign": c_sign, "outlier_pen": outlier_pen, "guard": guard, "spatial_pool": spatial_pool, "freq_pool": freq_pool, "base_goodness": base_goodness, "gate": gate, "final_g": g, }, "settings": {"focus": focus, "blend": blend, "base_used": base is not None}, } if kwargs.get("on_debug"): try: kwargs["on_debug"](dbg) except Exception: pass if kwargs.get("return_debug", False): return merged, dbg return merged
Cool, thanks for sharing!
Will definitely try it out.
The code block seems to warp it a bit, copy the code from my edits that seems to work.