import torch import torch.nn.functional as F def add_first_frame_conditioning( latent_model_input, first_frame, vae ): """ Adds first frame conditioning to a video diffusion model input. Args: latent_model_input: Original latent input (bs, channels, num_frames, height, width) first_frame: Tensor of first frame to condition on (bs, channels, height, width) vae: VAE model for encoding the conditioning Returns: conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width) """ device = latent_model_input.device dtype = latent_model_input.dtype vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample) # Get number of frames from latent model input _, _, num_latent_frames, _, _ = latent_model_input.shape # Calculate original number of frames # For n original frames, there are (n-1)//4 + 1 latent frames # So to get n: n = (num_latent_frames-1)*4 + 1 num_frames = (num_latent_frames - 1) * 4 + 1 if len(first_frame.shape) == 3: # we have a single image first_frame = first_frame.unsqueeze(0) # if it doesnt match the batch size, we need to expand it if first_frame.shape[0] != latent_model_input.shape[0]: first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) # resize first frame to match the latent model input vae_scale_factor = vae.config.scale_factor_spatial first_frame = F.interpolate( first_frame, size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), mode='bilinear', align_corners=False ) # Add temporal dimension to first frame first_frame = first_frame.unsqueeze(2) # Create video condition with first frame and zeros for remaining frames zero_frame = torch.zeros_like(first_frame) video_condition = torch.cat([ first_frame, *[zero_frame for _ in range(num_frames - 1)] ], dim=2) # Prepare for VAE encoding (bs, channels, num_frames, height, width) # video_condition = video_condition.permute(0, 2, 1, 3, 4) # Encode with VAE latent_condition = vae.encode( video_condition.to(device, dtype) ).latent_dist.sample() latent_condition = latent_condition.to(device, dtype) latents_mean = ( torch.tensor(vae.config.latents_mean) .view(1, vae.config.z_dim, 1, 1, 1) .to(device, dtype) ) latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( device, dtype ) latent_condition = (latent_condition - latents_mean) * latents_std # Create mask: 1 for conditioning frames, 0 for frames to generate batch_size = first_frame.shape[0] latent_height = latent_condition.shape[3] latent_width = latent_condition.shape[4] # Initialize mask for all frames mask_lat_size = torch.ones( batch_size, 1, num_frames, latent_height, latent_width) # Set all non-first frames to 0 mask_lat_size[:, :, list(range(1, num_frames))] = 0 # Special handling for first frame first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave( first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) # Combine first frame mask with rest mask_lat_size = torch.concat( [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) # Reshape and transpose for model input mask_lat_size = mask_lat_size.view( batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(device, dtype) # Combine conditioning with latent input first_frame_condition = torch.concat( [mask_lat_size, latent_condition], dim=1) conditioned_latent = torch.cat( [latent_model_input, first_frame_condition], dim=1) return conditioned_latent def add_first_frame_conditioning_v22( latent_model_input, first_frame, vae, last_frame=None ): """ Overwrites first few time steps in latent_model_input with VAE-encoded first_frame, and returns the modified latent + binary mask (0=conditioned, 1=noise). Args: latent_model_input: torch.Tensor of shape (bs, 48, T, H, W) first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale) vae: VAE model with .encode() and .config.latents_mean/std Returns: latent: (bs, 48, T, H, W) - modified input latent mask: (bs, 1, T, H, W) - binary mask """ device = latent_model_input.device dtype = latent_model_input.dtype bs, _, T, H, W = latent_model_input.shape scale = vae.config.scale_factor_spatial target_h = H * scale target_w = W * scale # Ensure shape if first_frame.ndim == 3: first_frame = first_frame.unsqueeze(0) if first_frame.shape[0] != bs: first_frame = first_frame.expand(bs, -1, -1, -1) # Resize and encode first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W) encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device) # Normalize mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype) std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype) encoded = (encoded - mean) * std # Replace in latent latent = latent_model_input.clone() latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0] # Mask: 0 where conditioned, 1 otherwise mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype) mask[:, :, :encoded.shape[2]] = 0.0 if last_frame is not None: # If last_frame is provided, encode it similarly last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) last_frame_up = last_frame_up.unsqueeze(2) last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device) last_encoded = (last_encoded - mean) * std latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last mask[:, :, -last_encoded.shape[2]:] = 0.0 # # Ensure mask is still binary mask = mask.clamp(0.0, 1.0) return latent, mask