Merge method

#1
by Retreatcost - opened

Hello πŸ‘‹

Did you create the merging method yourself?

By the way great quality of model cards lately, really like them!

Unfortunately don't have enough time to try them all tho 😞

Thanks, as for merge method I used gpt-5 made custom gpt and uploaded all the .py files in mergekit/mergemethods, and also the markdown code of https://github.com/arcee-ai/mergekit/blob/main/docs/create_a_merge_method.md, It works well but can forget the header import section, I recommend starting new context for coding it. Then drop the new method.py into mergekit mergemethods folder and edit init.py and add new method - import mergekit.merge_methods.yourmethod, then restart it. Also here is the .py code for harmony_forge.

harmony_forge.py
from mergekit.merge_methods.easy_define import merge_method
import torch
import torch.nn.functional as F
from typing import List

@merge_method("harmony_forge")


@torch
	.no_grad()
def harmony_forge(
    tensors: List[torch.Tensor],
    focus: float = 1.0,   # ↑ -> more peaked weights (softmax scale)
    blend: float = 0.5,   # 0=spatial-only .. 1=frequency-only
    **kwargs              # optional: base=<tensor>, return_debug=True
) -> torch.Tensor:
    """
    Harmony Forge (two knobs):
      - Works with or without a base tensor (task-vector mode if base provided).
      - Two controls: focus (decisiveness) and blend (spatial vs frequency).
      - Internally auto-tunes consensus, outlier control, and stability guards.
      - Now memory-safe for very large models (JL projection capped).
    """
    # ------- Basic checks -------
    n = len(tensors)
    if n == 0:
        raise ValueError("harmony_forge: need at least one tensor")
    if n == 1 and "base" not in kwargs:
        return tensors[0]

    # ------- Stack & setup -------
    stack = torch.stack(tensors)            # (n, ...)
    device = stack.device
    work_dtype = torch.float32
    eps = 1e-12

    # ------- Center (base or robust median) -------
    base = kwargs.get("base", None)
    if base is not None:
        center = base.to(device=stack.device, dtype=stack.dtype)
    else:
        center = stack.median(dim=0).values

    deltas = (stack - center)
    Z = deltas.flatten(1).to(work_dtype)    # (n, P)
    P = Z.shape[1]

    # ------- Scale calibration (auto) -------
    norms = torch.norm(Z, dim=1) + eps
    mean_norm = norms.mean()
    scale = (mean_norm / norms).unsqueeze(1)
    Z = Z * scale
    deltas = deltas * scale.view(n, *([1] * (deltas.ndim - 1)))

    # ------- Robust normalization helper -------
    def robust_0_1(x: torch.Tensor) -> torch.Tensor:
        med = x.median()
        mad = (x - med).abs().median() + eps
        z = (x - med) / (1.4826 * mad + eps)
        z = z.clamp(-4.0, 4.0)
        return (z - z.min()) / (z.max() - z.min() + eps)

    # ------- Spatial uniqueness & consensus -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    sims_sp = Zs @ Zs.T
    mean_sim_sp = (sims_sp.sum(dim=1) - 1.0) / max(n - 1, 1)
    u_spatial = robust_0_1((1.0 - mean_sim_sp).clamp(0.0, 1.0))

    # ------- JL projections (memory-safe) -------
    rng = torch.Generator(device=("cuda" if device.type == "cuda" else "cpu"))
    rng.manual_seed(0xACED)
    num_projs = int(kwargs.get("jl_num_projs", 6))
    max_proj_cols = int(kwargs.get("jl_max_cols", 500_000))
    jl_k = int(kwargs.get("jl_k", 128))
    inv_sqrtP = (min(P, max_proj_cols)) ** -0.5
    proj_uniqs = []

    if P > max_proj_cols:
        idx = torch.randperm(P, device=device, generator=rng)[:max_proj_cols]
        Z_sub = Z[:, idx]
        P_eff = Z_sub.shape[1]
    else:
        Z_sub = Z
        P_eff = P

    for _ in range(num_projs):
        R = torch.randn(P_eff, jl_k, dtype=work_dtype, device=device, generator=rng) * inv_sqrtP
        Zp = F.normalize(Z_sub @ R, dim=1, eps=eps)
        ms = (Zp @ Zp.T)
        ms = (ms.sum(dim=1) - 1.0) / max(n - 1, 1)
        proj_uniqs.append((1.0 - ms).clamp(0.0, 1.0))
    u_jl = robust_0_1(torch.stack(proj_uniqs, dim=0).mean(dim=0))

    # ------- Sign consensus & outlier detection -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    maj_dir = F.normalize(Zs.sum(dim=0, keepdim=True), dim=1, eps=eps)
    c_sign = robust_0_1((Zs @ maj_dir.T).squeeze(1).clamp(-1, 1) * 0.5 + 0.5)

    delta_norm = torch.norm(Z, dim=1)
    outlier_pen = 1.0 - robust_0_1(delta_norm)

    # ------- Frequency features & consensus -------
    if P > 4:
        Z_fft = Z[:, :: max(1, Z.shape[1] // 1_000_000)] if Z.shape[1] > 1_000_000 else Z
        spec = torch.fft.rfft(Z_fft, dim=1)
        mag = spec.abs() + eps
        mags = F.normalize(mag, dim=1, eps=eps)
        sims_hm = mags @ mags.T
        mean_sim_hm = (sims_hm.sum(dim=1) - 1.0) / max(n - 1, 1)
        u_freq = robust_0_1((1.0 - mean_sim_hm).clamp(0.0, 1.0))

        P_fft = mag.shape[1]
        low_hi_cut = max(1, int(0.05 * P_fft))
        mid_lo = max(1, int(0.10 * P_fft))
        mid_hi = min(P_fft, max(mid_lo + 1, int(0.40 * P_fft)))

        low_e = mag[:, :low_hi_cut].sum(dim=1)
        mid_e = mag[:, mid_lo:mid_hi].sum(dim=1)
        high_e = mag[:, mid_hi:].sum(dim=1)
        total_e = low_e + mid_e + high_e + eps

        s_mid = robust_0_1(mid_e / total_e)
        s_bal = robust_0_1(1.0 - (low_e / total_e - high_e / total_e).abs().clamp(0, 1))

        unit = spec / (spec.abs() + eps)
        unit_flat = torch.view_as_real(unit).flatten(1)
        unit_flat = F.normalize(unit_flat, dim=1, eps=eps)
        phase_sims = unit_flat @ unit_flat.T
        mean_phase = (phase_sims.sum(dim=1) - 1.0) / max(n - 1, 1)
        c_freq = robust_0_1((mean_phase + 1.0) / 2.0)
    else:
        # fallback for tiny tensors
        u_freq = s_mid = s_bal = c_freq = torch.zeros(n, dtype=work_dtype, device=device)

    # ------- Stability guard -------
    rough = []
    for i in range(n):
        z_i = Z[i]
        side = int(z_i.numel() ** 0.5)
        side = max(1, side)
        mat = z_i[: side * side].reshape(side, side)
        m = mat.to(work_dtype)
        v = F.normalize(torch.randn(m.shape[1], device=device, dtype=work_dtype), dim=0, eps=eps)
        u = F.normalize(m @ v, dim=0, eps=eps)
        sv = torch.dot(u, m @ v).abs()
        rough.append(float(sv.item()))
    rough = torch.tensor(rough, device=device, dtype=work_dtype)
    guard = 1.0 - robust_0_1(rough)

    # ------- Adaptive pooling -------
    def adaptive_pool(S: torch.Tensor) -> torch.Tensor:
        var = S.var(dim=1) + eps
        invvar = 1.0 / var
        w = invvar / invvar.sum()
        return (S * w.unsqueeze(1)).sum(dim=0)

    spatial_pool = adaptive_pool(torch.stack([u_spatial, u_jl, c_sign], dim=0))
    freq_pool = adaptive_pool(torch.stack([u_freq, s_mid, s_bal, c_freq], dim=0))

    # ------- Blend + gating -------
    blend = float(min(max(blend, 0.0), 1.0))
    base_goodness = (1.0 - blend) * spatial_pool + blend * freq_pool

    cons_pool = adaptive_pool(torch.stack([c_sign, c_freq, 1.0 - outlier_pen], dim=0))
    gate_strength = float(robust_0_1(base_goodness).var().clamp(0, 1).item())
    gate = (1.0 - 0.5 * gate_strength) + 0.5 * gate_strength * (0.6 * cons_pool + 0.4 * guard)

    g = base_goodness * gate
    closeness = robust_0_1(1.0 / (delta_norm + 1e-8))
    flatness = 1.0 - robust_0_1(g).var().clamp(0, 1)
    g = g + 0.02 * float(flatness.item()) * closeness

    # ------- Map to weights -------
    focus = float(max(0.0, min(focus, 50.0)))  # cap for stability
    g_med = g.median()
    g_mad = (g - g_med).abs().median() + eps
    g_z = (g - g_med) / (1.4826 * g_mad + eps)
    w = torch.softmax(focus * g_z, dim=0).to(stack.dtype)

    merged = torch.einsum("i...,i->...", stack, w)

    # ------- Debug output -------
    if kwargs.get("return_debug", False) or kwargs.get("on_debug"):
        dbg = {
            "weights": w,
            "signals": {
                "u_spatial": u_spatial, "u_jl": u_jl,
                "u_freq": u_freq, "s_mid": s_mid, "s_bal": s_bal, "c_freq": c_freq,
                "c_sign": c_sign, "outlier_pen": outlier_pen, "guard": guard,
                "spatial_pool": spatial_pool, "freq_pool": freq_pool,
                "base_goodness": base_goodness, "gate": gate, "final_g": g,
            },
            "settings": {"focus": focus, "blend": blend, "base_used": base is not None},
        }
        if kwargs.get("on_debug"):
            try: kwargs["on_debug"](dbg)
            except Exception: pass
        if kwargs.get("return_debug", False):
            return merged, dbg

    return merged

Thanks, as for merge method I used gpt-5 made custom gpt and uploaded all the .py files in mergekit/mergemethods, and also the markdown code of https://github.com/arcee-ai/mergekit/blob/main/docs/create_a_merge_method.md, It works well but can forget the header import section, I recommend starting new context for coding it. Then drop the new method.py into mergekit mergemethods folder and edit init.py and add new method - import mergekit.merge_methods.yourmethod. Also here is the .py code for harmony_forge.

harmony_forge.py
from mergekit.merge_methods.easy_define import merge_method
import torch
import torch.nn.functional as F
from typing import List

@merge_method("harmony_forge")


@torch
	.no_grad()
def harmony_forge(
    tensors: List[torch.Tensor],
    focus: float = 1.0,   # ↑ -> more peaked weights (softmax scale)
    blend: float = 0.5,   # 0=spatial-only .. 1=frequency-only
    **kwargs              # optional: base=<tensor>, return_debug=True
) -> torch.Tensor:
    """
    Harmony Forge (two knobs):
      - Works with or without a base tensor (task-vector mode if base provided).
      - Two controls: focus (decisiveness) and blend (spatial vs frequency).
      - Internally auto-tunes consensus, outlier control, and stability guards.
      - Now memory-safe for very large models (JL projection capped).
    """
    # ------- Basic checks -------
    n = len(tensors)
    if n == 0:
        raise ValueError("harmony_forge: need at least one tensor")
    if n == 1 and "base" not in kwargs:
        return tensors[0]

    # ------- Stack & setup -------
    stack = torch.stack(tensors)            # (n, ...)
    device = stack.device
    work_dtype = torch.float32
    eps = 1e-12

    # ------- Center (base or robust median) -------
    base = kwargs.get("base", None)
    if base is not None:
        center = base.to(device=stack.device, dtype=stack.dtype)
    else:
        center = stack.median(dim=0).values

    deltas = (stack - center)
    Z = deltas.flatten(1).to(work_dtype)    # (n, P)
    P = Z.shape[1]

    # ------- Scale calibration (auto) -------
    norms = torch.norm(Z, dim=1) + eps
    mean_norm = norms.mean()
    scale = (mean_norm / norms).unsqueeze(1)
    Z = Z * scale
    deltas = deltas * scale.view(n, *([1] * (deltas.ndim - 1)))

    # ------- Robust normalization helper -------
    def robust_0_1(x: torch.Tensor) -> torch.Tensor:
        med = x.median()
        mad = (x - med).abs().median() + eps
        z = (x - med) / (1.4826 * mad + eps)
        z = z.clamp(-4.0, 4.0)
        return (z - z.min()) / (z.max() - z.min() + eps)

    # ------- Spatial uniqueness & consensus -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    sims_sp = Zs @ Zs.T
    mean_sim_sp = (sims_sp.sum(dim=1) - 1.0) / max(n - 1, 1)
    u_spatial = robust_0_1((1.0 - mean_sim_sp).clamp(0.0, 1.0))

    # ------- JL projections (memory-safe) -------
    rng = torch.Generator(device=("cuda" if device.type == "cuda" else "cpu"))
    rng.manual_seed(0xACED)
    num_projs = int(kwargs.get("jl_num_projs", 6))
    max_proj_cols = int(kwargs.get("jl_max_cols", 500_000))
    jl_k = int(kwargs.get("jl_k", 128))
    inv_sqrtP = (min(P, max_proj_cols)) ** -0.5
    proj_uniqs = []

    if P > max_proj_cols:
        idx = torch.randperm(P, device=device, generator=rng)[:max_proj_cols]
        Z_sub = Z[:, idx]
        P_eff = Z_sub.shape[1]
    else:
        Z_sub = Z
        P_eff = P

    for _ in range(num_projs):
        R = torch.randn(P_eff, jl_k, dtype=work_dtype, device=device, generator=rng) * inv_sqrtP
        Zp = F.normalize(Z_sub @ R, dim=1, eps=eps)
        ms = (Zp @ Zp.T)
        ms = (ms.sum(dim=1) - 1.0) / max(n - 1, 1)
        proj_uniqs.append((1.0 - ms).clamp(0.0, 1.0))
    u_jl = robust_0_1(torch.stack(proj_uniqs, dim=0).mean(dim=0))

    # ------- Sign consensus & outlier detection -------
    Zs = F.normalize(Z, dim=1, eps=eps)
    maj_dir = F.normalize(Zs.sum(dim=0, keepdim=True), dim=1, eps=eps)
    c_sign = robust_0_1((Zs @ maj_dir.T).squeeze(1).clamp(-1, 1) * 0.5 + 0.5)

    delta_norm = torch.norm(Z, dim=1)
    outlier_pen = 1.0 - robust_0_1(delta_norm)

    # ------- Frequency features & consensus -------
    if P > 4:
        Z_fft = Z[:, :: max(1, Z.shape[1] // 1_000_000)] if Z.shape[1] > 1_000_000 else Z
        spec = torch.fft.rfft(Z_fft, dim=1)
        mag = spec.abs() + eps
        mags = F.normalize(mag, dim=1, eps=eps)
        sims_hm = mags @ mags.T
        mean_sim_hm = (sims_hm.sum(dim=1) - 1.0) / max(n - 1, 1)
        u_freq = robust_0_1((1.0 - mean_sim_hm).clamp(0.0, 1.0))

        P_fft = mag.shape[1]
        low_hi_cut = max(1, int(0.05 * P_fft))
        mid_lo = max(1, int(0.10 * P_fft))
        mid_hi = min(P_fft, max(mid_lo + 1, int(0.40 * P_fft)))

        low_e = mag[:, :low_hi_cut].sum(dim=1)
        mid_e = mag[:, mid_lo:mid_hi].sum(dim=1)
        high_e = mag[:, mid_hi:].sum(dim=1)
        total_e = low_e + mid_e + high_e + eps

        s_mid = robust_0_1(mid_e / total_e)
        s_bal = robust_0_1(1.0 - (low_e / total_e - high_e / total_e).abs().clamp(0, 1))

        unit = spec / (spec.abs() + eps)
        unit_flat = torch.view_as_real(unit).flatten(1)
        unit_flat = F.normalize(unit_flat, dim=1, eps=eps)
        phase_sims = unit_flat @ unit_flat.T
        mean_phase = (phase_sims.sum(dim=1) - 1.0) / max(n - 1, 1)
        c_freq = robust_0_1((mean_phase + 1.0) / 2.0)
    else:
        # fallback for tiny tensors
        u_freq = s_mid = s_bal = c_freq = torch.zeros(n, dtype=work_dtype, device=device)

    # ------- Stability guard -------
    rough = []
    for i in range(n):
        z_i = Z[i]
        side = int(z_i.numel() ** 0.5)
        side = max(1, side)
        mat = z_i[: side * side].reshape(side, side)
        m = mat.to(work_dtype)
        v = F.normalize(torch.randn(m.shape[1], device=device, dtype=work_dtype), dim=0, eps=eps)
        u = F.normalize(m @ v, dim=0, eps=eps)
        sv = torch.dot(u, m @ v).abs()
        rough.append(float(sv.item()))
    rough = torch.tensor(rough, device=device, dtype=work_dtype)
    guard = 1.0 - robust_0_1(rough)

    # ------- Adaptive pooling -------
    def adaptive_pool(S: torch.Tensor) -> torch.Tensor:
        var = S.var(dim=1) + eps
        invvar = 1.0 / var
        w = invvar / invvar.sum()
        return (S * w.unsqueeze(1)).sum(dim=0)

    spatial_pool = adaptive_pool(torch.stack([u_spatial, u_jl, c_sign], dim=0))
    freq_pool = adaptive_pool(torch.stack([u_freq, s_mid, s_bal, c_freq], dim=0))

    # ------- Blend + gating -------
    blend = float(min(max(blend, 0.0), 1.0))
    base_goodness = (1.0 - blend) * spatial_pool + blend * freq_pool

    cons_pool = adaptive_pool(torch.stack([c_sign, c_freq, 1.0 - outlier_pen], dim=0))
    gate_strength = float(robust_0_1(base_goodness).var().clamp(0, 1).item())
    gate = (1.0 - 0.5 * gate_strength) + 0.5 * gate_strength * (0.6 * cons_pool + 0.4 * guard)

    g = base_goodness * gate
    closeness = robust_0_1(1.0 / (delta_norm + 1e-8))
    flatness = 1.0 - robust_0_1(g).var().clamp(0, 1)
    g = g + 0.02 * float(flatness.item()) * closeness

    # ------- Map to weights -------
    focus = float(max(0.0, min(focus, 50.0)))  # cap for stability
    g_med = g.median()
    g_mad = (g - g_med).abs().median() + eps
    g_z = (g - g_med) / (1.4826 * g_mad + eps)
    w = torch.softmax(focus * g_z, dim=0).to(stack.dtype)

    merged = torch.einsum("i...,i->...", stack, w)

    # ------- Debug output -------
    if kwargs.get("return_debug", False) or kwargs.get("on_debug"):
        dbg = {
            "weights": w,
            "signals": {
                "u_spatial": u_spatial, "u_jl": u_jl,
                "u_freq": u_freq, "s_mid": s_mid, "s_bal": s_bal, "c_freq": c_freq,
                "c_sign": c_sign, "outlier_pen": outlier_pen, "guard": guard,
                "spatial_pool": spatial_pool, "freq_pool": freq_pool,
                "base_goodness": base_goodness, "gate": gate, "final_g": g,
            },
            "settings": {"focus": focus, "blend": blend, "base_used": base is not None},
        }
        if kwargs.get("on_debug"):
            try: kwargs["on_debug"](dbg)
            except Exception: pass
        if kwargs.get("return_debug", False):
            return merged, dbg

    return merged

Cool, thanks for sharing!

Will definitely try it out.

The code block seems to warp it a bit, copy the code from my edits that seems to work.

Sign up or log in to comment