Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from .normalize import Normalize | |
| from .conv import CausalConv3d | |
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| from .block import Block | |
| from .ops import video_to_image | |
| class LinearAttention(Block): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super().__init__() | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x) | |
| q, k, v = rearrange( | |
| qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 | |
| ) | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum("bhdn,bhen->bhde", k, v) | |
| out = torch.einsum("bhde,bhdn->bhen", context, q) | |
| out = rearrange( | |
| out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w | |
| ) | |
| return self.to_out(out) | |
| class LinAttnBlock(LinearAttention): | |
| """to match AttnBlock usage""" | |
| def __init__(self, in_channels): | |
| super().__init__(dim=in_channels, heads=1, dim_head=in_channels) | |
| class AttnBlock3D(Block): | |
| """Compatible with old versions, there are issues, use with caution.""" | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, t, h, w = q.shape | |
| q = q.reshape(b * t, c, h * w) | |
| q = q.permute(0, 2, 1) # b,hw,c | |
| k = k.reshape(b * t, c, h * w) # b,c,hw | |
| w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b * t, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, t, h, w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class AttnBlock3DFix(nn.Module): | |
| """ | |
| Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. | |
| """ | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) | |
| b, c, t, h, w = q.shape | |
| q = q.permute(0, 2, 1, 3, 4) | |
| q = q.reshape(b * t, c, h * w) | |
| q = q.permute(0, 2, 1) | |
| # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) | |
| k = k.permute(0, 2, 1, 3, 4) | |
| k = k.reshape(b * t, c, h * w) | |
| # w: (b*t hw hw) | |
| w_ = torch.bmm(q, k) | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| # v: (b c t h w) -> (b t c h w) -> (bt c hw) | |
| # w_: (bt hw hw) -> (bt hw hw) | |
| v = v.permute(0, 2, 1, 3, 4) | |
| v = v.reshape(b * t, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) | |
| h_ = h_.reshape(b, t, c, h, w) | |
| h_ = h_.permute(0, 2, 1, 3 ,4) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class AttnBlock(Block): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.v = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.proj_out = torch.nn.Conv2d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, h, w = q.shape | |
| q = q.reshape(b, c, h * w) | |
| q = q.permute(0, 2, 1) # b,hw,c | |
| k = k.reshape(b, c, h * w) # b,c,hw | |
| w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, h, w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class TemporalAttnBlock(Block): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = torch.nn.Conv3d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.v = torch.nn.Conv3d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.proj_out = torch.nn.Conv3d( | |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, t, h, w = q.shape | |
| q = rearrange(q, "b c t h w -> (b h w) t c") | |
| k = rearrange(k, "b c t h w -> (b h w) c t") | |
| v = rearrange(v, "b c t h w -> (b h w) c t") | |
| w_ = torch.bmm(q, k) | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| w_ = w_.permute(0, 2, 1) | |
| h_ = torch.bmm(v, w_) | |
| h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| def make_attn(in_channels, attn_type="vanilla"): | |
| assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" | |
| print(f"making attention of type '{attn_type}' with {in_channels} in_channels") | |
| print(attn_type) | |
| if attn_type == "vanilla": | |
| return AttnBlock(in_channels) | |
| elif attn_type == "vanilla3D": | |
| return AttnBlock3D(in_channels) | |
| elif attn_type == "none": | |
| return nn.Identity(in_channels) | |
| else: | |
| return LinAttnBlock(in_channels) |