|  | import copy | 
					
						
						|  | import importlib.metadata | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Any, Dict, List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from packaging import version | 
					
						
						|  |  | 
					
						
						|  | from transformers.configuration_utils import PretrainedConfig | 
					
						
						|  | from transformers.utils import is_torchdynamo_compiling, logging | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger = logging.get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Cache(torch.nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Base, abstract class for all caches. The actual data structure is specific to each subclass. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | key_states (`torch.Tensor`): | 
					
						
						|  | The new key states to cache. | 
					
						
						|  | value_states (`torch.Tensor`): | 
					
						
						|  | The new value states to cache. | 
					
						
						|  | layer_idx (`int`): | 
					
						
						|  | The index of the layer to cache the states for. | 
					
						
						|  | cache_kwargs (`Dict[str, Any]`, `optional`): | 
					
						
						|  | Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | 
					
						
						|  | cache to be created. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | A tuple containing the updated key and value states. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Make sure to implement `update` in a subclass.") | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | 
					
						
						|  |  | 
					
						
						|  | raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  | """Returns the maximum sequence length of the cached states, if there is any.""" | 
					
						
						|  | raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") | 
					
						
						|  |  | 
					
						
						|  | def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Given the sequence length of the new inputs, returns the usable length of the cache.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_length = self.get_max_length() | 
					
						
						|  | previous_seq_length = self.get_seq_length(layer_idx) | 
					
						
						|  | if max_length is not None and previous_seq_length + new_seq_length > max_length: | 
					
						
						|  | return max_length - new_seq_length | 
					
						
						|  | return previous_seq_length | 
					
						
						|  |  | 
					
						
						|  | def reorder_cache(self, beam_idx: torch.LongTensor): | 
					
						
						|  | """Reorders the cache for beam search, given the selected beam indices.""" | 
					
						
						|  | for layer_idx in range(len(self.key_cache)): | 
					
						
						|  | device = self.key_cache[layer_idx].device | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | 
					
						
						|  | device = self.value_cache[layer_idx].device | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def seen_tokens(self): | 
					
						
						|  | logger.warning_once( | 
					
						
						|  | "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " | 
					
						
						|  | "model input instead." | 
					
						
						|  | ) | 
					
						
						|  | if hasattr(self, "_seen_tokens"): | 
					
						
						|  | return self._seen_tokens | 
					
						
						|  | else: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class CacheConfig: | 
					
						
						|  | """ | 
					
						
						|  | Base class for cache configs | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | cache_implementation: None | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_dict(cls, config_dict, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | Constructs a CacheConfig instance from a dictionary of parameters. | 
					
						
						|  | Args: | 
					
						
						|  | config_dict (Dict[str, Any]): Dictionary containing configuration parameters. | 
					
						
						|  | **kwargs: Additional keyword arguments to override dictionary values. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | CacheConfig: Instance of CacheConfig constructed from the dictionary. | 
					
						
						|  | """ | 
					
						
						|  | config = cls(**config_dict) | 
					
						
						|  | to_remove = [] | 
					
						
						|  | for key, value in kwargs.items(): | 
					
						
						|  | if hasattr(config, key): | 
					
						
						|  | setattr(config, key, value) | 
					
						
						|  | to_remove.append(key) | 
					
						
						|  | for key in to_remove: | 
					
						
						|  | kwargs.pop(key, None) | 
					
						
						|  | return config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def to_json_file(self, json_file_path: Union[str, os.PathLike]): | 
					
						
						|  | """ | 
					
						
						|  | Save this instance to a JSON file. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | json_file_path (`str` or `os.PathLike`): | 
					
						
						|  | Path to the JSON file in which this configuration instance's parameters will be saved. | 
					
						
						|  | use_diff (`bool`, *optional*, defaults to `True`): | 
					
						
						|  | If set to `True`, only the difference between the config instance and the default | 
					
						
						|  | `QuantizationConfig()` is serialized to JSON file. | 
					
						
						|  | """ | 
					
						
						|  | with open(json_file_path, "w", encoding="utf-8") as writer: | 
					
						
						|  | config_dict = self.to_dict() | 
					
						
						|  | json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" | 
					
						
						|  |  | 
					
						
						|  | writer.write(json_string) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def to_dict(self) -> Dict[str, Any]: | 
					
						
						|  | """ | 
					
						
						|  | Serializes this instance to a Python dictionary. Returns: | 
					
						
						|  | `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. | 
					
						
						|  | """ | 
					
						
						|  | return copy.deepcopy(self.__dict__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" | 
					
						
						|  | for attr, value in copy.deepcopy(self.__dict__).items(): | 
					
						
						|  | yield attr, value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def __repr__(self): | 
					
						
						|  | return f"{self.__class__.__name__} {self.to_json_string()}" | 
					
						
						|  |  | 
					
						
						|  | def to_json_string(self): | 
					
						
						|  | """ | 
					
						
						|  | Serializes this instance to a JSON formatted string. | 
					
						
						|  | Returns: | 
					
						
						|  | str: JSON formatted string representing the configuration instance. | 
					
						
						|  | """ | 
					
						
						|  | return json.dumps(self.__dict__, indent=2) + "\n" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def update(self, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, | 
					
						
						|  | returning all the unused kwargs. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | kwargs (`Dict[str, Any]`): | 
					
						
						|  | Dictionary of attributes to tentatively update this class. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. | 
					
						
						|  | """ | 
					
						
						|  | to_remove = [] | 
					
						
						|  | for key, value in kwargs.items(): | 
					
						
						|  | if hasattr(self, key): | 
					
						
						|  | setattr(self, key, value) | 
					
						
						|  | to_remove.append(key) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} | 
					
						
						|  | return unused_kwargs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DynamicCache(Cache): | 
					
						
						|  | """ | 
					
						
						|  | A cache that grows dynamically as more tokens are generated. This is the default for generative models. | 
					
						
						|  |  | 
					
						
						|  | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | 
					
						
						|  | `[batch_size, num_heads, seq_len, head_dim]`. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> past_key_values = DynamicCache() | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.key_cache: List[torch.Tensor] = [] | 
					
						
						|  | self.value_cache: List[torch.Tensor] = [] | 
					
						
						|  | self._seen_tokens = 0 | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | 
					
						
						|  | """ | 
					
						
						|  | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | 
					
						
						|  | sequence length. | 
					
						
						|  | """ | 
					
						
						|  | if layer_idx < len(self): | 
					
						
						|  | return (self.key_cache[layer_idx], self.value_cache[layer_idx]) | 
					
						
						|  | else: | 
					
						
						|  | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | 
					
						
						|  |  | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | """ | 
					
						
						|  | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over | 
					
						
						|  | keys and values | 
					
						
						|  | """ | 
					
						
						|  | for layer_idx in range(len(self)): | 
					
						
						|  | yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | """ | 
					
						
						|  | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | 
					
						
						|  | to the number of layers in the model. | 
					
						
						|  | """ | 
					
						
						|  | return len(self.key_cache) | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | key_states (`torch.Tensor`): | 
					
						
						|  | The new key states to cache. | 
					
						
						|  | value_states (`torch.Tensor`): | 
					
						
						|  | The new value states to cache. | 
					
						
						|  | layer_idx (`int`): | 
					
						
						|  | The index of the layer to cache the states for. | 
					
						
						|  | cache_kwargs (`Dict[str, Any]`, `optional`): | 
					
						
						|  | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | A tuple containing the updated key and value states. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if layer_idx == 0: | 
					
						
						|  | self._seen_tokens += key_states.shape[-2] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(self.key_cache) <= layer_idx: | 
					
						
						|  | self.key_cache.append(key_states) | 
					
						
						|  | self.value_cache.append(value_states) | 
					
						
						|  | else: | 
					
						
						|  | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | 
					
						
						|  | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | 
					
						
						|  |  | 
					
						
						|  | return self.key_cache[layer_idx], self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | 
					
						
						|  |  | 
					
						
						|  | if len(self.key_cache) <= layer_idx: | 
					
						
						|  | return 0 | 
					
						
						|  | return self.key_cache[layer_idx].shape[-2] | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | 
					
						
						|  | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for | 
					
						
						|  | backward compatibility.""" | 
					
						
						|  | legacy_cache = () | 
					
						
						|  | for layer_idx in range(len(self)): | 
					
						
						|  | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) | 
					
						
						|  | return legacy_cache | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": | 
					
						
						|  | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for | 
					
						
						|  | backward compatibility.""" | 
					
						
						|  | cache = cls() | 
					
						
						|  | if past_key_values is not None: | 
					
						
						|  | for layer_idx in range(len(past_key_values)): | 
					
						
						|  | key_states, value_states = past_key_values[layer_idx] | 
					
						
						|  | cache.update(key_states, value_states, layer_idx) | 
					
						
						|  | return cache | 
					
						
						|  |  | 
					
						
						|  | def crop(self, max_length: int): | 
					
						
						|  | """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be | 
					
						
						|  | negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" | 
					
						
						|  |  | 
					
						
						|  | if max_length < 0: | 
					
						
						|  | max_length = self.get_seq_length() - abs(max_length) | 
					
						
						|  |  | 
					
						
						|  | if self.get_seq_length() <= max_length: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | self._seen_tokens = max_length | 
					
						
						|  | for idx in range(len(self.key_cache)): | 
					
						
						|  | self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] | 
					
						
						|  | self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] | 
					
						
						|  |  | 
					
						
						|  | def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: | 
					
						
						|  | """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | 
					
						
						|  | `_split_model_inputs()` in `generation.utils`""" | 
					
						
						|  | out = [] | 
					
						
						|  | for i in range(0, full_batch_size, split_size): | 
					
						
						|  | current_split = DynamicCache() | 
					
						
						|  | current_split._seen_tokens = self._seen_tokens | 
					
						
						|  | current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] | 
					
						
						|  | current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] | 
					
						
						|  | out.append(current_split) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": | 
					
						
						|  | """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | 
					
						
						|  | `generation.utils`""" | 
					
						
						|  | cache = cls() | 
					
						
						|  | for idx in range(len(splits[0])): | 
					
						
						|  | layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) | 
					
						
						|  | layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) | 
					
						
						|  | cache.update(layer_keys, layer_values, idx) | 
					
						
						|  | return cache | 
					
						
						|  |  | 
					
						
						|  | def batch_repeat_interleave(self, repeats: int): | 
					
						
						|  | """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" | 
					
						
						|  | for layer_idx in range(len(self)): | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) | 
					
						
						|  |  | 
					
						
						|  | def batch_select_indices(self, indices: torch.Tensor): | 
					
						
						|  | """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" | 
					
						
						|  | for layer_idx in range(len(self)): | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class OffloadedCache(DynamicCache): | 
					
						
						|  | """ | 
					
						
						|  | A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. | 
					
						
						|  | Useful for generating from models with very long context. | 
					
						
						|  |  | 
					
						
						|  | In addition to the default CUDA stream, where all forward() computations happen, | 
					
						
						|  | this class uses another stream, the prefetch stream, which it creates itself. | 
					
						
						|  | Since scheduling of operations on separate streams happens independently, this class uses | 
					
						
						|  | the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. | 
					
						
						|  | The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to | 
					
						
						|  | ensure the eviction is scheduled after all computations on that cache are finished. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self) -> None: | 
					
						
						|  | if not torch.cuda.is_available(): | 
					
						
						|  | raise RuntimeError("OffloadedCache can only be used with a GPU") | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.original_device = [] | 
					
						
						|  | self.prefetch_stream = torch.cuda.Stream() | 
					
						
						|  | self.beam_idx = None | 
					
						
						|  |  | 
					
						
						|  | def prefetch_layer(self, layer_idx: int): | 
					
						
						|  | "Starts prefetching the next layer cache" | 
					
						
						|  | if layer_idx < len(self): | 
					
						
						|  | with torch.cuda.stream(self.prefetch_stream): | 
					
						
						|  |  | 
					
						
						|  | device = self.original_device[layer_idx] | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) | 
					
						
						|  |  | 
					
						
						|  | def evict_previous_layer(self, layer_idx: int): | 
					
						
						|  | "Moves the previous layer cache to the CPU" | 
					
						
						|  | if len(self) > 2: | 
					
						
						|  |  | 
					
						
						|  | prev_layer_idx = (layer_idx - 1) % len(self) | 
					
						
						|  | self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) | 
					
						
						|  | self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | 
					
						
						|  | "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." | 
					
						
						|  | if layer_idx < len(self): | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.current_stream().synchronize() | 
					
						
						|  | self.evict_previous_layer(layer_idx) | 
					
						
						|  |  | 
					
						
						|  | original_device = self.original_device[layer_idx] | 
					
						
						|  | self.prefetch_stream.synchronize() | 
					
						
						|  | key_tensor = self.key_cache[layer_idx] | 
					
						
						|  | value_tensor = self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | if self.beam_idx is not None: | 
					
						
						|  | self.beam_idx = self.beam_idx.to(original_device) | 
					
						
						|  | key_tensor = key_tensor.index_select(0, self.beam_idx) | 
					
						
						|  | value_tensor = value_tensor.index_select(0, self.beam_idx) | 
					
						
						|  |  | 
					
						
						|  | self.prefetch_layer((layer_idx + 1) % len(self)) | 
					
						
						|  | return (key_tensor, value_tensor) | 
					
						
						|  | else: | 
					
						
						|  | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | 
					
						
						|  |  | 
					
						
						|  | def reorder_cache(self, beam_idx: torch.LongTensor): | 
					
						
						|  | """Saves the beam indices and reorders the cache when the tensor is back to its device.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | del self.beam_idx | 
					
						
						|  | self.beam_idx = beam_idx.clone() | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | 
					
						
						|  | Parameters: | 
					
						
						|  | key_states (`torch.Tensor`): | 
					
						
						|  | The new key states to cache. | 
					
						
						|  | value_states (`torch.Tensor`): | 
					
						
						|  | The new value states to cache. | 
					
						
						|  | layer_idx (`int`): | 
					
						
						|  | The index of the layer to cache the states for. | 
					
						
						|  | cache_kwargs (`Dict[str, Any]`, `optional`): | 
					
						
						|  | Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. | 
					
						
						|  | Return: | 
					
						
						|  | A tuple containing the updated key and value states. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if layer_idx == 0: | 
					
						
						|  | self._seen_tokens += key_states.shape[-2] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(self.key_cache) <= layer_idx: | 
					
						
						|  | self.key_cache.append(key_states) | 
					
						
						|  | self.value_cache.append(value_states) | 
					
						
						|  | self.original_device.append(key_states.device) | 
					
						
						|  | self.evict_previous_layer(layer_idx) | 
					
						
						|  | else: | 
					
						
						|  | key_tensor, value_tensor = self[layer_idx] | 
					
						
						|  | self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) | 
					
						
						|  | self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) | 
					
						
						|  |  | 
					
						
						|  | return self.key_cache[layer_idx], self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from_legacy_cache = None | 
					
						
						|  |  | 
					
						
						|  | to_legacy_cache = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SinkCache(Cache): | 
					
						
						|  | """ | 
					
						
						|  | A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to | 
					
						
						|  | generate beyond the length of its context window, without losing fluency in the conversation. As it discards past | 
					
						
						|  | tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. | 
					
						
						|  |  | 
					
						
						|  | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | 
					
						
						|  | `[batch_size, num_heads, seq_len, head_dim]`. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | window_length (`int`): | 
					
						
						|  | The length of the context window. | 
					
						
						|  | num_sink_tokens (`int`): | 
					
						
						|  | The number of sink tokens. See the original paper for more information. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, window_length: int, num_sink_tokens: int) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.key_cache: List[torch.Tensor] = [] | 
					
						
						|  | self.value_cache: List[torch.Tensor] = [] | 
					
						
						|  | self.window_length = window_length | 
					
						
						|  | self.num_sink_tokens = num_sink_tokens | 
					
						
						|  | self.cos_sin_rerotation_cache = {} | 
					
						
						|  | self._cos_cache = None | 
					
						
						|  | self._sin_cache = None | 
					
						
						|  | self._seen_tokens = 0 | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _rotate_half(x): | 
					
						
						|  | x1 = x[..., : x.shape[-1] // 2] | 
					
						
						|  | x2 = x[..., x.shape[-1] // 2 :] | 
					
						
						|  | return torch.cat((-x2, x1), dim=-1) | 
					
						
						|  |  | 
					
						
						|  | def _apply_key_rotary_pos_emb( | 
					
						
						|  | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) | 
					
						
						|  | return rotated_key_states | 
					
						
						|  |  | 
					
						
						|  | def _get_rerotation_cos_sin( | 
					
						
						|  | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | if key_states.shape[-2] not in self.cos_sin_rerotation_cache: | 
					
						
						|  |  | 
					
						
						|  | cos = cos.to(torch.float32) | 
					
						
						|  | sin = sin.to(torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] | 
					
						
						|  | shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] | 
					
						
						|  | original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] | 
					
						
						|  | shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] | 
					
						
						|  | rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin | 
					
						
						|  | rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin | 
					
						
						|  |  | 
					
						
						|  | self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( | 
					
						
						|  | rerotation_cos.to(key_states.dtype).unsqueeze(0), | 
					
						
						|  | rerotation_sin.to(key_states.dtype).unsqueeze(0), | 
					
						
						|  | ) | 
					
						
						|  | return self.cos_sin_rerotation_cache[key_states.shape[-2]] | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(self.key_cache) <= layer_idx: | 
					
						
						|  | return 0 | 
					
						
						|  | return self.key_cache[layer_idx].shape[-2] | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  | """Returns the maximum sequence length of the cached states.""" | 
					
						
						|  | return self.window_length | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | key_states (`torch.Tensor`): | 
					
						
						|  | The new key states to cache. | 
					
						
						|  | value_states (`torch.Tensor`): | 
					
						
						|  | The new value states to cache. | 
					
						
						|  | layer_idx (`int`): | 
					
						
						|  | The index of the layer to cache the states for. | 
					
						
						|  | cache_kwargs (`Dict[str, Any]`, `optional`): | 
					
						
						|  | Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, | 
					
						
						|  | `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the | 
					
						
						|  | rotation as the tokens are shifted. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | A tuple containing the updated key and value states. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sin = cache_kwargs.get("sin") | 
					
						
						|  | cos = cache_kwargs.get("cos") | 
					
						
						|  | partial_rotation_size = cache_kwargs.get("partial_rotation_size") | 
					
						
						|  | using_rope = cos is not None and sin is not None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if layer_idx == 0: | 
					
						
						|  | self._seen_tokens += key_states.shape[-2] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if using_rope and layer_idx == 0: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cos.dim() == 2: | 
					
						
						|  | self._cos_cache = cos | 
					
						
						|  | self._sin_cache = sin | 
					
						
						|  | else: | 
					
						
						|  | if self._cos_cache is None: | 
					
						
						|  | self._cos_cache = cos[0, ...] | 
					
						
						|  | self._sin_cache = sin[0, ...] | 
					
						
						|  | elif self._cos_cache.shape[0] < self.window_length: | 
					
						
						|  | self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) | 
					
						
						|  | self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(self.key_cache) <= layer_idx: | 
					
						
						|  |  | 
					
						
						|  | self.key_cache.append(key_states) | 
					
						
						|  | self.value_cache.append(value_states) | 
					
						
						|  |  | 
					
						
						|  | elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | 
					
						
						|  | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | keys_to_keep = self.key_cache[layer_idx][ | 
					
						
						|  | :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if using_rope: | 
					
						
						|  | rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( | 
					
						
						|  | key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] | 
					
						
						|  | ) | 
					
						
						|  | if partial_rotation_size is not None: | 
					
						
						|  | keys_to_keep, keys_pass = ( | 
					
						
						|  | keys_to_keep[..., :partial_rotation_size], | 
					
						
						|  | keys_to_keep[..., partial_rotation_size:], | 
					
						
						|  | ) | 
					
						
						|  | keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) | 
					
						
						|  | if partial_rotation_size is not None: | 
					
						
						|  | keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] | 
					
						
						|  | self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) | 
					
						
						|  |  | 
					
						
						|  | sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] | 
					
						
						|  | values_to_keep = self.value_cache[layer_idx][ | 
					
						
						|  | :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : | 
					
						
						|  | ] | 
					
						
						|  | self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) | 
					
						
						|  |  | 
					
						
						|  | return self.key_cache[layer_idx], self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class StaticCache(Cache): | 
					
						
						|  | """ | 
					
						
						|  | Static Cache class to be used with `torch.compile(model)` and `torch.export()`. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | config (`PretrainedConfig`): | 
					
						
						|  | The configuration file defining the shape-related attributes required to initialize the static cache. | 
					
						
						|  | max_batch_size (`int`): | 
					
						
						|  | The maximum batch size with which the model will be used. | 
					
						
						|  | max_cache_len (`int`): | 
					
						
						|  | The maximum sequence length with which the model will be used. | 
					
						
						|  | device (`torch.device`): | 
					
						
						|  | The device on which the cache should be initialized. Should be the same as the layer. | 
					
						
						|  | dtype (*optional*, defaults to `torch.float32`): | 
					
						
						|  | The default `dtype` to use when initializing the layer. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | 
					
						
						|  | >>> max_generated_length = inputs.input_ids.shape[1] + 10 | 
					
						
						|  | >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.max_batch_size = max_batch_size | 
					
						
						|  | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len | 
					
						
						|  |  | 
					
						
						|  | self.head_dim = ( | 
					
						
						|  | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.dtype = dtype if dtype is not None else torch.float32 | 
					
						
						|  | self.num_key_value_heads = ( | 
					
						
						|  | config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.key_cache: List[torch.Tensor] = [] | 
					
						
						|  | self.value_cache: List[torch.Tensor] = [] | 
					
						
						|  |  | 
					
						
						|  | cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) | 
					
						
						|  | for idx in range(config.num_hidden_layers): | 
					
						
						|  | new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | 
					
						
						|  | new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not is_torchdynamo_compiling(): | 
					
						
						|  | self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) | 
					
						
						|  | self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) | 
					
						
						|  | new_layer_key_cache = getattr(self, f"key_cache_{idx}") | 
					
						
						|  | new_layer_value_cache = getattr(self, f"value_cache_{idx}") | 
					
						
						|  | torch._dynamo.mark_static_address(new_layer_key_cache) | 
					
						
						|  | torch._dynamo.mark_static_address(new_layer_value_cache) | 
					
						
						|  | self.key_cache.append(new_layer_key_cache) | 
					
						
						|  | self.value_cache.append(new_layer_value_cache) | 
					
						
						|  | self._seen_tokens = 0 | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | 
					
						
						|  | It is VERY important to index using a tensor, otherwise you introduce a copy to the device. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | key_states (`torch.Tensor`): | 
					
						
						|  | The new key states to cache. | 
					
						
						|  | value_states (`torch.Tensor`): | 
					
						
						|  | The new value states to cache. | 
					
						
						|  | layer_idx (`int`): | 
					
						
						|  | The index of the layer to cache the states for. | 
					
						
						|  | cache_kwargs (`Dict[str, Any]`, `optional`): | 
					
						
						|  | Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input | 
					
						
						|  | to know how where to write in the cache. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | A tuple containing the updated key and value states. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if layer_idx == 0: | 
					
						
						|  | self._seen_tokens += key_states.shape[-2] | 
					
						
						|  |  | 
					
						
						|  | cache_position = cache_kwargs.get("cache_position") | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) | 
					
						
						|  | k_out = self.key_cache[layer_idx] | 
					
						
						|  | v_out = self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | if cache_position is None: | 
					
						
						|  | k_out.copy_(key_states) | 
					
						
						|  | v_out.copy_(value_states) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | k_out.index_copy_(2, cache_position, key_states) | 
					
						
						|  | v_out.index_copy_(2, cache_position, value_states) | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  |  | 
					
						
						|  | k_out[:, :, cache_position] = key_states | 
					
						
						|  | v_out[:, :, cache_position] = value_states | 
					
						
						|  |  | 
					
						
						|  | return k_out, v_out | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Returns the sequence length of the cached states that were seen by the model.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return self._seen_tokens | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  | """Returns the maximum sequence length of the cached states.""" | 
					
						
						|  | return self.max_cache_len | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | self._seen_tokens = 0 | 
					
						
						|  | """Resets the cache values while preserving the objects""" | 
					
						
						|  | for layer_idx in range(len(self.key_cache)): | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx].zero_() | 
					
						
						|  | self.value_cache[layer_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SlidingWindowCache(StaticCache): | 
					
						
						|  | """ | 
					
						
						|  | Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. | 
					
						
						|  | Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, | 
					
						
						|  | if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), | 
					
						
						|  | we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. | 
					
						
						|  |  | 
					
						
						|  | The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: | 
					
						
						|  |  | 
					
						
						|  | indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window | 
					
						
						|  | tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, | 
					
						
						|  | 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, | 
					
						
						|  | 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, | 
					
						
						|  | 55, 56, 57, 58, 59, 60, 61, 62, 63,  0]) | 
					
						
						|  |  | 
					
						
						|  | We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | config (`PretrainedConfig`): | 
					
						
						|  | The configuration file defining the shape-related attributes required to initialize the static cache. | 
					
						
						|  | max_batch_size (`int`): | 
					
						
						|  | The maximum batch size with which the model will be used. | 
					
						
						|  | max_cache_len (`int`): | 
					
						
						|  | The maximum sequence length with which the model will be used. | 
					
						
						|  | device (`torch.device`): | 
					
						
						|  | The device on which the cache should be initialized. Should be the same as the layer. | 
					
						
						|  | dtype (*optional*, defaults to `torch.float32`): | 
					
						
						|  | The default `dtype` to use when initializing the layer. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | 
					
						
						|  | >>> max_generated_length = inputs.input_ids.shape[1] + 10 | 
					
						
						|  | >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: | 
					
						
						|  | super().__init__(config, max_batch_size, max_cache_len, device, dtype) | 
					
						
						|  | if not hasattr(config, "sliding_window") or config.sliding_window is None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " | 
					
						
						|  | "sliding window attention, please check if there is a `sliding_window` field in the model " | 
					
						
						|  | "config and it's not set to None." | 
					
						
						|  | ) | 
					
						
						|  | max_cache_len = min(config.sliding_window, max_cache_len) | 
					
						
						|  | super().__init__( | 
					
						
						|  | config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor]: | 
					
						
						|  | cache_position = cache_kwargs.get("cache_position") | 
					
						
						|  | k_out = self.key_cache[layer_idx] | 
					
						
						|  | v_out = self.value_cache[layer_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cache_position.shape[0] > self.max_cache_len: | 
					
						
						|  | k_out = key_states[:, :, -self.max_cache_len :, :] | 
					
						
						|  | v_out = value_states[:, :, -self.max_cache_len :, :] | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] += k_out | 
					
						
						|  | self.value_cache[layer_idx] += v_out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return key_states, value_states | 
					
						
						|  |  | 
					
						
						|  | slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | 
					
						
						|  | cache_position = cache_position.clamp(0, self.max_cache_len - 1) | 
					
						
						|  | to_shift = cache_position >= self.max_cache_len - 1 | 
					
						
						|  | indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len | 
					
						
						|  |  | 
					
						
						|  | k_out = k_out[:, :, indices] | 
					
						
						|  | v_out = v_out[:, :, indices] | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | cache_position.to(device=k_out.device) | 
					
						
						|  | k_out.index_copy_(2, cache_position, key_states) | 
					
						
						|  | v_out.index_copy_(2, cache_position, value_states) | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  |  | 
					
						
						|  | k_out[:, :, cache_position] = key_states | 
					
						
						|  | v_out[:, :, cache_position] = value_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx].zero_() | 
					
						
						|  | self.value_cache[layer_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] += k_out | 
					
						
						|  | self.value_cache[layer_idx] += v_out | 
					
						
						|  |  | 
					
						
						|  | return k_out, v_out | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  |  | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | for layer_idx in range(len(self.key_cache)): | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx].zero_() | 
					
						
						|  | self.value_cache[layer_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EncoderDecoderCache(Cache): | 
					
						
						|  | """ | 
					
						
						|  | Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and | 
					
						
						|  | cross-attention caches. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") | 
					
						
						|  | >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare cache classes for encoder and decoder and pass it to model's forward | 
					
						
						|  | >>> self_attention_cache = DynamicCache() | 
					
						
						|  | >>> cross_attention_cache = DynamicCache() | 
					
						
						|  | >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.self_attention_cache = self_attention_cache | 
					
						
						|  | self.cross_attention_cache = cross_attention_cache | 
					
						
						|  |  | 
					
						
						|  | self.is_updated = {} | 
					
						
						|  | for layer_idx in range(len(cross_attention_cache.key_cache)): | 
					
						
						|  | self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | 
					
						
						|  | """ | 
					
						
						|  | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | 
					
						
						|  | sequence length. | 
					
						
						|  | """ | 
					
						
						|  | if layer_idx < len(self): | 
					
						
						|  | return ( | 
					
						
						|  | self.self_attention_cache.key_cache[layer_idx], | 
					
						
						|  | self.self_attention_cache.value_cache[layer_idx], | 
					
						
						|  | self.cross_attention_cache.key_cache[layer_idx], | 
					
						
						|  | self.cross_attention_cache.value_cache[layer_idx], | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | """ | 
					
						
						|  | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | 
					
						
						|  | to the number of layers in the model. | 
					
						
						|  | """ | 
					
						
						|  | return len(self.self_attention_cache) | 
					
						
						|  |  | 
					
						
						|  | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | 
					
						
						|  | """Converts the `EncoderDecoderCache` instance into  its equivalent in the legacy cache format.""" | 
					
						
						|  | legacy_cache = () | 
					
						
						|  | if len(self.cross_attention_cache) > 0: | 
					
						
						|  | for self_attn, cross_attn in zip( | 
					
						
						|  | self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() | 
					
						
						|  | ): | 
					
						
						|  | legacy_cache += (self_attn + cross_attn,) | 
					
						
						|  | else: | 
					
						
						|  | legacy_cache = self.self_attention_cache.to_legacy_cache() | 
					
						
						|  | return legacy_cache | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_legacy_cache( | 
					
						
						|  | cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | 
					
						
						|  | ) -> "EncoderDecoderCache": | 
					
						
						|  | """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" | 
					
						
						|  | cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) | 
					
						
						|  | if past_key_values is not None: | 
					
						
						|  | for layer_idx in range(len(past_key_values)): | 
					
						
						|  | key_states, value_states = past_key_values[layer_idx][:2] | 
					
						
						|  | cache.self_attention_cache.update(key_states, value_states, layer_idx) | 
					
						
						|  | if len(past_key_values[layer_idx]) > 2: | 
					
						
						|  | key_states, value_states = past_key_values[layer_idx][2:] | 
					
						
						|  | cache.cross_attention_cache.update(key_states, value_states, layer_idx) | 
					
						
						|  | cache.is_updated[layer_idx] = True | 
					
						
						|  | return cache | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | 
					
						
						|  | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | 
					
						
						|  | if len(self.self_attention_cache.key_cache) <= layer_idx: | 
					
						
						|  | return 0 | 
					
						
						|  | return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | if hasattr(self.self_attention_cache, "reset"): | 
					
						
						|  | self.self_attention_cache.reset() | 
					
						
						|  | if hasattr(self.cross_attention_cache, "reset"): | 
					
						
						|  | self.cross_attention_cache.reset() | 
					
						
						|  | elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " | 
					
						
						|  | "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " | 
					
						
						|  | f"Got {self.self_attention_cache.__str__()} for the self attention cache and " | 
					
						
						|  | f"{self.cross_attention_cache.__str__()} for the cross attention cache." | 
					
						
						|  | ) | 
					
						
						|  | for layer_idx in self.is_updated: | 
					
						
						|  | self.is_updated[layer_idx] = False | 
					
						
						|  |  | 
					
						
						|  | def reorder_cache(self, beam_idx: torch.LongTensor): | 
					
						
						|  | """Reorders the cache for beam search, given the selected beam indices.""" | 
					
						
						|  | self.self_attention_cache.reorder_cache(beam_idx) | 
					
						
						|  | self.cross_attention_cache.reorder_cache(beam_idx) | 
					
						
						|  |  | 
					
						
						|  | def check_dynamic_cache(self, method: str): | 
					
						
						|  | if not ( | 
					
						
						|  | isinstance(self.self_attention_cache, DynamicCache) | 
					
						
						|  | and isinstance(self.cross_attention_cache, DynamicCache) | 
					
						
						|  | ): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " | 
					
						
						|  | f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def crop(self, maximum_length: int): | 
					
						
						|  | """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be | 
					
						
						|  | negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" | 
					
						
						|  | self.check_dynamic_cache(self.crop.__name__) | 
					
						
						|  | self.self_attention_cache.crop(maximum_length) | 
					
						
						|  |  | 
					
						
						|  | def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": | 
					
						
						|  | """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | 
					
						
						|  | `_split_model_inputs()` in `generation.utils`""" | 
					
						
						|  | self.check_dynamic_cache(self.batch_split.__name__) | 
					
						
						|  | self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) | 
					
						
						|  | cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) | 
					
						
						|  |  | 
					
						
						|  | out = [] | 
					
						
						|  | for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): | 
					
						
						|  | out.append(EncoderDecoderCache(self_attn, cross_attn)) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": | 
					
						
						|  | """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | 
					
						
						|  | `generation.utils`""" | 
					
						
						|  | self_attention_cache = DynamicCache() | 
					
						
						|  | cross_attention_cache = DynamicCache() | 
					
						
						|  | for idx in range(len(splits[0])): | 
					
						
						|  | layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) | 
					
						
						|  | layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) | 
					
						
						|  | self_attention_cache.update(layer_keys, layer_values, idx) | 
					
						
						|  |  | 
					
						
						|  | layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) | 
					
						
						|  | layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) | 
					
						
						|  | cross_attention_cache.update(layer_keys, layer_values, idx) | 
					
						
						|  | return cls(self_attention_cache, cross_attention_cache) | 
					
						
						|  |  | 
					
						
						|  | def batch_repeat_interleave(self, repeats: int): | 
					
						
						|  | """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" | 
					
						
						|  | self.check_dynamic_cache(self.batch_repeat_interleave.__name__) | 
					
						
						|  | self.self_attention_cache.batch_repeat_interleave(repeats) | 
					
						
						|  | self.cross_attention_cache.batch_repeat_interleave(repeats) | 
					
						
						|  |  | 
					
						
						|  | def batch_select_indices(self, indices: torch.Tensor): | 
					
						
						|  | """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" | 
					
						
						|  | self.check_dynamic_cache(self.batch_select_indices.__name__) | 
					
						
						|  | self.self_attention_cache.batch_select_indices(indices) | 
					
						
						|  | self.cross_attention_cache.batch_select_indices(indices) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class HybridCache(Cache): | 
					
						
						|  | """ | 
					
						
						|  | Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention | 
					
						
						|  | and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention | 
					
						
						|  | and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | config (`PretrainedConfig): | 
					
						
						|  | The configuration file defining the shape-related attributes required to initialize the static cache. | 
					
						
						|  | max_batch_size (`int`): | 
					
						
						|  | The maximum batch size with which the model will be used. | 
					
						
						|  | max_cache_len (`int`): | 
					
						
						|  | The maximum sequence length with which the model will be used. | 
					
						
						|  | device (`torch.device`, *optional*, defaults to `"cpu"`): | 
					
						
						|  | The device on which the cache should be initialized. Should be the same as the layer. | 
					
						
						|  | dtype (*optional*, defaults to `torch.float32`): | 
					
						
						|  | The default `dtype` to use when initializing the layer. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | 
					
						
						|  | >>> max_generated_length = inputs.input_ids.shape[1] + 10 | 
					
						
						|  | >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | if not hasattr(config, "sliding_window") or config.sliding_window is None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " | 
					
						
						|  | "sliding window attention, please check if there is a `sliding_window` field in the model " | 
					
						
						|  | "config and it's not set to None." | 
					
						
						|  | ) | 
					
						
						|  | self.max_cache_len = max_cache_len | 
					
						
						|  | self.max_batch_size = max_batch_size | 
					
						
						|  |  | 
					
						
						|  | self.head_dim = ( | 
					
						
						|  | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.dtype = dtype if dtype is not None else torch.float32 | 
					
						
						|  | self.num_key_value_heads = ( | 
					
						
						|  | config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads | 
					
						
						|  | ) | 
					
						
						|  | self.is_sliding = torch.tensor( | 
					
						
						|  | [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device | 
					
						
						|  | ) | 
					
						
						|  | self.key_cache: List[torch.Tensor] = [] | 
					
						
						|  | self.value_cache: List[torch.Tensor] = [] | 
					
						
						|  | global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) | 
					
						
						|  | sliding_cache_shape = ( | 
					
						
						|  | max_batch_size, | 
					
						
						|  | self.num_key_value_heads, | 
					
						
						|  | min(config.sliding_window, max_cache_len), | 
					
						
						|  | self.head_dim, | 
					
						
						|  | ) | 
					
						
						|  | for i in range(config.num_hidden_layers): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape | 
					
						
						|  | new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | 
					
						
						|  | new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | 
					
						
						|  | torch._dynamo.mark_static_address(new_layer_key_cache) | 
					
						
						|  | torch._dynamo.mark_static_address(new_layer_value_cache) | 
					
						
						|  | self.key_cache.append(new_layer_key_cache) | 
					
						
						|  | self.value_cache.append(new_layer_value_cache) | 
					
						
						|  |  | 
					
						
						|  | def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | 
					
						
						|  | if cache_position.shape[0] > max_cache_len: | 
					
						
						|  | k_out = key_states[:, :, -max_cache_len:, :] | 
					
						
						|  | v_out = value_states[:, :, -max_cache_len:, :] | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] += k_out | 
					
						
						|  | self.value_cache[layer_idx] += v_out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return key_states, value_states | 
					
						
						|  |  | 
					
						
						|  | slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | 
					
						
						|  | cache_position = cache_position.clamp(0, max_cache_len - 1) | 
					
						
						|  | to_shift = cache_position >= max_cache_len - 1 | 
					
						
						|  | indices = (slicing + to_shift[-1].int() - 1) % max_cache_len | 
					
						
						|  | k_out = k_out[:, :, indices] | 
					
						
						|  | v_out = v_out[:, :, indices] | 
					
						
						|  |  | 
					
						
						|  | k_out[:, :, cache_position] = key_states | 
					
						
						|  | v_out[:, :, cache_position] = value_states | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx].zero_() | 
					
						
						|  | self.value_cache[layer_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] += k_out | 
					
						
						|  | self.value_cache[layer_idx] += v_out | 
					
						
						|  | return k_out, v_out | 
					
						
						|  |  | 
					
						
						|  | def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | 
					
						
						|  | k_out[:, :, cache_position] = key_states | 
					
						
						|  | v_out[:, :, cache_position] = value_states | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx] = k_out | 
					
						
						|  | self.value_cache[layer_idx] = v_out | 
					
						
						|  | return k_out, v_out | 
					
						
						|  |  | 
					
						
						|  | def update( | 
					
						
						|  | self, | 
					
						
						|  | key_states: torch.Tensor, | 
					
						
						|  | value_states: torch.Tensor, | 
					
						
						|  | layer_idx: int, | 
					
						
						|  | cache_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | ) -> Tuple[torch.Tensor]: | 
					
						
						|  | cache_position = cache_kwargs.get("cache_position") | 
					
						
						|  | sliding_window = cache_kwargs.get("sliding_window") | 
					
						
						|  | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) | 
					
						
						|  | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) | 
					
						
						|  | k_out = self.key_cache[layer_idx] | 
					
						
						|  | v_out = self.value_cache[layer_idx] | 
					
						
						|  | if sliding_window: | 
					
						
						|  | update_fn = self._sliding_update | 
					
						
						|  | else: | 
					
						
						|  | update_fn = self._static_update | 
					
						
						|  |  | 
					
						
						|  | return update_fn( | 
					
						
						|  | cache_position, | 
					
						
						|  | layer_idx, | 
					
						
						|  | key_states, | 
					
						
						|  | value_states, | 
					
						
						|  | k_out, | 
					
						
						|  | v_out, | 
					
						
						|  | k_out.shape[2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def get_max_length(self) -> Optional[int]: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return self.max_cache_len | 
					
						
						|  |  | 
					
						
						|  | def get_seq_length(self, layer_idx: Optional[int] = 0): | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | """Resets the cache values while preserving the objects""" | 
					
						
						|  | for layer_idx in range(len(self.key_cache)): | 
					
						
						|  |  | 
					
						
						|  | self.key_cache[layer_idx].zero_() | 
					
						
						|  | self.value_cache[layer_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MambaCache: | 
					
						
						|  | """ | 
					
						
						|  | Cache for mamba model which does not have attention mechanism and key value states. | 
					
						
						|  |  | 
					
						
						|  | Arguments: | 
					
						
						|  | config (`PretrainedConfig): | 
					
						
						|  | The configuration file defining the shape-related attributes required to initialize the static cache. | 
					
						
						|  | max_batch_size (`int`): | 
					
						
						|  | The maximum batch size with which the model will be used. | 
					
						
						|  | dtype (*optional*, defaults to `torch.float16`): | 
					
						
						|  | The default `dtype` to use when initializing the layer. | 
					
						
						|  | device (`torch.device`, *optional*): | 
					
						
						|  | The device on which the cache should be initialized. Should be the same as the layer. | 
					
						
						|  |  | 
					
						
						|  | Attributes: | 
					
						
						|  | dtype: (`torch.dtype`): | 
					
						
						|  | The default `dtype` used to initializing the cache. | 
					
						
						|  | intermediate_size: (`int`): | 
					
						
						|  | Model's intermediate_size taken from config. | 
					
						
						|  | ssm_state_size: (`int`): | 
					
						
						|  | Model's state_size taken from config. | 
					
						
						|  | conv_kernel_size: (`int`): | 
					
						
						|  | Model's convolution kernel size taken from config | 
					
						
						|  | conv_states: (`torch.Tensor`): | 
					
						
						|  | A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. | 
					
						
						|  | ssm_states: (`torch.Tensor`): | 
					
						
						|  | A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache | 
					
						
						|  |  | 
					
						
						|  | >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") | 
					
						
						|  | >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") | 
					
						
						|  |  | 
					
						
						|  | >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | >>> # Prepare a cache class and pass it to model's forward | 
					
						
						|  | >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) | 
					
						
						|  | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | 
					
						
						|  | >>> past_kv = outputs.past_key_values | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | max_batch_size: int, | 
					
						
						|  | dtype: torch.dtype = torch.float16, | 
					
						
						|  | device: Optional[str] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | self.dtype = dtype | 
					
						
						|  | self.max_batch_size = max_batch_size | 
					
						
						|  | self.intermediate_size = config.intermediate_size | 
					
						
						|  | self.ssm_state_size = config.state_size | 
					
						
						|  | self.conv_kernel_size = config.conv_kernel | 
					
						
						|  |  | 
					
						
						|  | self.conv_states: torch.Tensor = torch.zeros( | 
					
						
						|  | config.num_hidden_layers, | 
					
						
						|  | self.max_batch_size, | 
					
						
						|  | self.intermediate_size, | 
					
						
						|  | self.conv_kernel_size, | 
					
						
						|  | device=device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | ) | 
					
						
						|  | self.ssm_states: torch.Tensor = torch.zeros( | 
					
						
						|  | config.num_hidden_layers, | 
					
						
						|  | self.max_batch_size, | 
					
						
						|  | self.intermediate_size, | 
					
						
						|  | self.ssm_state_size, | 
					
						
						|  | device=device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | torch._dynamo.mark_static_address(self.conv_states) | 
					
						
						|  | torch._dynamo.mark_static_address(self.ssm_states) | 
					
						
						|  |  | 
					
						
						|  | def update_conv_state( | 
					
						
						|  | self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | conv_state = self.conv_states[layer_idx] | 
					
						
						|  | cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) | 
					
						
						|  |  | 
					
						
						|  | conv_state = conv_state.roll(shifts=-1, dims=-1) | 
					
						
						|  | conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) | 
					
						
						|  | self.conv_states[layer_idx].zero_() | 
					
						
						|  | self.conv_states[layer_idx] += conv_state | 
					
						
						|  | return self.conv_states[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): | 
					
						
						|  | self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) | 
					
						
						|  | return self.ssm_states[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | self.conv_states.zero_() | 
					
						
						|  | self.ssm_states.zero_() | 
					
						
						|  |  |