|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class AttentionMaskConverter: | 
					
						
						|  | """ | 
					
						
						|  | A utility attention mask class that allows one to: | 
					
						
						|  | - Create a causal 4d mask | 
					
						
						|  | - Create a causal 4d mask with slided window | 
					
						
						|  | - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, | 
					
						
						|  | key_value_length) that can be multiplied with attention scores | 
					
						
						|  |  | 
					
						
						|  | Examples: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> import torch | 
					
						
						|  | >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter | 
					
						
						|  |  | 
					
						
						|  | >>> converter = AttentionMaskConverter(True) | 
					
						
						|  | >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) | 
					
						
						|  | tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], | 
					
						
						|  | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], | 
					
						
						|  | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], | 
					
						
						|  | [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38], | 
					
						
						|  | [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]]) | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | is_causal (`bool`): | 
					
						
						|  | Whether the attention mask should be a uni-directional (causal) or bi-directional mask. | 
					
						
						|  |  | 
					
						
						|  | sliding_window (`int`, *optional*): | 
					
						
						|  | Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | is_causal: bool | 
					
						
						|  | sliding_window: int | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): | 
					
						
						|  | self.is_causal = is_causal | 
					
						
						|  | self.sliding_window = sliding_window | 
					
						
						|  |  | 
					
						
						|  | if self.sliding_window is not None and self.sliding_window <= 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def to_causal_4d( | 
					
						
						|  | self, | 
					
						
						|  | batch_size: int, | 
					
						
						|  | query_length: int, | 
					
						
						|  | key_value_length: int, | 
					
						
						|  | dtype: torch.dtype, | 
					
						
						|  | device: Union[torch.device, "str"] = "cpu", | 
					
						
						|  | ) -> Optional[torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative | 
					
						
						|  | bias to upper right hand triangular matrix (causal mask). | 
					
						
						|  | """ | 
					
						
						|  | if not self.is_causal: | 
					
						
						|  | raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | input_shape = (batch_size, query_length) | 
					
						
						|  | past_key_values_length = key_value_length - query_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | causal_4d_mask = None | 
					
						
						|  | if input_shape[-1] > 1 or self.sliding_window is not None: | 
					
						
						|  | causal_4d_mask = self._make_causal_mask( | 
					
						
						|  | input_shape, | 
					
						
						|  | dtype, | 
					
						
						|  | device=device, | 
					
						
						|  | past_key_values_length=past_key_values_length, | 
					
						
						|  | sliding_window=self.sliding_window, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return causal_4d_mask | 
					
						
						|  |  | 
					
						
						|  | def to_4d( | 
					
						
						|  | self, | 
					
						
						|  | attention_mask_2d: torch.Tensor, | 
					
						
						|  | query_length: int, | 
					
						
						|  | dtype: torch.dtype, | 
					
						
						|  | key_value_length: Optional[int] = None, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, | 
					
						
						|  | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is | 
					
						
						|  | causal, a causal mask will be added. | 
					
						
						|  | """ | 
					
						
						|  | input_shape = (attention_mask_2d.shape[0], query_length) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | causal_4d_mask = None | 
					
						
						|  | if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: | 
					
						
						|  | if key_value_length is None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | past_key_values_length = key_value_length - query_length | 
					
						
						|  | causal_4d_mask = self._make_causal_mask( | 
					
						
						|  | input_shape, | 
					
						
						|  | dtype, | 
					
						
						|  | device=attention_mask_2d.device, | 
					
						
						|  | past_key_values_length=past_key_values_length, | 
					
						
						|  | sliding_window=self.sliding_window, | 
					
						
						|  | ) | 
					
						
						|  | elif self.sliding_window is not None: | 
					
						
						|  | raise NotImplementedError("Sliding window is currently only implemented for causal masking") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( | 
					
						
						|  | attention_mask_2d.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if causal_4d_mask is not None: | 
					
						
						|  | expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | expanded_4d_mask = expanded_attn_mask | 
					
						
						|  |  | 
					
						
						|  | return expanded_4d_mask | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _make_causal_mask( | 
					
						
						|  | input_ids_shape: torch.Size, | 
					
						
						|  | dtype: torch.dtype, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | past_key_values_length: int = 0, | 
					
						
						|  | sliding_window: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Make causal mask used for bi-directional self-attention. | 
					
						
						|  | """ | 
					
						
						|  | bsz, tgt_len = input_ids_shape | 
					
						
						|  | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | 
					
						
						|  | mask_cond = torch.arange(mask.size(-1), device=device) | 
					
						
						|  | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | 
					
						
						|  |  | 
					
						
						|  | mask = mask.to(dtype) | 
					
						
						|  |  | 
					
						
						|  | if past_key_values_length > 0: | 
					
						
						|  | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if sliding_window is not None: | 
					
						
						|  | diagonal = past_key_values_length - sliding_window - 1 | 
					
						
						|  |  | 
					
						
						|  | context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) | 
					
						
						|  | mask.masked_fill_(context_mask, torch.finfo(dtype).min) | 
					
						
						|  |  | 
					
						
						|  | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): | 
					
						
						|  | """ | 
					
						
						|  | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | 
					
						
						|  | """ | 
					
						
						|  | bsz, src_len = mask.size() | 
					
						
						|  | tgt_len = tgt_len if tgt_len is not None else src_len | 
					
						
						|  |  | 
					
						
						|  | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | 
					
						
						|  |  | 
					
						
						|  | inverted_mask = 1.0 - expanded_mask | 
					
						
						|  |  | 
					
						
						|  | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _unmask_unattended( | 
					
						
						|  | expanded_mask: torch.FloatTensor, | 
					
						
						|  | min_dtype: float, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when | 
					
						
						|  | using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | 
					
						
						|  | Details: https://github.com/pytorch/pytorch/issues/110213 | 
					
						
						|  |  | 
					
						
						|  | `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. | 
					
						
						|  | `attention_mask` is [bsz, src_seq_len]. | 
					
						
						|  |  | 
					
						
						|  | The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. | 
					
						
						|  |  | 
					
						
						|  | For example, if `expanded_mask` is (e.g. here left-padding case) | 
					
						
						|  | ``` | 
					
						
						|  | [[[[0, 0, 0], | 
					
						
						|  | [0, 0, 0], | 
					
						
						|  | [0, 0, 1]]], | 
					
						
						|  | [[[1, 0, 0], | 
					
						
						|  | [1, 1, 0], | 
					
						
						|  | [1, 1, 1]]], | 
					
						
						|  | [[[0, 0, 0], | 
					
						
						|  | [0, 1, 0], | 
					
						
						|  | [0, 1, 1]]]] | 
					
						
						|  | ``` | 
					
						
						|  | then the modified `expanded_mask` will be | 
					
						
						|  | ``` | 
					
						
						|  | [[[[1, 1, 1],   <-- modified | 
					
						
						|  | [1, 1, 1],   <-- modified | 
					
						
						|  | [0, 0, 1]]], | 
					
						
						|  | [[[1, 0, 0], | 
					
						
						|  | [1, 1, 0], | 
					
						
						|  | [1, 1, 1]]], | 
					
						
						|  | [[[1, 1, 1],   <-- modified | 
					
						
						|  | [0, 1, 0], | 
					
						
						|  | [0, 1, 1]]]] | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if expanded_mask.dtype == torch.bool: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _ignore_causal_mask_sdpa( | 
					
						
						|  | attention_mask: Optional[torch.Tensor], | 
					
						
						|  | inputs_embeds: torch.Tensor, | 
					
						
						|  | past_key_values_length: int, | 
					
						
						|  | sliding_window: Optional[int] = None, | 
					
						
						|  | is_training: bool = False, | 
					
						
						|  | ) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. | 
					
						
						|  |  | 
					
						
						|  | In case no token is masked in the `attention_mask` argument, if `query_length == 1` or | 
					
						
						|  | `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, | 
					
						
						|  | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] | 
					
						
						|  | key_value_length = query_length + past_key_values_length | 
					
						
						|  |  | 
					
						
						|  | is_tracing = ( | 
					
						
						|  | torch.jit.is_tracing() | 
					
						
						|  | or isinstance(inputs_embeds, torch.fx.Proxy) | 
					
						
						|  | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ignore_causal_mask = False | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is None: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | (is_training or not is_tracing) | 
					
						
						|  | and (query_length == 1 or key_value_length == query_length) | 
					
						
						|  | and (sliding_window is None or key_value_length < sliding_window) | 
					
						
						|  | ): | 
					
						
						|  | ignore_causal_mask = True | 
					
						
						|  | elif sliding_window is None or key_value_length < sliding_window: | 
					
						
						|  | if len(attention_mask.shape) == 4: | 
					
						
						|  | return False | 
					
						
						|  | elif (is_training or not is_tracing) and torch.all(attention_mask == 1): | 
					
						
						|  | if query_length == 1 or key_value_length == query_length: | 
					
						
						|  |  | 
					
						
						|  | ignore_causal_mask = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return ignore_causal_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _prepare_4d_causal_attention_mask( | 
					
						
						|  | attention_mask: Optional[torch.Tensor], | 
					
						
						|  | input_shape: Union[torch.Size, Tuple, List], | 
					
						
						|  | inputs_embeds: torch.Tensor, | 
					
						
						|  | past_key_values_length: int, | 
					
						
						|  | sliding_window: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | 
					
						
						|  | `(batch_size, key_value_length)` | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | attention_mask (`torch.Tensor` or `None`): | 
					
						
						|  | A 2D attention mask of shape `(batch_size, key_value_length)` | 
					
						
						|  | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): | 
					
						
						|  | The input shape should be a tuple that defines `(batch_size, query_length)`. | 
					
						
						|  | inputs_embeds (`torch.Tensor`): | 
					
						
						|  | The embedded inputs as a torch Tensor. | 
					
						
						|  | past_key_values_length (`int`): | 
					
						
						|  | The length of the key value cache. | 
					
						
						|  | sliding_window (`int`, *optional*): | 
					
						
						|  | If the model uses windowed attention, a sliding window should be passed. | 
					
						
						|  | """ | 
					
						
						|  | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) | 
					
						
						|  |  | 
					
						
						|  | key_value_length = input_shape[-1] + past_key_values_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if attention_mask is not None and len(attention_mask.shape) == 2: | 
					
						
						|  | attention_mask = attn_mask_converter.to_4d( | 
					
						
						|  | attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype | 
					
						
						|  | ) | 
					
						
						|  | elif attention_mask is not None and len(attention_mask.shape) == 4: | 
					
						
						|  | expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) | 
					
						
						|  | if tuple(attention_mask.shape) != expected_shape: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | inverted_mask = 1.0 - attention_mask | 
					
						
						|  | attention_mask = inverted_mask.masked_fill( | 
					
						
						|  | inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | attention_mask = attn_mask_converter.to_causal_4d( | 
					
						
						|  | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return attention_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _prepare_4d_causal_attention_mask_for_sdpa( | 
					
						
						|  | attention_mask: Optional[torch.Tensor], | 
					
						
						|  | input_shape: Union[torch.Size, Tuple, List], | 
					
						
						|  | inputs_embeds: torch.Tensor, | 
					
						
						|  | past_key_values_length: int, | 
					
						
						|  | sliding_window: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. | 
					
						
						|  |  | 
					
						
						|  | In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and | 
					
						
						|  | `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, | 
					
						
						|  | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). | 
					
						
						|  | """ | 
					
						
						|  | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) | 
					
						
						|  |  | 
					
						
						|  | key_value_length = input_shape[-1] + past_key_values_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | is_tracing = ( | 
					
						
						|  | torch.jit.is_tracing() | 
					
						
						|  | or isinstance(inputs_embeds, torch.fx.Proxy) | 
					
						
						|  | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | past_key_values_length=past_key_values_length, | 
					
						
						|  | sliding_window=sliding_window, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if ignore_causal_mask: | 
					
						
						|  | expanded_4d_mask = None | 
					
						
						|  | elif attention_mask is None: | 
					
						
						|  | expanded_4d_mask = attn_mask_converter.to_causal_4d( | 
					
						
						|  | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | if attention_mask.dim() == 4: | 
					
						
						|  |  | 
					
						
						|  | if attention_mask.max() != 0: | 
					
						
						|  | raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") | 
					
						
						|  | expanded_4d_mask = attention_mask | 
					
						
						|  | else: | 
					
						
						|  | expanded_4d_mask = attn_mask_converter.to_4d( | 
					
						
						|  | attention_mask, | 
					
						
						|  | input_shape[-1], | 
					
						
						|  | dtype=inputs_embeds.dtype, | 
					
						
						|  | key_value_length=key_value_length, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not is_tracing and expanded_4d_mask.device.type == "cuda": | 
					
						
						|  | expanded_4d_mask = AttentionMaskConverter._unmask_unattended( | 
					
						
						|  | expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return expanded_4d_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): | 
					
						
						|  | """ | 
					
						
						|  | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | 
					
						
						|  | `(batch_size, key_value_length)` | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | mask (`torch.Tensor`): | 
					
						
						|  | A 2D attention mask of shape `(batch_size, key_value_length)` | 
					
						
						|  | dtype (`torch.dtype`): | 
					
						
						|  | The torch dtype the created mask shall have. | 
					
						
						|  | tgt_len (`int`): | 
					
						
						|  | The target length or query length the created mask shall have. | 
					
						
						|  | """ | 
					
						
						|  | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): | 
					
						
						|  | """ | 
					
						
						|  | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | 
					
						
						|  | `(batch_size, key_value_length)` | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | mask (`torch.Tensor`): | 
					
						
						|  | A 2D attention mask of shape `(batch_size, key_value_length)` | 
					
						
						|  | dtype (`torch.dtype`): | 
					
						
						|  | The torch dtype the created mask shall have. | 
					
						
						|  | tgt_len (`int`): | 
					
						
						|  | The target length or query length the created mask shall have. | 
					
						
						|  | """ | 
					
						
						|  | _, key_value_length = mask.shape | 
					
						
						|  | tgt_len = tgt_len if tgt_len is not None else key_value_length | 
					
						
						|  |  | 
					
						
						|  | is_tracing = ( | 
					
						
						|  | torch.jit.is_tracing() | 
					
						
						|  | or isinstance(mask, torch.fx.Proxy) | 
					
						
						|  | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not is_tracing and torch.all(mask == 1): | 
					
						
						|  | return None | 
					
						
						|  | else: | 
					
						
						|  | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _create_4d_causal_attention_mask( | 
					
						
						|  | input_shape: Union[torch.Size, Tuple, List], | 
					
						
						|  | dtype: torch.dtype, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | past_key_values_length: int = 0, | 
					
						
						|  | sliding_window: Optional[int] = None, | 
					
						
						|  | ) -> Optional[torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): | 
					
						
						|  | The input shape should be a tuple that defines `(batch_size, query_length)`. | 
					
						
						|  | dtype (`torch.dtype`): | 
					
						
						|  | The torch dtype the created mask shall have. | 
					
						
						|  | device (`int`): | 
					
						
						|  | The torch device the created mask shall have. | 
					
						
						|  | sliding_window (`int`, *optional*): | 
					
						
						|  | If the model uses windowed attention, a sliding window should be passed. | 
					
						
						|  | """ | 
					
						
						|  | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) | 
					
						
						|  |  | 
					
						
						|  | key_value_length = past_key_values_length + input_shape[-1] | 
					
						
						|  | attention_mask = attn_mask_converter.to_causal_4d( | 
					
						
						|  | input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return attention_mask | 
					
						
						|  |  |