Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bccf74a
1
Parent(s):
566ec8f
Extend CLIP text encoder to support 97 tokens
Browse files- adaface/adaface_wrapper.py +21 -3
- adaface/util.py +20 -0
- app.py +5 -1
adaface/adaface_wrapper.py
CHANGED
|
@@ -14,7 +14,7 @@ from diffusers import (
|
|
| 14 |
LCMScheduler,
|
| 15 |
)
|
| 16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
| 17 |
-
from adaface.util import UNetEnsemble
|
| 18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
| 19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
| 20 |
from safetensors.torch import load_file as safetensors_load_file
|
|
@@ -27,7 +27,7 @@ class AdaFaceWrapper(nn.Module):
|
|
| 27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
| 28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
| 29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
| 30 |
-
use_840k_vae=False, use_ds_text_encoder=False,
|
| 31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
| 32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
| 33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
|
@@ -56,6 +56,9 @@ class AdaFaceWrapper(nn.Module):
|
|
| 56 |
|
| 57 |
self.default_scheduler_name = default_scheduler_name
|
| 58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
|
|
|
|
|
|
|
|
|
| 59 |
self.use_840k_vae = use_840k_vae
|
| 60 |
self.use_ds_text_encoder = use_ds_text_encoder
|
| 61 |
self.main_unet_filepath = main_unet_filepath
|
|
@@ -199,6 +202,21 @@ class AdaFaceWrapper(nn.Module):
|
|
| 199 |
|
| 200 |
pipeline.unet = unet2
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
if self.use_840k_vae:
|
| 203 |
pipeline.vae = vae
|
| 204 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
@@ -715,7 +733,7 @@ class AdaFaceWrapper(nn.Module):
|
|
| 715 |
ref_img_strength=0.8, generator=None,
|
| 716 |
ablate_prompt_only_placeholders=False,
|
| 717 |
ablate_prompt_no_placeholders=False,
|
| 718 |
-
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', '
|
| 719 |
nonmix_prompt_emb_weight=0,
|
| 720 |
repeat_prompt_for_each_encoder=True,
|
| 721 |
verbose=False):
|
|
|
|
| 14 |
LCMScheduler,
|
| 15 |
)
|
| 16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
| 17 |
+
from adaface.util import UNetEnsemble, extend_nn_embedding
|
| 18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
| 19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
| 20 |
from safetensors.torch import load_file as safetensors_load_file
|
|
|
|
| 27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
| 28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
| 29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
| 30 |
+
max_prompt_length=77, use_840k_vae=False, use_ds_text_encoder=False,
|
| 31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
| 32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
| 33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
|
|
|
| 56 |
|
| 57 |
self.default_scheduler_name = default_scheduler_name
|
| 58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
| 59 |
+
|
| 60 |
+
self.max_prompt_length = max_prompt_length
|
| 61 |
+
|
| 62 |
self.use_840k_vae = use_840k_vae
|
| 63 |
self.use_ds_text_encoder = use_ds_text_encoder
|
| 64 |
self.main_unet_filepath = main_unet_filepath
|
|
|
|
| 202 |
|
| 203 |
pipeline.unet = unet2
|
| 204 |
|
| 205 |
+
# Extending prompt length is for SD 1.5 only.
|
| 206 |
+
if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
|
| 207 |
+
# pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
|
| 208 |
+
# We reuse the last EL position embeddings for the new position embeddings.
|
| 209 |
+
# If we use the "neat" way, i.e., initialize CLIPTextModel with a CLIPTextConfig with
|
| 210 |
+
# a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
|
| 211 |
+
# then the old position embeddings won't be loaded from the pretrained ckpt,
|
| 212 |
+
# leading to degenerated performance.
|
| 213 |
+
EL = self.max_prompt_length - 77
|
| 214 |
+
# position_embedding.weight: [77, 768] -> [max_length, 768]
|
| 215 |
+
new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
|
| 216 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
|
| 217 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
|
| 218 |
+
pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
|
| 219 |
+
|
| 220 |
if self.use_840k_vae:
|
| 221 |
pipeline.vae = vae
|
| 222 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
|
| 733 |
ref_img_strength=0.8, generator=None,
|
| 734 |
ablate_prompt_only_placeholders=False,
|
| 735 |
ablate_prompt_no_placeholders=False,
|
| 736 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img1', 'img2'.
|
| 737 |
nonmix_prompt_emb_weight=0,
|
| 738 |
repeat_prompt_for_each_encoder=True,
|
| 739 |
verbose=False):
|
adaface/util.py
CHANGED
|
@@ -73,6 +73,26 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
|
|
| 73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# Revised from RevGrad, by removing the grad negation.
|
| 77 |
class ScaleGrad(torch.autograd.Function):
|
| 78 |
@staticmethod
|
|
|
|
| 73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
| 74 |
|
| 75 |
|
| 76 |
+
# new_token_embeddings: [new_num_tokens, 768].
|
| 77 |
+
def extend_nn_embedding(old_nn_embedding, new_token_embeddings):
|
| 78 |
+
emb_dim = old_nn_embedding.embedding_dim
|
| 79 |
+
num_old_tokens = old_nn_embedding.num_embeddings
|
| 80 |
+
num_new_tokens = new_token_embeddings.shape[0]
|
| 81 |
+
num_tokens2 = num_old_tokens + num_new_tokens
|
| 82 |
+
|
| 83 |
+
new_nn_embedding = nn.Embedding(num_tokens2, emb_dim,
|
| 84 |
+
device=old_nn_embedding.weight.device,
|
| 85 |
+
dtype=old_nn_embedding.weight.dtype)
|
| 86 |
+
|
| 87 |
+
old_num_tokens = old_nn_embedding.weight.shape[0]
|
| 88 |
+
# Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding.
|
| 89 |
+
new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data
|
| 90 |
+
# Copy the new embeddings to new_nn_embedding.
|
| 91 |
+
new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings
|
| 92 |
+
|
| 93 |
+
print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.")
|
| 94 |
+
return new_nn_embedding
|
| 95 |
+
|
| 96 |
# Revised from RevGrad, by removing the grad negation.
|
| 97 |
class ScaleGrad(torch.autograd.Function):
|
| 98 |
@staticmethod
|
app.py
CHANGED
|
@@ -34,6 +34,8 @@ parser.add_argument('--num_inference_steps', type=int, default=50,
|
|
| 34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
| 35 |
choices=["ada", "arc2face", "consistentID"],
|
| 36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
|
|
|
|
|
|
| 37 |
|
| 38 |
parser.add_argument('--gpu', type=int, default=None)
|
| 39 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
@@ -79,6 +81,7 @@ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_
|
|
| 79 |
adaface_encoder_types=args.adaface_encoder_types,
|
| 80 |
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
|
| 81 |
num_inference_steps=args.num_inference_steps,
|
|
|
|
| 82 |
is_on_hf_space=is_on_hf_space)
|
| 83 |
|
| 84 |
basedir = os.getcwd()
|
|
@@ -208,7 +211,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
| 208 |
if args.ablate_prompt_embed_type != "ada":
|
| 209 |
# Find the prompt_emb_type index in adaface_encoder_types
|
| 210 |
# adaface_encoder_types: ["consistentID", "arc2face"]
|
| 211 |
-
ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type)
|
| 212 |
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
| 213 |
else:
|
| 214 |
ablate_prompt_embed_type = "ada"
|
|
@@ -270,6 +273,7 @@ def check_prompt_and_model_type(prompt, model_style_type, progress=gr.Progress()
|
|
| 270 |
adaface_encoder_types=args.adaface_encoder_types,
|
| 271 |
adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
|
| 272 |
num_inference_steps=args.num_inference_steps,
|
|
|
|
| 273 |
is_on_hf_space=is_on_hf_space)
|
| 274 |
# Update base model type.
|
| 275 |
args.model_style_type = model_style_type
|
|
|
|
| 34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
| 35 |
choices=["ada", "arc2face", "consistentID"],
|
| 36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
| 37 |
+
parser.add_argument('--max_prompt_length', type=int, default=97,
|
| 38 |
+
help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
|
| 39 |
|
| 40 |
parser.add_argument('--gpu', type=int, default=None)
|
| 41 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
|
|
| 81 |
adaface_encoder_types=args.adaface_encoder_types,
|
| 82 |
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
|
| 83 |
num_inference_steps=args.num_inference_steps,
|
| 84 |
+
max_prompt_length=args.max_prompt_length,
|
| 85 |
is_on_hf_space=is_on_hf_space)
|
| 86 |
|
| 87 |
basedir = os.getcwd()
|
|
|
|
| 211 |
if args.ablate_prompt_embed_type != "ada":
|
| 212 |
# Find the prompt_emb_type index in adaface_encoder_types
|
| 213 |
# adaface_encoder_types: ["consistentID", "arc2face"]
|
| 214 |
+
ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type) + 1
|
| 215 |
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
| 216 |
else:
|
| 217 |
ablate_prompt_embed_type = "ada"
|
|
|
|
| 273 |
adaface_encoder_types=args.adaface_encoder_types,
|
| 274 |
adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
|
| 275 |
num_inference_steps=args.num_inference_steps,
|
| 276 |
+
max_prompt_length=args.max_prompt_length,
|
| 277 |
is_on_hf_space=is_on_hf_space)
|
| 278 |
# Update base model type.
|
| 279 |
args.model_style_type = model_style_type
|