Spaces:
Running
on
Zero
Running
on
Zero
| # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Global Variable | |
| global_concept_mask = [] | |
| attn_mask_logs = {} | |
| text_attn_map_logs = {} | |
| image_attn_map_logs = {} | |
| class AttnProcessor(nn.Module): | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size=None, | |
| cross_attention_dim=None, | |
| ): | |
| super().__init__() | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class IPAttnProcessor(nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
| The context length of the image features. | |
| """ | |
| def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| self.to_k_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.to_v_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| global global_concept_mask | |
| global attn_mask_logs | |
| global text_attn_map_logs | |
| global image_attn_map_logs | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| # get encoder_hidden_states, ip_hidden_states | |
| end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
| encoder_hidden_states, ip_hidden_states = ( | |
| encoder_hidden_states[:, :end_pos, :], | |
| encoder_hidden_states[:, end_pos:, :], | |
| ) | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # for ip-adapter | |
| ip_key = self.to_k_ip(ip_hidden_states) | |
| ip_value = self.to_v_ip(ip_hidden_states) | |
| ip_key = attn.head_to_batch_dim(ip_key) | |
| ip_value = attn.head_to_batch_dim(ip_value) | |
| ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | |
| self.attn_map = ip_attention_probs | |
| ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | |
| ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
| hidden_states = hidden_states + self.scale * ip_hidden_states | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class ConceptrolAttnProcessor(nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
| The context length of the image features. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size, | |
| cross_attention_dim=None, | |
| scale=1.0, | |
| num_tokens=4, | |
| textual_concept_idxs=None, | |
| name=None, | |
| global_masking=False, | |
| adaptive_scale_mask=False, | |
| concept_mask_layer=None, | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| self.textual_concept_idxs = textual_concept_idxs | |
| self.name = name | |
| self.to_k_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.to_v_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.global_masking = global_masking | |
| self.adaptive_scale_mask = adaptive_scale_mask | |
| if concept_mask_layer is None: | |
| concept_mask_layer = [ | |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor" | |
| ] # For SD | |
| print("Warning: Using default concept mask layer for SD. For SDXL, use 'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor'") | |
| # concept_mask_layer = ['up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor'] # For SDXL | |
| self.concept_mask_layer = concept_mask_layer | |
| def set_global_view(self, attn_procs): | |
| self.attn_procs = attn_procs | |
| # print(self.name, self.attn_procs.keys()) | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| global global_concept_mask | |
| global attn_mask_logs | |
| if self.textual_concept_idxs is None: | |
| raise ValueError( | |
| "textual_concept_idxs should be provided for ConceptrolAttnProcessor" | |
| ) | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| # get encoder_hidden_states, ip_hidden_states | |
| end_pos = 77 # Both SD and SDXL use 77 as length of text tokens | |
| encoder_hidden_states, ip_hidden_states_cat = ( | |
| encoder_hidden_states[:, :end_pos, :], | |
| encoder_hidden_states[:, end_pos:, :], | |
| ) | |
| num_concepts = ip_hidden_states_cat.shape[1] // self.num_tokens | |
| ip_hidden_states_list = torch.chunk( | |
| ip_hidden_states_cat, num_concepts, dim=1 | |
| ) | |
| assert len(ip_hidden_states_list) == len( | |
| self.textual_concept_idxs | |
| ), f"register_idxs should have the same length as the number of concepts, but got {len(ip_hidden_states_list)} and {len(self.textual_concept_idxs)}" | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) # [16, 4096, 40] | |
| key = attn.head_to_batch_dim(key) # [16, 77, 40] | |
| value = attn.head_to_batch_dim(value) # [16, 77, 40] | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| concept_mask_layer = self.concept_mask_layer | |
| if len(global_concept_mask) == 0: | |
| global_concept_mask = [None for _ in range(len(self.textual_concept_idxs))] | |
| for i in range(len(self.textual_concept_idxs)): | |
| ip_hidden_states = ip_hidden_states_list[i] | |
| textual_concept_start_idx, textual_concept_end_idx = ( | |
| self.textual_concept_idxs[i] | |
| ) | |
| ip_key = self.to_k_ip(ip_hidden_states) | |
| ip_value = self.to_v_ip(ip_hidden_states) | |
| ip_key = attn.head_to_batch_dim(ip_key) # [16, 4, 40] | |
| ip_value = attn.head_to_batch_dim(ip_value) # [16, 4, 40] | |
| # attention_probs: [20/40, 4096, 77] | |
| ip_attention_mask = attention_probs[ | |
| :, :, textual_concept_start_idx:textual_concept_end_idx | |
| ] # [16, 4096, T] | |
| ip_attention_mask = torch.mean( | |
| ip_attention_mask, dim=-1, keepdim=True | |
| ) # [16, 4096, 1] | |
| ip_attention_mask = attn.batch_to_head_dim( | |
| ip_attention_mask | |
| ) # [2, 4096, 8] | |
| ip_attention_mask = torch.mean( | |
| ip_attention_mask, dim=-1, keepdim=True | |
| ) # [2, 4096, 1] | |
| ip_attention_mask = ip_attention_mask / ( | |
| torch.amax(ip_attention_mask, dim=-2, keepdim=True) + 1e-6 | |
| ) | |
| ip_attention_mask = ip_attention_mask[1:2] # (use the classifier one) | |
| # Visualization | |
| if self.name not in attn_mask_logs: | |
| attn_mask_logs[self.name] = [] | |
| text_attn_map_logs[self.name] = [] | |
| image_attn_map_logs[self.name] = [] | |
| attn_mask_logs[self.name].append( | |
| ip_attention_mask.detach().cpu().numpy()[0, :, 0] | |
| ) | |
| text_attn_map_logs[self.name].append( | |
| ip_attention_mask.detach().cpu().numpy()[0, :, 0] | |
| ) | |
| if self.global_masking and ( | |
| self.name == concept_mask_layer[0] | |
| ): | |
| global_concept_mask[i] = ip_attention_mask | |
| if ( | |
| self.global_masking | |
| and self.name != concept_mask_layer[0] | |
| and global_concept_mask[i] is not None | |
| ): | |
| original_dim = int(global_concept_mask[i].shape[1] ** 0.5) | |
| target_dim = int(hidden_states.shape[1] ** 0.5) | |
| global_concept_mask_2d = global_concept_mask[i].view( | |
| global_concept_mask[i].shape[0], 1, original_dim, original_dim | |
| ) | |
| resized_global_concept_mask_2d = F.interpolate( | |
| global_concept_mask_2d, | |
| size=(target_dim, target_dim), | |
| mode="nearest", | |
| ) | |
| resized_global_concept_mask = resized_global_concept_mask_2d.view( | |
| global_concept_mask[i].shape[0], -1, 1 | |
| ) | |
| ip_attention_mask = resized_global_concept_mask | |
| ip_attention_probs = attn.get_attention_scores( | |
| query, ip_key, None | |
| ) # [16, 4096, 4] | |
| # Visualization | |
| ip_attention_map = attention_probs[:, :, 15:16] # [16, 4096, T] | |
| ip_attention_map = torch.mean( | |
| ip_attention_map, dim=-1, keepdim=True | |
| ) # [16, 4096, 1] | |
| ip_attention_map = torch.mean( | |
| ip_attention_map, dim=-1, keepdim=True | |
| ) # [16, 4096, 1] | |
| ip_attention_map = attn.batch_to_head_dim(ip_attention_map) # [2, 4096, 8] | |
| ip_attention_map = torch.mean( | |
| ip_attention_map, dim=-1, keepdim=True | |
| ) # [2, 4096, 1] | |
| ip_attention_map = ip_attention_map / ( | |
| torch.amax(ip_attention_map, dim=-2, keepdim=True) + 1e-6 | |
| ) | |
| ip_attention_map = ip_attention_map[1:2] # (use the classifier one) | |
| image_attn_map_logs[self.name].append( | |
| ip_attention_map.detach().cpu().numpy()[0, :, 0] | |
| ) | |
| ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) # [16, 4096, 40] | |
| ip_hidden_states = attn.batch_to_head_dim( | |
| ip_hidden_states | |
| ) # [2, 4096, 320] | |
| ip_hidden_states = ip_hidden_states * ip_attention_mask | |
| if self.adaptive_scale_mask: | |
| raise ValueError("adaptive_scale_mask is deprecated already") | |
| hidden_states += self.scale * ip_hidden_states | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class AttnProcessor2_0(torch.nn.Module): | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size=None, | |
| cross_attention_dim=None, | |
| ): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class IPAttnProcessor2_0(torch.nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater for PyTorch 2.0. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
| The context length of the image features. | |
| """ | |
| def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| self.to_k_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.to_v_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| # get encoder_hidden_states, ip_hidden_states | |
| end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
| encoder_hidden_states, ip_hidden_states = ( | |
| encoder_hidden_states[:, :end_pos, :], | |
| encoder_hidden_states[:, end_pos:, :], | |
| ) | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # for ip-adapter | |
| ip_key = self.to_k_ip(ip_hidden_states) | |
| ip_value = self.to_v_ip(ip_hidden_states) | |
| ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| ip_hidden_states = F.scaled_dot_product_attention( | |
| query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| with torch.no_grad(): | |
| self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) | |
| # print(self.attn_map.shape) | |
| ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| ip_hidden_states = ip_hidden_states.to(query.dtype) | |
| hidden_states = hidden_states + self.scale * ip_hidden_states | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| ## for controlnet | |
| class CNAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __init__(self, num_tokens=4): | |
| self.num_tokens = num_tokens | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
| encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class CNAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
| """ | |
| def __init__(self, num_tokens=4): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| self.num_tokens = num_tokens | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
| encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |