Spaces:
Paused
Paused
Update models/controlnet.py
Browse files- models/controlnet.py +2 -44
models/controlnet.py
CHANGED
|
@@ -43,10 +43,9 @@ class ControlNetOutput(BaseOutput):
|
|
| 43 |
down_block_res_samples: Tuple[torch.Tensor]
|
| 44 |
mid_block_res_sample: torch.Tensor
|
| 45 |
|
| 46 |
-
|
| 47 |
class ControlNetConditioningEmbedding(nn.Module):
|
| 48 |
"""
|
| 49 |
-
"""
|
| 50 |
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
| 51 |
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
| 52 |
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
|
@@ -54,7 +53,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
|
| 54 |
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
| 55 |
model) to encode image-space conditions ... into feature maps ..."
|
| 56 |
"""
|
| 57 |
-
|
| 58 |
|
| 59 |
def __init__(
|
| 60 |
self,
|
|
@@ -89,48 +88,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
|
| 89 |
embedding = self.conv_out(embedding)
|
| 90 |
|
| 91 |
return embedding
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
class ControlNetConditioningEmbedding(nn.Module):
|
| 95 |
-
def __init__(
|
| 96 |
-
self,
|
| 97 |
-
conditioning_embedding_channels: int,
|
| 98 |
-
conditioning_channels: int = 3,
|
| 99 |
-
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
| 100 |
-
):
|
| 101 |
-
super().__init__()
|
| 102 |
|
| 103 |
-
self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
| 104 |
-
self.bn_in = nn.BatchNorm3d(block_out_channels[0])
|
| 105 |
-
|
| 106 |
-
self.blocks = nn.ModuleList([])
|
| 107 |
-
self.bns = nn.ModuleList([])
|
| 108 |
-
|
| 109 |
-
for i in range(len(block_out_channels) - 1):
|
| 110 |
-
channel_in = block_out_channels[i]
|
| 111 |
-
channel_out = block_out_channels[i + 1]
|
| 112 |
-
self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
| 113 |
-
self.bns.append(nn.BatchNorm3d(channel_in))
|
| 114 |
-
self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
| 115 |
-
self.bns.append(nn.BatchNorm3d(channel_out))
|
| 116 |
-
|
| 117 |
-
self.conv_out = zero_module(
|
| 118 |
-
InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
def forward(self, conditioning):
|
| 122 |
-
embedding = self.conv_in(conditioning)
|
| 123 |
-
embedding = self.bn_in(embedding)
|
| 124 |
-
embedding = F.silu(embedding)
|
| 125 |
-
|
| 126 |
-
for block, bn in zip(self.blocks, self.bns):
|
| 127 |
-
embedding = block(embedding)
|
| 128 |
-
embedding = bn(embedding)
|
| 129 |
-
embedding = F.silu(embedding)
|
| 130 |
-
|
| 131 |
-
embedding = self.conv_out(embedding)
|
| 132 |
-
|
| 133 |
-
return embedding
|
| 134 |
|
| 135 |
class ControlNetModel3D(ModelMixin, ConfigMixin):
|
| 136 |
_supports_gradient_checkpointing = True
|
|
|
|
| 43 |
down_block_res_samples: Tuple[torch.Tensor]
|
| 44 |
mid_block_res_sample: torch.Tensor
|
| 45 |
|
| 46 |
+
|
| 47 |
class ControlNetConditioningEmbedding(nn.Module):
|
| 48 |
"""
|
|
|
|
| 49 |
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
| 50 |
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
| 51 |
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
|
|
|
| 53 |
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
| 54 |
model) to encode image-space conditions ... into feature maps ..."
|
| 55 |
"""
|
| 56 |
+
|
| 57 |
|
| 58 |
def __init__(
|
| 59 |
self,
|
|
|
|
| 88 |
embedding = self.conv_out(embedding)
|
| 89 |
|
| 90 |
return embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
class ControlNetModel3D(ModelMixin, ConfigMixin):
|
| 94 |
_supports_gradient_checkpointing = True
|