STAR / fairseq /modules /kmeans_attention.py
Yixuan Li
add fairseq folder
85ba398
raw
history blame
21.8 kB
import math
from functools import reduce, wraps
from inspect import isfunction
from operator import mul
import torch
import torch.nn as nn
import torch.nn.functional as F
from aml.multimodal_video.utils.einops.lib import rearrange, repeat
from aml.multimodal_video.utils.einops.lib.layers.torch import Rearrange
from fairseq.modules.local_attention import LocalAttention
# constants
TOKEN_SELF_ATTN_VALUE = -5e4
KMEAN_INIT_ITERS = 10
# helper functions
def exists(val):
return val is not None
def identity(x, *args, **kwargs):
return x
def default(x, d):
if not exists(x):
return d if not isfunction(d) else d()
return x
def cast_tuple(x):
return x if isinstance(x, tuple) else (x,)
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, **kwargs):
nonlocal cache
if exists(cache):
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
def to(t):
return {"device": t.device, "dtype": t.dtype}
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
def is_empty(t):
return t.nelement() == 0
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(2, expand_dim(indices, -1, last_dim))
def merge_dims(ind_from, ind_to, tensor):
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)
def expand_dim(t, dim, k):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
def scatter_mean(src, t, index, dim, eps=1e-5):
numer = src.scatter_add(dim, index, t)
denom = src.scatter_add(dim, index, torch.ones_like(t))
return numer / (denom + eps)
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
def reshape_dim(t, dim, split_dims):
shape = list(t.shape)
num_dims = len(shape)
dim = (dim + num_dims) % num_dims
shape[dim : dim + 1] = split_dims
return t.reshape(shape)
def ema(old, new, decay):
if not exists(old):
return new
return old * decay + new * (1 - decay)
def ema_inplace(moving_avg, new, decay):
if is_empty(moving_avg):
moving_avg.data.copy_(new)
return
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
# helper classes
def map_first_tuple_or_el(x, fn):
if isinstance(x, tuple):
return (fn(x[0]),) + x[1:]
return fn(x)
class Chunk(nn.Module):
def __init__(self, chunks, fn, along_dim=-1):
super().__init__()
self.dim = along_dim
self.chunks = chunks
self.fn = fn
def forward(self, x, **kwargs):
if self.chunks <= 1:
return self.fn(x, **kwargs)
chunks = x.chunk(self.chunks, dim=self.dim)
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim)
class PreNorm(nn.ModuleList):
def __init__(self, norm_class, dim, fn):
super().__init__()
self.norm = norm_class(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class ReZero(nn.Module):
def __init__(self, fn):
super().__init__()
self.residual_weight = nn.Parameter(torch.zeros(1))
self.fn = fn
def forward(self, x, **kwargs):
x = self.fn(x, **kwargs)
return map_first_tuple_or_el(x, lambda t: t * self.residual_weight)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.g = nn.Parameter(torch.ones(1))
self.eps = eps
def forward(self, x):
def norm(t):
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
return t / n * self.g
return map_first_tuple_or_el(x, norm)
class ProjectInOut(nn.Module):
def __init__(self, fn, dim_in, dim_out, project_out=True):
super().__init__()
self.fn = fn
self.project_in = nn.Linear(dim_in, dim_out)
self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity
def forward(self, x, **kwargs):
x = self.project_in(x)
x, loss = self.fn(x, **kwargs)
x = self.project_out(x)
return x, loss
class MatrixMultiply(nn.Module):
def __init__(self, tensor, transpose=False):
super().__init__()
self.tensor = tensor
self.transpose = transpose
def forward(self, x):
tensor = self.tensor
if self.transpose:
tensor = tensor.t()
return x @ tensor
# positional embeddings
class DepthWiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False):
super().__init__()
self.padding = (
((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
)
self.net = nn.Sequential(
nn.Conv1d(
dim_in,
dim_in,
kernel_size=kernel_size,
groups=dim_in,
stride=stride,
bias=bias,
),
nn.Conv1d(dim_in, dim_out, 1, bias=bias),
)
def forward(self, x):
x = F.pad(x, self.padding, value=0.0)
return self.net(x)
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
position = torch.arange(0, max_seq_len, dtype=torch.float)
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
self.register_buffer("emb", emb)
def forward(self, x):
return self.emb[None, : x.shape[1], :].to(x)
def rotate_every_two(x):
x = rearrange(x, "... (d j) -> ... d j", j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d j -> ... (d j)")
def apply_rotary_pos_emb(q, k, sinu_pos):
sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2)
sin, cos = sinu_pos.unbind(dim=-2)
sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
return q, k
# kmeans related function and class
def update_kmeans_on_backwards(module):
module.kmean_modules = find_modules(module, Kmeans)
def hook(_, grad_in, grad_out):
for m in module.kmean_modules:
m.update()
return module.register_backward_hook(hook)
def similarity(x, means):
return torch.einsum("bhld,hcd->bhlc", x, means)
def dists_and_buckets(x, means):
dists = similarity(x, means)
_, buckets = torch.max(dists, dim=-1)
return dists, buckets
def batched_bincount(index, num_classes, dim=-1):
shape = list(index.shape)
shape[dim] = num_classes
out = index.new_zeros(shape)
out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))
return out
def kmeans_iter(x, means, buckets=None):
b, h, _, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1]
if not exists(buckets):
_, buckets = dists_and_buckets(x, means)
bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
zero_mask = bins.long() == 0
means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype)
means_.scatter_add_(-2, expand_dim(buckets, -1, d), x)
means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)
means = torch.where(zero_mask.unsqueeze(-1), means, means_)
means = means.squeeze(0)
return means
def distribution(dists, window_size):
_, topk_indices = dists.topk(k=window_size, dim=-2)
indices = topk_indices.transpose(-2, -1)
return indices.reshape(*indices.size()[:2], -1)
class Kmeans(nn.Module):
def __init__(
self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4
):
super().__init__()
self.commitment = commitment
self.ema_decay = ema_decay
self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim))
self.register_buffer("initted", torch.tensor(False))
self.num_new_means = 0
self.new_means = None
@torch.no_grad()
def init(self, x):
if self.initted:
return
_, h, _, d, device, _ = *x.shape, x.device, x.dtype
num_clusters = self.means.shape[1]
means = x.transpose(0, 1).contiguous().view(h, -1, d)
num_samples = means.shape[1]
if num_samples >= num_clusters:
indices = torch.randperm(num_samples, device=device)[:num_clusters]
else:
indices = torch.randint(0, num_samples, (num_clusters,), device=device)
means = means[:, indices]
for _ in range(KMEAN_INIT_ITERS):
means = kmeans_iter(x, means)
self.num_new_means = 0
self.means.data.copy_(means)
self.initted.data.copy_(torch.tensor(True))
@torch.no_grad()
def update(self, new_means=None):
new_means = default(new_means, self.new_means)
assert exists(new_means), "new kmeans has not been supplied"
ema_inplace(self.means, new_means, self.ema_decay)
del self.new_means
self.new_means = None
self.num_new_means = 0
def forward(self, x, update_means=False):
self.init(x)
b, dtype = x.shape[0], x.dtype
means = self.means.type(dtype)
x = F.normalize(x, 2, dim=-1).type(dtype)
with torch.no_grad():
dists, buckets = dists_and_buckets(x, means)
routed_means = batched_index_select(expand_dim(means, 0, b), buckets)
loss = F.mse_loss(x, routed_means) * self.commitment
if update_means:
with torch.no_grad():
means = kmeans_iter(x, means, buckets)
self.new_means = ema(
self.new_means, means, self.num_new_means / (self.num_new_means + 1)
)
self.num_new_means += 1
return dists, loss
# kmeans attention class
class KmeansAttention(nn.Module):
def __init__(
self,
num_clusters,
window_size,
num_heads,
head_dim,
causal=False,
dropout=0.0,
ema_decay=0.999,
commitment=1e-4,
context_window_size=None,
receives_context=False,
num_mem_kv=0,
shared_qk=False,
):
super().__init__()
self.num_heads = num_heads
self.num_clusters = num_clusters
self.head_dim = head_dim
self.window_size = window_size
self.context_window_size = default(context_window_size, window_size)
self.causal = causal
self.shared_qk = shared_qk
self.receives_context = receives_context
self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment)
self.dropout = nn.Dropout(dropout)
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
self.mem_key = nn.Parameter(
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
self.mem_value = nn.Parameter(
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs):
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = (
*q.shape,
k.shape[2],
self.window_size,
self.context_window_size,
self.num_clusters,
q.device,
q.dtype,
)
is_reverse = kwargs.pop("_reverse", False)
out = torch.zeros_like(q, dtype=dtype)
update_kmeans = self.training and not is_reverse
key_mask = (
default(key_mask, query_mask) if not self.receives_context else key_mask
)
kv_wsz = wsz if not self.receives_context else c_wsz
wsz = min(wsz, t)
kv_wsz = min(kv_wsz, kv_t)
if not self.shared_qk or self.receives_context:
dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
q_dists, k_dists = split_at_index(2, t, dists)
indices = distribution(q_dists, wsz)
kv_indices = distribution(k_dists, kv_wsz)
else:
dists, aux_loss = self.kmeans(q, update_kmeans)
k = F.normalize(k, dim=-1).to(q)
indices = distribution(dists, wsz)
kv_indices = indices
q = batched_index_select(q, indices)
k = batched_index_select(k, kv_indices)
v = batched_index_select(v, kv_indices)
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
q, k, v = map(reshape_with_window, (q, k, v))
m_k, m_v = map(
lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)
)
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))
dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d**-0.5)
mask_value = max_neg_value(dots)
if exists(query_mask) or exists(key_mask):
query_mask = default(
query_mask, lambda: torch.ones((b, t), device=device).bool()
)
key_mask = default(
key_mask, lambda: torch.ones((b, kv_t), device=device).bool()
)
q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask))
mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
dots.masked_fill_(~mask, mask_value)
del mask
if self.shared_qk:
q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=0)
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
del mask
dots = dots.softmax(dim=-1)
dots = self.dropout(dots)
bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v)
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
return out, aux_loss
# feedforward
class GELU_(nn.Module):
def forward(self, x):
return (
0.5
* x
* (
1
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
)
)
GELU = nn.GELU if hasattr(nn, "GELU") else GELU_
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False):
super().__init__()
activation = default(activation, GELU)
self.glu = glu
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation()
self.dropout = nn.Dropout(dropout)
self.w2 = nn.Linear(dim * mult, dim)
def forward(self, x, **kwargs):
if not self.glu:
x = self.w1(x)
x = self.act(x)
else:
x, v = self.w1(x).chunk(2, dim=-1)
x = self.act(x) * v
x = self.dropout(x)
x = self.w2(x)
return x
# self attention
class SelfAttention(nn.Module):
def __init__(
self,
dim,
max_seq_len,
heads,
local_attn_heads,
window_size,
dim_head=None,
local_attn_window_size=None,
local_attn_radius_blocks=1,
causal=False,
attn_dropout=0.0,
dropout=0.0,
kmeans_ema_decay=0.999,
commitment_factor=1e-4,
receives_context=False,
context_window_size=None,
rel_pos_emb=True,
num_mem_kv=0,
shared_qk=False,
conv_query_kernel=9,
):
super().__init__()
assert (
dim_head or (dim % heads) == 0
), "hidden dimension must be divisible by number of heads"
assert (
max_seq_len % window_size
) == 0, "maximum sequence length must be divisible by the target window size"
assert (
local_attn_heads <= heads
), "number of local attention heads must be less than total heads"
assert not (
receives_context and local_attn_heads > 0
), "local attention cannot be used for self attention with context"
assert not (
receives_context and causal
), "contextual attention layer cannot be causal"
local_attn_window_size = default(local_attn_window_size, window_size)
context_window_size = default(context_window_size, window_size)
self.shared_qk = shared_qk
self.receives_context = receives_context
self.heads = heads
self.local_attn_heads = local_attn_heads
self.global_attn_heads = heads - local_attn_heads
self.causal = causal
self.window_size = window_size
dim_head = default(dim_head, dim // heads)
dim_heads = dim_head * heads
self.dim_head = dim_head
num_clusters = max_seq_len // window_size
# local
local_dim_heads = dim_head * self.local_attn_heads
if self.local_attn_heads > 0:
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
self.local_attn = LocalAttention(
dim=dim_head,
window_size=local_attn_window_size,
causal=causal,
dropout=attn_dropout,
rel_pos_emb_config=rel_pos_emb_config,
look_backward=local_attn_radius_blocks,
look_forward=0 if causal else local_attn_radius_blocks,
)
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)
# global
global_dim_heads = dim_head * self.global_attn_heads
if self.global_attn_heads > 0:
self.global_attn = KmeansAttention(
num_clusters,
window_size,
self.global_attn_heads,
dim_head,
causal=causal,
dropout=attn_dropout,
ema_decay=kmeans_ema_decay,
commitment=commitment_factor,
receives_context=receives_context,
num_mem_kv=num_mem_kv,
shared_qk=shared_qk,
)
self.to_q = nn.Sequential(
Rearrange("b n c -> b c n"),
DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal),
Rearrange("b c n -> b n c"),
)
self.to_v = nn.Linear(dim, global_dim_heads, bias=False)
if not self.shared_qk:
self.to_k = nn.Linear(dim, global_dim_heads, bias=False)
# out
self.to_out = nn.Linear(dim_heads, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query,
key,
value,
context=None,
key_padding_mask=None,
context_mask=None,
pos_emb=None,
**kwargs
):
assert not (
self.receives_context and not exists(context)
), "context must be passed if self attention is set to receive context"
input_mask = key_padding_mask
x = query.transpose(0, 1)
b, t, _, h, dh = *x.shape, self.heads, self.dim_head
has_local, has_global = map(
lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)
)
split_heads = (
lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
)
if has_local:
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
lq, lk, lv = map(split_heads, local_qkv)
if has_global:
kv_input = x if not self.receives_context else context
q, v = self.to_q(x), self.to_v(kv_input)
if not self.shared_qk:
k = self.to_k(kv_input)
else:
k = self.to_q(kv_input) if self.receives_context else q
q, k, v = map(split_heads, (q, k, v))
out = []
total_loss = torch.tensor(0.0, requires_grad=True, **to(x))
if has_local:
local_out = self.local_attn(lq, lk, lv, input_mask=input_mask)
out.append(local_out)
if has_global:
if not self.receives_context and exists(pos_emb):
q, k = apply_rotary_pos_emb(q, k, pos_emb)
global_out, loss = self.global_attn(
q, k, v, query_mask=input_mask, key_mask=context_mask
)
total_loss = total_loss + loss
out.append(global_out)
out = torch.cat(out, dim=1)
out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
out = self.dropout(out.transpose(0, 1))
# out = self.to_out(out)
return out, total_loss