|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def embedding_bag_forward_kernel( |
|
|
out_ptr, |
|
|
indices_ptr, |
|
|
weight_ptr, |
|
|
vals_ptr, |
|
|
D: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
BLOCK_D: tl.constexpr, |
|
|
): |
|
|
tl.static_assert((D % BLOCK_D) == 0) |
|
|
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
|
|
|
|
|
b = tl.program_id(axis=0).to(tl.int64) |
|
|
pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
|
|
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
|
|
|
out_value = tl.zeros([BLOCK_D], dtype=tl.float32) |
|
|
for i in tl.range(K): |
|
|
my_index = tl.load(indices_ptr + b * K + i).to(tl.int64) |
|
|
my_scaling = tl.load(vals_ptr + b * K + i) |
|
|
w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32) |
|
|
out_value += w_tile * my_scaling |
|
|
|
|
|
tl.store(out_ptr + b * D + off_d, out_value) |
|
|
|
|
|
|
|
|
def embedding_bag_forward( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
B, K = indices.shape |
|
|
D = weight.shape[1] |
|
|
|
|
|
trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device) |
|
|
|
|
|
def _forward_grid(meta): |
|
|
return (B, D // meta["BLOCK_D"]) |
|
|
|
|
|
embedding_bag_forward_kernel[_forward_grid]( |
|
|
trt_out, |
|
|
indices, |
|
|
weight, |
|
|
vals, |
|
|
D=D, |
|
|
K=K, |
|
|
BLOCK_D=64, |
|
|
num_warps=1, |
|
|
num_stages=1, |
|
|
) |
|
|
return trt_out |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def count_per_embedding_kernel( |
|
|
count_per_emb_ptr, |
|
|
indices_ptr, |
|
|
K: tl.constexpr, |
|
|
): |
|
|
batch_id = tl.program_id(axis=0).to(tl.int64) |
|
|
for t in tl.range(K): |
|
|
embedding_id = tl.load(indices_ptr + batch_id * K + t) |
|
|
tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed") |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def map_embeddings_and_outputs_kernel( |
|
|
reverse_mapping_ptr, |
|
|
mapping_write_pos_ptr, |
|
|
indices_ptr, |
|
|
K: tl.constexpr, |
|
|
): |
|
|
batch_id = tl.program_id(axis=0).to(tl.int64) |
|
|
for t in tl.range(K): |
|
|
embedding_id = tl.load(indices_ptr + batch_id * K + t) |
|
|
write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed") |
|
|
tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def aggregate_gradient_for_embedding_kernel( |
|
|
weight_grad_ptr, |
|
|
vals_grad_ptr, |
|
|
weight_ptr, |
|
|
emb_begin_pos_ptr, |
|
|
reverse_mapping_ptr, |
|
|
vals_ptr, |
|
|
gradient_ptr, |
|
|
D: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
BLOCK_D: tl.constexpr, |
|
|
): |
|
|
tl.static_assert((D % BLOCK_D) == 0) |
|
|
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
|
|
|
|
|
e = tl.program_id(axis=0).to(tl.int64) |
|
|
pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
|
|
|
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
|
|
|
begin = tl.load(emb_begin_pos_ptr + e) |
|
|
end = tl.load(emb_begin_pos_ptr + e + 1) |
|
|
|
|
|
w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32) |
|
|
w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32) |
|
|
|
|
|
for idx in tl.range(begin, end): |
|
|
out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64) |
|
|
b = out_linear // K |
|
|
|
|
|
psw = tl.load(vals_ptr + out_linear) |
|
|
g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32) |
|
|
|
|
|
w_grad_tile += psw * g_tile |
|
|
|
|
|
psw_grad_partial = tl.sum(g_tile * w_row_tile) |
|
|
tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed") |
|
|
|
|
|
tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile) |
|
|
|
|
|
|
|
|
def embedding_bag_backward( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
gradient: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
F, D = weight.shape |
|
|
B, K = indices.shape |
|
|
|
|
|
count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device) |
|
|
count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1) |
|
|
|
|
|
emb_begin_pos = count_per_emb.cumsum(0) |
|
|
|
|
|
reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device) |
|
|
assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1 |
|
|
|
|
|
map_embeddings_and_outputs_kernel[(B,)]( |
|
|
reverse_mapping_ptr=reverse_mapping, |
|
|
mapping_write_pos_ptr=emb_begin_pos.clone(), |
|
|
indices_ptr=indices, |
|
|
K=K, |
|
|
num_warps=1, |
|
|
) |
|
|
|
|
|
weight_grad = torch.empty_like(weight, dtype=torch.float32) |
|
|
vals_grad = torch.zeros_like(vals, dtype=torch.float32) |
|
|
|
|
|
def _forward_grid(meta): |
|
|
return (F, D // meta["BLOCK_D"]) |
|
|
|
|
|
aggregate_gradient_for_embedding_kernel[_forward_grid]( |
|
|
weight_grad_ptr=weight_grad, |
|
|
vals_grad_ptr=vals_grad, |
|
|
weight_ptr=weight, |
|
|
emb_begin_pos_ptr=emb_begin_pos, |
|
|
reverse_mapping_ptr=reverse_mapping, |
|
|
vals_ptr=vals, |
|
|
gradient_ptr=gradient, |
|
|
D=D, |
|
|
K=K, |
|
|
BLOCK_D=256, |
|
|
num_warps=1, |
|
|
num_stages=2, |
|
|
) |
|
|
return weight_grad, vals_grad |
|
|
|
|
|
|
|
|
class xFormersEmbeddingBag(torch.autograd.Function): |
|
|
@staticmethod |
|
|
@torch.amp.custom_fwd(device_type="cuda") |
|
|
def forward( |
|
|
ctx, |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
ctx.save_for_backward(indices, weight, vals) |
|
|
return embedding_bag_forward(indices, weight, vals) |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.custom_bwd(device_type="cuda") |
|
|
def backward(ctx, gradient): |
|
|
indices, weight, vals = ctx.saved_tensors |
|
|
weight_g, vals_g = embedding_bag_backward( |
|
|
indices, |
|
|
weight, |
|
|
vals, |
|
|
gradient, |
|
|
) |
|
|
return None, weight_g, vals_g |
|
|
|
|
|
|
|
|
def triton_topk_sae_loss( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals) |
|
|
diff = recon.to(torch.float32) - target.to(torch.float32) |
|
|
loss = diff.pow(2).mean() |
|
|
return loss |
|
|
|
|
|
|
|
|
def topk_sae_loss( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
emb = weight[indices].to(torch.float32) |
|
|
recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1) |
|
|
diff = recon.to(torch.float32) - target.to(torch.float32) |
|
|
loss = diff.pow(2).mean() |
|
|
return loss |
|
|
|