Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from typing import List | |
| from diffusers.models.embeddings import Timesteps, TimestepEmbedding | |
| # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py | |
| def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: | |
| assert dim % 2 == 0, "The dimension must be even." | |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
| omega = 1.0 / (theta**scale) | |
| batch_size, seq_length = pos.shape | |
| out = torch.einsum("...n,d->...nd", pos, omega) | |
| cos_out = torch.cos(out) | |
| sin_out = torch.sin(out) | |
| stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) | |
| out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) | |
| return out.float() | |
| # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py | |
| class EmbedND(nn.Module): | |
| def __init__(self, theta: int, axes_dim: List[int]): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| n_axes = ids.shape[-1] | |
| emb = torch.cat( | |
| [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], | |
| dim=-3, | |
| ) | |
| return emb.unsqueeze(2) | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size=2, | |
| in_channels=4, | |
| out_channels=1024, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.out_channels = out_channels | |
| self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, latent): | |
| latent = self.proj(latent) | |
| return latent | |
| class PooledEmbed(nn.Module): | |
| def __init__(self, text_emb_dim, hidden_size): | |
| super().__init__() | |
| self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, pooled_embed): | |
| return self.pooled_embedder(pooled_embed) | |
| class TimestepEmbed(nn.Module): | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, timesteps, wdtype): | |
| t_emb = self.time_proj(timesteps).to(dtype=wdtype) | |
| t_emb = self.timestep_embedder(t_emb) | |
| return t_emb | |
| class OutEmbed(nn.Module): | |
| def __init__(self, hidden_size, patch_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.zeros_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, adaln_input): | |
| shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) | |
| x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| x = self.linear(x) | |
| return x |