Spaces:
Runtime error
Runtime error
| from functools import partial | |
| from typing import List, Optional, Union | |
| import torch | |
| from einops import rearrange, repeat | |
| from sgm.modules.attention import checkpoint, exists | |
| from sgm.modules.diffusionmodules.util import timestep_embedding | |
| ### VideoUnet ##### | |
| def forward_VideoUnet( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| context: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| time_context: Optional[torch.Tensor] = None, | |
| num_video_frames: Optional[int] = None, | |
| image_only_indicator: Optional[torch.Tensor] = None, | |
| RT: Optional[torch.Tensor] = None | |
| ): | |
| if RT is not None: | |
| context = {'RT': RT, 'context': context} | |
| assert (y is not None) == ( | |
| self.num_classes is not None | |
| ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" | |
| hs = [] | |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
| emb = self.time_embed(t_emb) | |
| ## tbd: check the role of "image_only_indicator" | |
| num_video_frames = self.num_frames | |
| image_only_indicator = torch.zeros( | |
| x.shape[0]//num_video_frames, num_video_frames | |
| ).to(x.device) if image_only_indicator is None else image_only_indicator | |
| if self.num_classes is not None: | |
| assert y.shape[0] == x.shape[0] | |
| emb = emb + self.label_emb(y) | |
| h = x | |
| for module in self.input_blocks: | |
| h = module( | |
| h, | |
| emb, | |
| context=context, | |
| image_only_indicator=image_only_indicator, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames | |
| ) | |
| hs.append(h) | |
| h = self.middle_block( | |
| h, | |
| emb, | |
| context=context, | |
| image_only_indicator=image_only_indicator, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames | |
| ) | |
| for module in self.output_blocks: | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module( | |
| h, | |
| emb, | |
| context=context, | |
| image_only_indicator=image_only_indicator, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames | |
| ) | |
| h = h.type(x.dtype) | |
| return self.out(h) | |
| ### VideoTransformerBlock ##### | |
| def forward_VideoTransformerBlock(self, x, context, timesteps): | |
| if self.checkpoint: | |
| return checkpoint(self._forward, x, context, timesteps) | |
| else: | |
| return self._forward(x, context, timesteps=timesteps) | |
| def _forward_VideoTransformerBlock_attan2(self, x, context=None, timesteps=None): | |
| assert self.timesteps or timesteps | |
| assert not (self.timesteps and timesteps) or self.timesteps == timesteps | |
| timesteps = self.timesteps or timesteps | |
| B, S, C = x.shape | |
| x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) | |
| if isinstance(context, dict): | |
| RT = context['RT'] # (b, t, 12) | |
| context = context['context'] | |
| else: | |
| RT = None | |
| if self.ff_in: | |
| x_skip = x | |
| x = self.ff_in(self.norm_in(x)) | |
| if self.is_res: | |
| x += x_skip | |
| if self.disable_self_attn: | |
| x = self.attn1(self.norm1(x), context=context) + x | |
| else: | |
| x = self.attn1(self.norm1(x)) + x | |
| if RT is not None: | |
| # import pdb; pdb.set_trace() | |
| RT = RT.repeat_interleave(repeats=S, dim=0) # (b*s, t, 12) | |
| x = torch.cat([x, RT], dim=-1) | |
| x = self.cc_projection(x) | |
| if self.attn2 is not None: | |
| if self.switch_temporal_ca_to_sa: | |
| x = self.attn2(self.norm2(x)) + x | |
| else: | |
| x = self.attn2(self.norm2(x), context=context) + x | |
| x_skip = x | |
| x = self.ff(self.norm3(x)) | |
| if self.is_res: | |
| x += x_skip | |
| x = rearrange( | |
| x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps | |
| ) | |
| return x | |
| #### BasicTransformerBlock ##### | |
| def _forward_BasicTransformerBlock( | |
| self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 | |
| ): | |
| if isinstance(context, dict): | |
| context = context['context'] | |
| x = ( | |
| self.attn1( | |
| self.norm1(x), | |
| context=context if self.disable_self_attn else None, | |
| additional_tokens=additional_tokens, | |
| n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self | |
| if not self.disable_self_attn | |
| else 0, | |
| ) | |
| + x | |
| ) | |
| x = ( | |
| self.attn2( | |
| self.norm2(x), context=context, additional_tokens=additional_tokens | |
| ) | |
| + x | |
| ) | |
| x = self.ff(self.norm3(x)) + x | |
| return x | |
| #### SpatialVideoTransformer ##### | |
| def forward_SpatialVideoTransformer( | |
| self, | |
| x: torch.Tensor, | |
| context: Optional[torch.Tensor] = None, | |
| time_context: Optional[torch.Tensor] = None, | |
| timesteps: Optional[int] = None, | |
| image_only_indicator: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| _, _, h, w = x.shape | |
| x_in = x | |
| if isinstance(context, dict): | |
| RT = context['RT'] | |
| context = context['context'] | |
| else: | |
| RT = None | |
| spatial_context = None | |
| if exists(context): | |
| spatial_context = context | |
| if self.use_spatial_context: | |
| assert ( | |
| context.ndim == 3 | |
| ), f"n dims of spatial context should be 3 but are {context.ndim}" | |
| time_context = context | |
| time_context_first_timestep = time_context[::timesteps] | |
| time_context = repeat( | |
| time_context_first_timestep, "b ... -> (b n) ...", n=h * w | |
| ) | |
| elif time_context is not None and not self.use_spatial_context: | |
| time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) | |
| if time_context.ndim == 2: | |
| time_context = rearrange(time_context, "b c -> b 1 c") | |
| x = self.norm(x) | |
| if not self.use_linear: | |
| x = self.proj_in(x) | |
| x = rearrange(x, "b c h w -> b (h w) c") | |
| if self.use_linear: | |
| x = self.proj_in(x) | |
| num_frames = torch.arange(timesteps, device=x.device) | |
| num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) | |
| num_frames = rearrange(num_frames, "b t -> (b t)") | |
| t_emb = timestep_embedding( | |
| num_frames, | |
| self.in_channels, | |
| repeat_only=False, | |
| max_period=self.max_time_embed_period, | |
| ) | |
| emb = self.time_pos_embed(t_emb) | |
| emb = emb[:, None, :] | |
| for it_, (block, mix_block) in enumerate( | |
| zip(self.transformer_blocks, self.time_stack) | |
| ): | |
| x = block( | |
| x, | |
| context=spatial_context, | |
| ) | |
| x_mix = x | |
| x_mix = x_mix + emb | |
| if RT is not None: | |
| x_mix = mix_block(x_mix, context={'context': time_context, 'RT': RT}, timesteps=timesteps) | |
| else: | |
| x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) | |
| x = self.time_mixer( | |
| x_spatial=x, | |
| x_temporal=x_mix, | |
| image_only_indicator=image_only_indicator, | |
| ) | |
| if self.use_linear: | |
| x = self.proj_out(x) | |
| x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
| if not self.use_linear: | |
| x = self.proj_out(x) | |
| out = x + x_in | |
| return out |