|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def hierarchical_sae_forward_kernel( |
|
|
loss_per_batch_ptr, |
|
|
final_recon_ptr, |
|
|
indices_ptr, |
|
|
weight_ptr, |
|
|
bias_ptr, |
|
|
vals_ptr, |
|
|
target_ptr, |
|
|
B: tl.constexpr, |
|
|
D: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
BLOCK_D: tl.constexpr, |
|
|
LOOP_NUM_STAGES: tl.constexpr, |
|
|
BLOCK_B: tl.constexpr, |
|
|
): |
|
|
tl.static_assert((D % BLOCK_D) == 0) |
|
|
tl.static_assert((B % BLOCK_B) == 0) |
|
|
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2") |
|
|
|
|
|
pid_b = tl.program_id(axis=0).to(tl.int64) |
|
|
pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
|
|
|
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
|
|
batch_offsets = batch_offsets.to(tl.int64) |
|
|
tl.multiple_of(batch_offsets, BLOCK_B) |
|
|
|
|
|
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
offset_d = offset_d.to(tl.int64) |
|
|
|
|
|
tl.multiple_of(offset_d, BLOCK_D) |
|
|
tl.max_contiguous(offset_d, BLOCK_D) |
|
|
|
|
|
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :] |
|
|
|
|
|
bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32) |
|
|
|
|
|
recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
|
|
recon += bias_tile[None, :] |
|
|
|
|
|
target = tl.load(target_ptr + batch_d_offset).to(tl.float32) |
|
|
|
|
|
loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
|
|
|
|
|
row_idx_ptr = indices_ptr + batch_offsets * K |
|
|
row_val_ptr = vals_ptr + batch_offsets * K |
|
|
|
|
|
idx = tl.load(row_idx_ptr).to(tl.int64) |
|
|
val = tl.load(row_val_ptr).to(tl.float32) |
|
|
val = val[:, None] |
|
|
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
|
|
|
for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES): |
|
|
recon += weight_tile * val |
|
|
diff = recon - target |
|
|
loss_accum += diff * diff |
|
|
|
|
|
if t + 1 < K: |
|
|
idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64) |
|
|
val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32) |
|
|
weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
|
|
|
idx = idx_next |
|
|
val = val_next[:, None] |
|
|
weight_tile = weight_next |
|
|
|
|
|
loss_tile = tl.sum(loss_accum, axis=1) |
|
|
tl.atomic_add( |
|
|
loss_per_batch_ptr + batch_offsets, |
|
|
loss_tile, |
|
|
sem="relaxed", |
|
|
) |
|
|
tl.store( |
|
|
final_recon_ptr + batch_d_offset, |
|
|
recon, |
|
|
) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def hierarchical_sae_backward_kernel( |
|
|
weight_grad_ptr, |
|
|
vals_grad_ptr, |
|
|
bias_grad_ptr, |
|
|
final_recon_ptr, |
|
|
indices_ptr, |
|
|
weight_ptr, |
|
|
vals_ptr, |
|
|
target_ptr, |
|
|
B: tl.constexpr, |
|
|
D: tl.constexpr, |
|
|
K: tl.constexpr, |
|
|
BLOCK_D: tl.constexpr, |
|
|
LOOP_NUM_STAGES: tl.constexpr, |
|
|
BLOCK_B: tl.constexpr, |
|
|
): |
|
|
tl.static_assert((D % BLOCK_D) == 0) |
|
|
tl.static_assert((B % BLOCK_B) == 0) |
|
|
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
|
|
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2") |
|
|
|
|
|
pid_b = tl.program_id(axis=0).to(tl.int64) |
|
|
pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
|
|
|
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
|
|
batch_offsets = batch_offsets.to(tl.int64) |
|
|
tl.multiple_of(batch_offsets, BLOCK_B) |
|
|
|
|
|
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
|
|
offset_d = offset_d.to(tl.int64) |
|
|
|
|
|
tl.multiple_of(offset_d, BLOCK_D) |
|
|
tl.max_contiguous(offset_d, BLOCK_D) |
|
|
|
|
|
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :] |
|
|
|
|
|
recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32) |
|
|
target = tl.load(target_ptr + batch_d_offset).to(tl.float32) |
|
|
|
|
|
suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
|
|
bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
|
|
scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32) |
|
|
|
|
|
row_idx_ptr = indices_ptr + batch_offsets * K |
|
|
row_val_ptr = vals_ptr + batch_offsets * K |
|
|
k_offsets = tl.arange(0, K) |
|
|
val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32) |
|
|
|
|
|
step = K - 1 |
|
|
idx = tl.load(row_idx_ptr + step).to(tl.int64) |
|
|
val = tl.load(row_val_ptr + step).to(tl.float32) |
|
|
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
|
|
|
for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES): |
|
|
curr_step = step |
|
|
|
|
|
diff = recon - target |
|
|
grad_curr = diff * scale |
|
|
suffix += grad_curr |
|
|
bias_accum += grad_curr |
|
|
|
|
|
val_broadcast = val[:, None] |
|
|
contrib = suffix * val_broadcast |
|
|
tl.atomic_add( |
|
|
weight_grad_ptr + idx[:, None] * D + offset_d[None, :], |
|
|
contrib, |
|
|
sem="relaxed", |
|
|
) |
|
|
|
|
|
dot_partial = tl.sum(weight_tile * suffix, axis=1) |
|
|
mask_curr = k_offsets[None, :] == curr_step |
|
|
val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile) |
|
|
|
|
|
recon -= weight_tile * val_broadcast |
|
|
|
|
|
if curr_step > 0: |
|
|
step = curr_step - 1 |
|
|
idx = tl.load(row_idx_ptr + step).to(tl.int64) |
|
|
val = tl.load(row_val_ptr + step).to(tl.float32) |
|
|
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
|
|
|
bias_grad_tile = tl.sum(bias_accum, axis=0) |
|
|
tl.atomic_add( |
|
|
bias_grad_ptr + offset_d, |
|
|
bias_grad_tile, |
|
|
sem="relaxed", |
|
|
) |
|
|
|
|
|
row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :] |
|
|
tl.atomic_add( |
|
|
row_val_grad_ptr, |
|
|
val_grad_tile, |
|
|
sem="relaxed", |
|
|
) |
|
|
|
|
|
|
|
|
def _hierarchical_sae_forward( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
B, K = indices.shape |
|
|
F, D = weight.shape |
|
|
|
|
|
loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device) |
|
|
final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device) |
|
|
|
|
|
def _forward_grid(meta): |
|
|
return ( |
|
|
B // meta["BLOCK_B"], |
|
|
D // meta["BLOCK_D"], |
|
|
) |
|
|
|
|
|
hierarchical_sae_forward_kernel[_forward_grid]( |
|
|
loss_per_batch, |
|
|
final_recon, |
|
|
indices, |
|
|
weight, |
|
|
bias, |
|
|
vals, |
|
|
target, |
|
|
B=B, |
|
|
D=D, |
|
|
K=K, |
|
|
BLOCK_D=64, |
|
|
LOOP_NUM_STAGES=4, |
|
|
BLOCK_B=1, |
|
|
num_warps=2, |
|
|
num_stages=2, |
|
|
) |
|
|
loss = loss_per_batch.sum() / (B * K * D) |
|
|
return loss, final_recon |
|
|
|
|
|
|
|
|
def _hierarchical_sae_backward( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
final_recon: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
device = weight.device |
|
|
B, K = indices.shape |
|
|
F, D = weight.shape |
|
|
|
|
|
dW = torch.zeros((F, D), dtype=torch.float32, device=device) |
|
|
dVals = torch.zeros((B, K), dtype=torch.float32, device=device) |
|
|
db = torch.zeros((D,), dtype=torch.float32, device=device) |
|
|
|
|
|
def _backward_grid(meta): |
|
|
return ( |
|
|
B // meta["BLOCK_B"], |
|
|
D // meta["BLOCK_D"], |
|
|
) |
|
|
|
|
|
hierarchical_sae_backward_kernel[_backward_grid]( |
|
|
dW, |
|
|
dVals, |
|
|
db, |
|
|
final_recon, |
|
|
indices, |
|
|
weight, |
|
|
vals, |
|
|
target, |
|
|
B=B, |
|
|
D=D, |
|
|
K=K, |
|
|
BLOCK_D=32, |
|
|
LOOP_NUM_STAGES=16, |
|
|
BLOCK_B=16, |
|
|
num_warps=8, |
|
|
num_stages=8, |
|
|
) |
|
|
|
|
|
return dW, dVals, db |
|
|
|
|
|
|
|
|
class HierarchicalSAELossFunction(torch.autograd.Function): |
|
|
@staticmethod |
|
|
@torch.amp.custom_fwd(device_type="cuda") |
|
|
def forward( |
|
|
ctx, |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
): |
|
|
loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target) |
|
|
ctx.save_for_backward(indices, weight, vals, target, final_recon) |
|
|
return loss |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.custom_bwd(device_type="cuda") |
|
|
def backward(ctx, grad): |
|
|
indices, weight, vals, target, final_recon = ctx.saved_tensors |
|
|
dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon) |
|
|
|
|
|
if grad is not None: |
|
|
dW.mul_(grad) |
|
|
dVals.mul_(grad) |
|
|
db.mul_(grad) |
|
|
|
|
|
return None, dW, dVals, db, None |
|
|
|
|
|
|
|
|
def triton_hierarchical_sae_loss( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target) |
|
|
|
|
|
|
|
|
def hierarchical_sae_loss( |
|
|
indices: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
vals: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
emb = weight[indices].to(torch.float32) |
|
|
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1) |
|
|
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1) |
|
|
loss = diff.pow(2).mean() |
|
|
return loss |
|
|
|