Kernels
sae
flex-sae / torch-ext /flex_sae /topk_kernels.py
elephantmipt's picture
Upload folder using huggingface_hub
a262a48 verified
# TopK SAE decoder Triton kernels
# Copyright 2025 T-Tech
# This code is adapted from Facebook Research under the
# Creative Commons Attribution-NonCommercial 4.0 International License.
# Original code can be found at: https://github.com/facebookresearch/memory
from typing import Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def embedding_bag_forward_kernel(
out_ptr, # [B, D]
indices_ptr, # [B, K]
weight_ptr, # [F, D]
vals_ptr, # [B, K]
D: tl.constexpr,
K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
tl.static_assert((D % BLOCK_D) == 0)
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
b = tl.program_id(axis=0).to(tl.int64)
pid_d = tl.program_id(axis=1).to(tl.int64)
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
out_value = tl.zeros([BLOCK_D], dtype=tl.float32)
for i in tl.range(K):
my_index = tl.load(indices_ptr + b * K + i).to(tl.int64)
my_scaling = tl.load(vals_ptr + b * K + i)
w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32)
out_value += w_tile * my_scaling
tl.store(out_ptr + b * D + off_d, out_value)
def embedding_bag_forward(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
) -> torch.Tensor:
B, K = indices.shape
D = weight.shape[1]
trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device)
def _forward_grid(meta):
return (B, D // meta["BLOCK_D"])
embedding_bag_forward_kernel[_forward_grid](
trt_out,
indices,
weight,
vals,
D=D,
K=K,
BLOCK_D=64,
num_warps=1,
num_stages=1,
)
return trt_out
@triton.jit
def count_per_embedding_kernel(
count_per_emb_ptr, # [F + 1]
indices_ptr, # [B, K]
K: tl.constexpr,
):
batch_id = tl.program_id(axis=0).to(tl.int64)
for t in tl.range(K):
embedding_id = tl.load(indices_ptr + batch_id * K + t)
tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed")
@triton.jit
def map_embeddings_and_outputs_kernel(
reverse_mapping_ptr, # [B * K]
mapping_write_pos_ptr, # [F]
indices_ptr, # [B, K]
K: tl.constexpr,
):
batch_id = tl.program_id(axis=0).to(tl.int64)
for t in tl.range(K):
embedding_id = tl.load(indices_ptr + batch_id * K + t)
write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed")
tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t)
@triton.jit
def aggregate_gradient_for_embedding_kernel(
weight_grad_ptr, # [F, D]
vals_grad_ptr, # [B, K]
weight_ptr, # [F, D]
emb_begin_pos_ptr, # [F + 1]
reverse_mapping_ptr, # [B * K]
vals_ptr, # [B, K]
gradient_ptr, # [B, D]
D: tl.constexpr,
K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
tl.static_assert((D % BLOCK_D) == 0)
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
e = tl.program_id(axis=0).to(tl.int64)
pid_d = tl.program_id(axis=1).to(tl.int64)
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
begin = tl.load(emb_begin_pos_ptr + e)
end = tl.load(emb_begin_pos_ptr + e + 1)
w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32)
w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32)
for idx in tl.range(begin, end):
out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64)
b = out_linear // K
psw = tl.load(vals_ptr + out_linear)
g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32)
w_grad_tile += psw * g_tile
psw_grad_partial = tl.sum(g_tile * w_row_tile)
tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed")
tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile)
def embedding_bag_backward(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
gradient: torch.Tensor, # [B, D]
) -> Tuple[torch.Tensor, torch.Tensor]:
F, D = weight.shape
B, K = indices.shape
count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device)
count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1)
emb_begin_pos = count_per_emb.cumsum(0) # [F + 1]
reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device)
assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1
map_embeddings_and_outputs_kernel[(B,)](
reverse_mapping_ptr=reverse_mapping,
mapping_write_pos_ptr=emb_begin_pos.clone(),
indices_ptr=indices,
K=K,
num_warps=1,
)
weight_grad = torch.empty_like(weight, dtype=torch.float32) # [F, D]
vals_grad = torch.zeros_like(vals, dtype=torch.float32) # [B, K]
def _forward_grid(meta):
return (F, D // meta["BLOCK_D"])
aggregate_gradient_for_embedding_kernel[_forward_grid](
weight_grad_ptr=weight_grad,
vals_grad_ptr=vals_grad,
weight_ptr=weight,
emb_begin_pos_ptr=emb_begin_pos,
reverse_mapping_ptr=reverse_mapping,
vals_ptr=vals,
gradient_ptr=gradient,
D=D,
K=K,
BLOCK_D=256,
num_warps=1,
num_stages=2,
)
return weight_grad, vals_grad
class xFormersEmbeddingBag(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
) -> torch.Tensor:
ctx.save_for_backward(indices, weight, vals)
return embedding_bag_forward(indices, weight, vals) # [B, D]
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, gradient):
indices, weight, vals = ctx.saved_tensors
weight_g, vals_g = embedding_bag_backward(
indices,
weight,
vals,
gradient,
)
return None, weight_g, vals_g
def triton_topk_sae_loss(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
) -> torch.Tensor:
recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals)
diff = recon.to(torch.float32) - target.to(torch.float32)
loss = diff.pow(2).mean()
return loss
def topk_sae_loss(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
) -> torch.Tensor:
emb = weight[indices].to(torch.float32) # [K, D]
recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1)
diff = recon.to(torch.float32) - target.to(torch.float32)
loss = diff.pow(2).mean()
return loss