Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.distributions import Categorical | |
| import models.pos_encoding as pos_encoding | |
| class Text2Motion_Transformer(nn.Module): | |
| def __init__(self, | |
| num_vq=1024, | |
| embed_dim=512, | |
| clip_dim=512, | |
| block_size=16, | |
| num_layers=2, | |
| n_head=8, | |
| drop_out_rate=0.1, | |
| fc_rate=4): | |
| super().__init__() | |
| self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) | |
| self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) | |
| self.block_size = block_size | |
| self.num_vq = num_vq | |
| def get_block_size(self): | |
| return self.block_size | |
| def forward(self, idxs, clip_feature): | |
| feat = self.trans_base(idxs, clip_feature) | |
| logits = self.trans_head(feat) | |
| return logits | |
| def sample(self, clip_feature, if_categorial=False): | |
| for k in range(self.block_size): | |
| if k == 0: | |
| x = [] | |
| else: | |
| x = xs | |
| logits = self.forward(x, clip_feature) | |
| logits = logits[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| if if_categorial: | |
| dist = Categorical(probs) | |
| idx = dist.sample() | |
| if idx == self.num_vq: | |
| break | |
| idx = idx.unsqueeze(-1) | |
| else: | |
| _, idx = torch.topk(probs, k=1, dim=-1) | |
| if idx[0] == self.num_vq: | |
| break | |
| # append to the sequence and continue | |
| if k == 0: | |
| xs = idx | |
| else: | |
| xs = torch.cat((xs, idx), dim=1) | |
| if k == self.block_size - 1: | |
| return xs[:, :-1] | |
| return xs | |
| class CausalCrossConditionalSelfAttention(nn.Module): | |
| def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1): | |
| super().__init__() | |
| assert embed_dim % 8 == 0 | |
| # key, query, value projections for all heads | |
| self.key = nn.Linear(embed_dim, embed_dim) | |
| self.query = nn.Linear(embed_dim, embed_dim) | |
| self.value = nn.Linear(embed_dim, embed_dim) | |
| self.attn_drop = nn.Dropout(drop_out_rate) | |
| self.resid_drop = nn.Dropout(drop_out_rate) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| # causal mask to ensure that attention is only applied to the left in the input sequence | |
| self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) | |
| self.n_head = n_head | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_drop(att) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.resid_drop(self.proj(y)) | |
| return y | |
| class Block(nn.Module): | |
| def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(embed_dim) | |
| self.ln2 = nn.LayerNorm(embed_dim) | |
| self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(embed_dim, fc_rate * embed_dim), | |
| nn.GELU(), | |
| nn.Linear(fc_rate * embed_dim, embed_dim), | |
| nn.Dropout(drop_out_rate), | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class CrossCondTransBase(nn.Module): | |
| def __init__(self, | |
| num_vq=1024, | |
| embed_dim=512, | |
| clip_dim=512, | |
| block_size=16, | |
| num_layers=2, | |
| n_head=8, | |
| drop_out_rate=0.1, | |
| fc_rate=4): | |
| super().__init__() | |
| self.tok_emb = nn.Embedding(num_vq + 2, embed_dim) | |
| self.cond_emb = nn.Linear(clip_dim, embed_dim) | |
| self.pos_embedding = nn.Embedding(block_size, embed_dim) | |
| self.drop = nn.Dropout(drop_out_rate) | |
| # transformer block | |
| self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) | |
| self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False) | |
| self.block_size = block_size | |
| self.apply(self._init_weights) | |
| def get_block_size(self): | |
| return self.block_size | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, idx, clip_feature): | |
| if len(idx) == 0: | |
| token_embeddings = self.cond_emb(clip_feature).unsqueeze(1) | |
| else: | |
| b, t = idx.size() | |
| assert t <= self.block_size, "Cannot forward, model block size is exhausted." | |
| # forward the Trans model | |
| token_embeddings = self.tok_emb(idx) | |
| token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1) | |
| x = self.pos_embed(token_embeddings) | |
| x = self.blocks(x) | |
| return x | |
| class CrossCondTransHead(nn.Module): | |
| def __init__(self, | |
| num_vq=1024, | |
| embed_dim=512, | |
| block_size=16, | |
| num_layers=2, | |
| n_head=8, | |
| drop_out_rate=0.1, | |
| fc_rate=4): | |
| super().__init__() | |
| self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) | |
| self.ln_f = nn.LayerNorm(embed_dim) | |
| self.head = nn.Linear(embed_dim, num_vq + 1, bias=False) | |
| self.block_size = block_size | |
| self.apply(self._init_weights) | |
| def get_block_size(self): | |
| return self.block_size | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, x): | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| return logits | |