Spaces:
Runtime error
Runtime error
| import logging | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from enum import Enum, auto | |
| from transformers import BartModel, BartForConditionalGeneration, \ | |
| T5Model, T5ForConditionalGeneration, \ | |
| LEDModel, LEDForConditionalGeneration, \ | |
| AutoModelForCausalLM, AutoModelForSeq2SeqLM, \ | |
| MODEL_WITH_LM_HEAD_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
| from typing import TypeVar, Generic | |
| from .index_building import Datastore, DatastoreBatch | |
| logger = logging.getLogger('Unlimiformer') | |
| logger.setLevel(20) | |
| ModelType = TypeVar('ModelType') | |
| class Unlimiformer(Generic[ModelType]): | |
| def __init__(self, model: ModelType, | |
| layer_begin=-1, layer_end=None, | |
| unlimiformer_head_num=None, | |
| exclude_attention=False, | |
| model_encoder_max_len=None, | |
| chunk_overlap=0, | |
| verbose=False, save_heatmap=False, | |
| tokenizer=None, unlimiformer_training=False, | |
| use_datastore=False, | |
| flat_index=False, | |
| test_datastore=False, reconstruct_embeddings=False, | |
| gpu_datastore=False, gpu_index=False, | |
| index_devices=(0,), datastore_device=0, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| model.unlimiformer = self | |
| self.layer_begin = layer_begin | |
| self.layer_end = layer_end | |
| self.specific_head = unlimiformer_head_num | |
| self.exclude_attention = exclude_attention | |
| self.actual_model_window_size = None | |
| self.model_encoder_max_len = model_encoder_max_len | |
| self.chunk_overlap = chunk_overlap | |
| self.verbose = verbose | |
| self.save_heatmap = save_heatmap | |
| self.tokenizer = tokenizer | |
| self.unlimiformer_training = unlimiformer_training | |
| self.use_datastore = use_datastore | |
| self.flat_index = flat_index | |
| self.reconstruct_embeddings = reconstruct_embeddings | |
| self.gpu_datastore = gpu_datastore | |
| self.gpu_index = gpu_index | |
| # if torch.cuda.is_available() and gpu_index: | |
| # self.index_devices = [torch.device(f'cuda:{i}') for i in index_devices] | |
| # else: | |
| self.index_devices = [torch.device('cpu')] | |
| self.datastore_device = torch.device('cpu') | |
| self.test_datastore = test_datastore # flag for debugging | |
| self.device = torch.device('cpu') | |
| self.activation_capturer = None | |
| self.is_encoder_decoder = model.config.is_encoder_decoder | |
| self.hook_handles = [] | |
| self.is_input_encoding_pass = False | |
| self.is_first_test_decoding_step = False | |
| self.prev_tokens = None | |
| self.last_beam_idx = None | |
| self.heatmap = None | |
| self.cur_decoder_layer_index = None | |
| self.datastore = None | |
| self.break_into(model) | |
| def break_into(self, model): | |
| self.actual_model_window_size = self.window_size() | |
| if self.model_encoder_max_len is None: | |
| self.model_encoder_max_len = self.actual_model_window_size | |
| self.window_margin = int(self.model_encoder_max_len * self.chunk_overlap / 2) | |
| self.num_heads = model.config.num_attention_heads | |
| if self.specific_head is None: | |
| self.head_nums = Ellipsis # torch.arange(0, self.num_heads, device=self.device) | |
| else: | |
| self.head_nums = self.specific_head | |
| self.hooks_injected = False | |
| self.training_hooks_injected = False | |
| self.original_forward_func = model.forward | |
| # Activate Unlimiformer when calling model.eval(), deactivate for model.train() | |
| self.original_model_eval_func = model.eval | |
| model.eval = self.pre_eval_hook | |
| self.original_model_train_func = model.train | |
| model.train = self.pre_train_hook | |
| def pre_eval_hook(self): | |
| self.remove_training_hooks(self.model) | |
| self.inject_hooks(self.model) | |
| self.original_model_eval_func() | |
| def pre_train_hook(self, mode=True): | |
| # mode=True means model.train() is called | |
| # mode=False means model.eval() is called | |
| torch.cuda.empty_cache() | |
| if mode is True: | |
| self.break_out(self.model) | |
| if self.unlimiformer_training: | |
| self.inject_training_hooks(self.model) | |
| self.original_model_train_func(mode) | |
| def inject_hooks(self, model): | |
| if self.hooks_injected: | |
| return | |
| # Inject our activation_capturer to capture the activations at every forward pass | |
| attention_layers_to_capture = self.activation_to_capture(self.layer_begin, self.layer_end) | |
| self.activation_capturer = [] | |
| for layer in attention_layers_to_capture: | |
| if type(layer) is list: | |
| layer_capturers = [] | |
| for k_or_v in layer: | |
| capturer = ActivationCapturer(k_or_v, capture_input=False) | |
| layer_capturers.append(capturer) | |
| self.register_hook(k_or_v, capturer) | |
| self.activation_capturer.append(layer_capturers) | |
| else: | |
| capturer = ActivationCapturer(layer, capture_input=False) | |
| self.register_hook(layer, capturer) | |
| self.activation_capturer.append(capturer) | |
| # Inject our main function after the main attention function | |
| attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) | |
| for layer in attention_layers_to_run: | |
| self.register_hook(layer, self.attention_forward_hook) | |
| decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
| self.original_decoder_layer_cross_attn_forward_funcs = [] | |
| for i, decoder_layer in enumerate(decoder_layers_to_run): | |
| decoder_layer_cross_attention = self.cross_attention(decoder_layer) | |
| self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) | |
| decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) | |
| # Inject our hook function in the beginning of generation. | |
| # When the "model.generate()" will be called, it will first call our "reset_generation()" function, | |
| # and only then call "model.generate()" | |
| self.original_generate_func = model.generate | |
| model.generate = self.pre_generate_hook | |
| model.forward = self.pre_forward_hook | |
| self.original_reorder_cache_func = model._reorder_cache | |
| model._reorder_cache = self.reorder_cache_hook | |
| self.hooks_injected = True | |
| def inject_training_hooks(self, model): | |
| if self.training_hooks_injected: | |
| return | |
| # self.original_forward_func = model.forward | |
| model.forward = self.pre_forward_hook | |
| decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
| self.original_decoder_layer_self_attn_forward_funcs = [] | |
| for decoder_layer in decoder_layers_to_run: | |
| attention = self.self_attention(decoder_layer) | |
| self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) | |
| attention.forward = self.create_self_attn_pre_forward_hook(attention.forward) | |
| self.original_decoder_layer_cross_attn_forward_funcs = [] | |
| for i, decoder_layer in enumerate(decoder_layers_to_run): | |
| decoder_layer_cross_attention = self.cross_attention(decoder_layer) | |
| self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) | |
| decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) | |
| self.original_decoder_layer_forward_funcs = [] | |
| for decoder_layer in decoder_layers_to_run: | |
| self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) | |
| decoder_layer.forward = self.create_decoder_layer_func(decoder_layer.forward, decoder_layer) | |
| self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) | |
| attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) | |
| for layer in attention_layers_to_run: | |
| self.register_hook(layer, self.train_attention_forward_hook) | |
| self.training_hooks_injected = True | |
| def inject_hooks_for_unaffected_layers(self, model, decoder_layers_to_run): | |
| self.original_non_injected_decoder_layer_forward_funcs = [] | |
| non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) | |
| if l not in decoder_layers_to_run] | |
| for decoder_layer in non_injected_decoder_layers: | |
| self.original_non_injected_decoder_layer_forward_funcs.append(decoder_layer.forward) | |
| decoder_layer.forward = self.create_noninjected_decoder_layer_func(decoder_layer.forward, decoder_layer) | |
| def create_self_attn_pre_forward_hook(self, original_self_attn_forward_func): | |
| def self_attention_pre_forward_hook(*args, **kwargs): | |
| kwargs['past_key_value'] = None | |
| return original_self_attn_forward_func(*args, **kwargs) | |
| return self_attention_pre_forward_hook | |
| def create_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): | |
| def checkpointed_decoder_layer( | |
| hidden_states: torch.Tensor, | |
| attention_mask=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| layer_head_mask=None, | |
| cross_attn_layer_head_mask=None, | |
| past_key_value=None, | |
| output_attentions=False, | |
| position_bias=None, | |
| encoder_decoder_position_bias=None, | |
| use_cache=True): | |
| def forward_with_all_keys(hidden_states, attention_mask, | |
| encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
| cross_attn_layer_head_mask, past_key_value, | |
| output_attentions, use_cache, long_inputs, long_inputs_mask, | |
| position_bias, encoder_decoder_position_bias): | |
| key, value = self.create_key_value(long_inputs, decoder_layer) | |
| decoder_layer_args = self.create_decoder_layer_args( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| layer_head_mask=layer_head_mask, | |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| position_bias=position_bias, | |
| encoder_decoder_position_bias=encoder_decoder_position_bias, | |
| use_cache=use_cache, | |
| key=key,value=value) | |
| return decoder_layer_original_forward_func(**decoder_layer_args) | |
| return torch.utils.checkpoint.checkpoint( | |
| forward_with_all_keys, hidden_states, attention_mask, | |
| encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
| cross_attn_layer_head_mask, None, | |
| output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, | |
| position_bias, encoder_decoder_position_bias) | |
| return checkpointed_decoder_layer | |
| def create_noninjected_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): | |
| def checkpointed_decoder_layer( | |
| hidden_states: torch.Tensor, | |
| attention_mask=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| layer_head_mask=None, | |
| cross_attn_layer_head_mask=None, | |
| past_key_value=None, | |
| output_attentions=False, | |
| position_bias=None, | |
| encoder_decoder_position_bias=None, | |
| use_cache=True): | |
| def forward_with_all_keys(hidden_states, attention_mask, | |
| encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
| cross_attn_layer_head_mask, past_key_value, | |
| output_attentions, use_cache, long_inputs, long_inputs_mask, | |
| position_bias, encoder_decoder_position_bias): | |
| decoder_layer_args = self.create_decoder_layer_args( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| layer_head_mask=layer_head_mask, | |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| position_bias=position_bias, | |
| encoder_decoder_position_bias=encoder_decoder_position_bias, | |
| use_cache=use_cache, key=None, value=None) | |
| return decoder_layer_original_forward_func(**decoder_layer_args) | |
| return torch.utils.checkpoint.checkpoint( | |
| forward_with_all_keys, hidden_states, attention_mask, | |
| encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
| cross_attn_layer_head_mask, None, | |
| output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, | |
| position_bias, encoder_decoder_position_bias) | |
| return checkpointed_decoder_layer | |
| def register_hook(self, layer, func, pre=False): | |
| handle = layer.register_forward_pre_hook(func) if pre else layer.register_forward_hook(func) | |
| self.hook_handles.append(handle) | |
| def break_out(self, model): | |
| self.prompt_keys = [] | |
| self.prompt_values = [] | |
| self.prompt_attention_mask = [] | |
| self.generated_input_ids = [] | |
| torch.cuda.empty_cache() | |
| if not self.hooks_injected: | |
| return | |
| for h in self.hook_handles: | |
| h.remove() | |
| model.generate = self.original_generate_func | |
| model.forward = self.original_forward_func | |
| model._reorder_cache = self.original_reorder_cache_func | |
| decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
| for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): | |
| self.cross_attention(decoder_layer).forward = original_func | |
| self.hooks_injected = False | |
| def remove_training_hooks(self, model): | |
| self.long_inputs_encoded, self.long_inputs_mask = None, None | |
| if not self.training_hooks_injected: | |
| return | |
| for h in self.hook_handles: | |
| h.remove() | |
| model.forward = self.original_forward_func | |
| decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
| for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_self_attn_forward_funcs): | |
| self.self_attention(decoder_layer).forward = original_func | |
| for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): | |
| self.cross_attention(decoder_layer).forward = original_func | |
| for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_forward_funcs): | |
| decoder_layer.forward = original_func | |
| non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) | |
| if l not in decoder_layers_to_run] | |
| for decoder_layer, original_func in zip(non_injected_decoder_layers, self.original_non_injected_decoder_layer_forward_funcs): | |
| decoder_layer.forward = original_func | |
| self.training_hooks_injected = False | |
| def reset_memory(self, input_ids, attention_mask): | |
| if self.use_datastore: | |
| if self.is_encoder_decoder: | |
| self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, | |
| gpu_index=self.gpu_index, index_device=self.index_devices[0])] | |
| self.hidden_states = [[]] | |
| else: | |
| self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, | |
| gpu_index=self.gpu_index, index_device=self.index_devices[i % len(self.index_devices)]) | |
| for i in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
| self.hidden_states = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
| torch.cuda.empty_cache() | |
| self.prompt_input_ids = input_ids | |
| self.input_ids_size = input_ids.shape[-1] | |
| self.prompt_keys, self.prompt_values = None, None | |
| self.prev_tokens = [None for _ in range(len(self.original_decoder_layer_cross_attn_forward_funcs))] | |
| self.last_beam_idx = None | |
| self.cur_layer_key_value_placeholder = None | |
| self.is_input_encoding_pass = True | |
| if self.is_encoder_decoder: | |
| dummy_labels = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) | |
| else: | |
| dummy_labels = None | |
| if self.save_heatmap: | |
| if self.heatmap is not None: | |
| print(f'Generated: {self.tokenizer.decode(self.generated_input_ids[0])}') | |
| self.plot_heatmap(self.heatmap[0].detach().cpu().numpy()) | |
| self.heatmap = torch.tensor([], dtype=torch.float, device=input_ids.device) | |
| self.generated_input_ids = torch.tensor([], dtype=torch.long, device=input_ids.device) | |
| self.prompt_keys = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
| self.prompt_values = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
| self.prompt_attention_mask = [] | |
| window_indices = self.window_indices(input_ids.shape[-1]) | |
| for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: | |
| logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}') | |
| chunk = input_ids[:, context_start_ind:context_end_ind].to(self.device) | |
| chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind].to(self.device) | |
| with torch.inference_mode(): | |
| _ = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels) # , return_dict=True, output_hidden_states=True) | |
| if self.use_datastore: | |
| # TODO: verify with BART as well | |
| # hidden_states_to_index = [hidden_states.encoder_last_hidden_state] # list of length 1 of (batch, chunked_source_len, dim) | |
| hidden_states_to_index = [ | |
| layer_capturer.captured for layer_capturer in self.activation_capturer | |
| ] | |
| # hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end] | |
| to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index] | |
| to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] | |
| # to_apply_mask = to_apply_mask.log().to(to_add[0].dtype) | |
| to_apply_mask = to_apply_mask.to(to_add[0].dtype) | |
| if not self.reconstruct_embeddings: | |
| to_add_embeddings = to_add | |
| if not self.gpu_datastore: | |
| to_add_embeddings = [states.cpu() for states in to_add_embeddings] | |
| to_apply_mask = to_apply_mask.cpu() | |
| for i, layer_states in enumerate(to_add_embeddings): | |
| layer_states = layer_states * to_apply_mask.unsqueeze(-1) | |
| self.hidden_states[i].append(layer_states.to(self.datastore_device)) | |
| # list of len layers, inside it there is a list of len batch, each item is (masked_time, dim) | |
| # for i, to_add_layer in enumerate(to_add): | |
| # keys = [key[mask.bool()] for key, mask in zip(to_add_layer, to_apply_mask)] | |
| # self.datastore[i].add_keys(keys) | |
| if (not self.use_datastore) or self.test_datastore: | |
| layers_kv = [ | |
| self.process_key_value(layer_capturer) # (batch, head, time, dim) | |
| for layer_capturer in self.activation_capturer | |
| ] # list of pairs of (batch, head, time, dim) | |
| # list of (batch, head, chunked_time, dim) | |
| key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv] | |
| value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv] | |
| chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) | |
| # key = torch.stack(key, dim=0) # (num_layers, batch, head, time, dim) | |
| # value = torch.stack(value, dim=0) # (num_layers, batch, head, time, dim) | |
| for i, (layer_key, layer_value) in enumerate(zip(key, value)): | |
| self.prompt_keys[i].append(layer_key) # (num_layers, batch, head, chunked_source_len, dim) | |
| self.prompt_values[i].append(layer_value) # (num_layers, batch, head, chunked_source_len, dim) | |
| self.prompt_attention_mask.append(chunk_attention_mask) # (batch, chunked_source_len) | |
| if self.use_datastore: | |
| # keys are all in datastore already! | |
| if not self.reconstruct_embeddings: | |
| # self.hidden_states = [torch.cat(layer_hidden_states, axis=1) for layer_hidden_states in self.hidden_states] | |
| concat_hidden_states = [] | |
| for i in range(len(self.hidden_states)): | |
| concat_hidden_states.append(torch.cat(self.hidden_states[i], axis=1)) | |
| self.hidden_states[i] = None | |
| self.hidden_states = concat_hidden_states | |
| for datastore, layer_hidden_states in zip(self.datastore, self.hidden_states): | |
| datastore.train_index(layer_hidden_states) | |
| if (not self.use_datastore) or self.test_datastore: | |
| for i, (layer_keys, layer_values) in enumerate(zip(self.prompt_keys, self.prompt_values)): | |
| self.prompt_keys[i] = torch.cat(layer_keys, dim=-2) | |
| self.prompt_values[i] = torch.cat(layer_values, dim=-2) | |
| # self.prompt_keys = torch.cat(self.prompt_keys, dim=-2) # (num_layers, batch, head, source_len, dim) | |
| # self.prompt_values = torch.cat(self.prompt_values, dim=-2) # (num_layers, batch, head, source_len, dim) | |
| self.prompt_attention_mask = torch.cat(self.prompt_attention_mask, dim=-1) # (batch, source_len) | |
| self.is_input_encoding_pass = False | |
| if self.verbose: | |
| print(f'Input: ' | |
| f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' | |
| f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') | |
| print() | |
| def chunked_encode_input(self, input_ids, attention_mask): | |
| long_inputs_encoded = [] | |
| long_inputs_mask = [] | |
| window_indices = self.window_indices(input_ids.shape[-1]) | |
| self.is_input_encoding_pass = True | |
| for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: | |
| chunk = input_ids[:, context_start_ind:context_end_ind] | |
| chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] | |
| output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) | |
| encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) | |
| # list of (batch, head, chunked_time, dim) | |
| encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) | |
| chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) | |
| long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) | |
| long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) | |
| long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) | |
| long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) | |
| self.is_input_encoding_pass = False | |
| if self.verbose: | |
| print(f'Input: ' | |
| f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' | |
| f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') | |
| print() | |
| return long_inputs_encoded, long_inputs_mask | |
| def window_indices(self, total_seq_len): | |
| # Copied from SLED (Ivgy et al., 2022) | |
| # https://github.com/Mivg/SLED/blob/main/sled/modeling_sled.py#L467 | |
| if total_seq_len <= self.model_encoder_max_len: | |
| return [(0, total_seq_len, 0, total_seq_len)] | |
| else: | |
| results = [] | |
| # if self.chunk_overlap == 0: | |
| # stride = self.model_encoder_max_len | |
| stride = self.model_encoder_max_len - 2 * self.window_margin | |
| context_start = update_start_ind = 0 | |
| context_end = self.model_encoder_max_len | |
| if self.is_encoder_decoder: | |
| update_end_ind = context_end - self.window_margin | |
| else: | |
| update_end_ind = context_end | |
| # first window always should update from the beginning | |
| results.append((context_start, context_end, update_start_ind, update_end_ind)) | |
| while context_end < total_seq_len: | |
| context_end = min(total_seq_len, context_end + stride) | |
| context_start = ( | |
| context_start + stride if context_end < total_seq_len else total_seq_len - self.model_encoder_max_len | |
| ) | |
| update_start_ind = max(update_start_ind + stride, update_end_ind) | |
| # last window always should update until the end | |
| update_end_ind = ( | |
| min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len | |
| ) | |
| cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \ | |
| update_end_ind - context_start | |
| results.append((cs, ce, us, ue)) | |
| return results | |
| def pre_generate_hook(self, input_ids, **kwargs): | |
| if 'attention_mask' not in kwargs: | |
| kwargs['attention_mask'] = torch.ones_like(input_ids) | |
| self.reset_memory(input_ids, kwargs['attention_mask']) | |
| new_kwargs = kwargs | |
| if 'attention_mask' in kwargs: | |
| new_kwargs = {k: v for k, v in kwargs.items() if k != 'attention_mask'} | |
| new_kwargs['attention_mask'] = kwargs['attention_mask'][:, :self.actual_model_window_size].to(self.device) | |
| new_kwargs['use_cache'] = True | |
| if self.is_encoder_decoder: | |
| input_ids_prefix = input_ids[:, :self.actual_model_window_size] | |
| else: | |
| input_ids_prefix = input_ids[:, -self.actual_model_window_size:] | |
| input_ids_prefix = input_ids_prefix.to(self.device) | |
| return self.original_generate_func(input_ids_prefix, **new_kwargs) | |
| def pre_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
| self.set_gradient_checkpointing(False) | |
| if not self.is_input_encoding_pass: | |
| if self.model.training: | |
| # self.reset_memory(input_ids, attention_mask) | |
| self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) | |
| input_ids = input_ids[:, :self.actual_model_window_size] | |
| attention_mask = attention_mask[:, :self.actual_model_window_size] if attention_mask is not None else None | |
| # input_ids = input_ids[:, :self.model_encoder_max_len] | |
| # labels = labels[:, :self.model_encoder_max_len] if labels is not None else None | |
| else: | |
| if kwargs.get('past_key_values') is None: | |
| self.is_first_test_decoding_step = True | |
| if input_ids is not None: | |
| # self.input_ids_size += input_ids.shape[-1] | |
| self.input_ids_size += 1 | |
| if kwargs.get('decoder_input_ids') is not None: | |
| self.generated_input_ids = torch.cat([self.generated_input_ids, kwargs['decoder_input_ids']], axis=-1) | |
| result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs) | |
| self.is_first_test_decoding_step = False | |
| return result | |
| def create_cross_attn_pre_forward_hook(self, original_cross_attn_forward_func, cur_layer_num): | |
| def attention_pre_forward_hook(hidden_states, attention_mask=None, *args, **kwargs): | |
| self.cur_decoder_layer_index = cur_layer_num | |
| if kwargs.get('past_key_value') is not None: | |
| # it's a tuple, and we convert it to a list to be able to perform assignment | |
| # and modify its items from our attention_forward_hook | |
| self.cur_layer_key_value_placeholder = \ | |
| kwargs['past_key_value'] = list(kwargs['past_key_value']) # (batch, head, time, attn_dim) | |
| batch_size, tgt_len, dim = hidden_states.shape | |
| if self.model.training: | |
| # from: (batch, tgt_len, dim) to: (batch * tgt_len, 1, dim) | |
| hidden_states = hidden_states.reshape(-1, 1, hidden_states.shape[-1]) | |
| # from: (batch, 1, tgt_len, dim) to: (batch * tgt_len, 1, 1, dim) | |
| attention_mask = attention_mask.reshape(-1, 1, 1, attention_mask.shape[-1]) | |
| attn_output, attn_weights_reshaped, past_key_value = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) | |
| attn_output = attn_output.reshape(batch_size, tgt_len, dim) | |
| result = (attn_output, attn_weights_reshaped, past_key_value) | |
| else: | |
| result = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) | |
| # Uri: this part adds the generated tokens to the prompt. | |
| # However it was commented out because currently we always keep the generated tokens in the attention window | |
| # if not self.is_encoder_decoder and not self.is_input_encoding_pass and \ | |
| # past_key_value[0].shape[2] > self.prompt_keys[self.cur_decoder_layer_index].shape[2]: | |
| # self.prompt_keys[self.cur_decoder_layer_index] = torch.cat([self.prompt_keys[self.cur_decoder_layer_index], past_key_value[0][:,:,-1:]], dim=-2) | |
| # self.prompt_values[self.cur_decoder_layer_index] = torch.cat([self.prompt_values[self.cur_decoder_layer_index], past_key_value[1][:,:,-1:]], dim=-2) | |
| # if self.cur_decoder_layer_index == self.model.config.num_hidden_layers - 1: | |
| # self.prompt_attention_mask = torch.cat([ | |
| # self.prompt_attention_mask, | |
| # torch.ones([self.prompt_attention_mask.shape[0], 1], dtype=self.prompt_attention_mask.dtype).to(self.device)], dim=-1) | |
| return result | |
| return attention_pre_forward_hook | |
| def attention_forward_hook(self, module, input, output): | |
| # output: (batch, time, 3 * heads * attention_dim) | |
| if self.is_input_encoding_pass or self.is_first_test_decoding_step: | |
| return | |
| with torch.no_grad(): | |
| prompt_size = self.prompt_input_ids.shape[1] | |
| generated_size = self.input_ids_size - prompt_size | |
| window_size = self.cur_layer_key_value_placeholder[0].shape[-2] | |
| # topk = min(self.actual_model_window_size, attn_weights.shape[-1]) | |
| topk = min(prompt_size, window_size) | |
| if not self.is_encoder_decoder: | |
| topk = min(topk, window_size - generated_size + 1) | |
| if self.gpu_index: | |
| topk = min(topk, 2048) | |
| query = self.process_query(output)[:,-1] # (batch * beam, head, dim) | |
| query = query[:, self.head_nums] # (batch * beam, head, dim) | |
| if self.use_datastore: | |
| # query: (batch, beam, head, dim) | |
| # need to multiply by key vector | |
| # query.view(query.shape[0], query.shape[1] * query.shape[2]) | |
| # k_proj in attention? | |
| datastore_index = 0 if self.is_encoder_decoder else self.cur_decoder_layer_index | |
| attention_layer_list = self.get_kv_projections(self.layer_begin, self.layer_end) | |
| k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index] | |
| v_proj_layer = [layers[1] for layers in attention_layer_list][self.cur_decoder_layer_index] | |
| # modify query by k_projs | |
| k_proj = k_proj_layer.weight | |
| datastore_query = self.preprocess_query(query, k_proj) # (batch * beam, num_heads, embed_dim) | |
| batch_size = self.datastore[datastore_index].batch_size | |
| datastore_query = datastore_query.view((batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim) | |
| # then search | |
| if self.reconstruct_embeddings: | |
| # embeddings: (batch, beam * head, actual_model_window_size, dim) | |
| _, top_search_key_indices, embeddings = self.datastore[datastore_index].search_and_reconstruct(datastore_query, k=topk) | |
| else: | |
| _, top_search_key_indices = self.datastore[datastore_index].search(datastore_query, k=topk) | |
| # self.embeddings: (batch, src_len, dim) | |
| # indices: (batch, beam * head, actual_model_window_size) | |
| # embeddings: (batch, beam * head, actual_model_window_size, dim) | |
| embeddings = torch.take_along_dim(input=self.hidden_states[datastore_index].unsqueeze(1), | |
| indices=top_search_key_indices.unsqueeze(-1).to(self.hidden_states[datastore_index].device), dim=-2) | |
| embeddings = embeddings.to(self.device) | |
| # (batch, beam, head, actual_model_window_size) | |
| # top_search_key_scores = top_search_key_scores.reshape(batch_size, -1, *top_search_key_scores.shape[1:]) | |
| top_search_key_indices = top_search_key_indices.reshape(batch_size, -1, *top_search_key_indices.shape[1:]) | |
| # embeddings: (batch, beam, head, actual_model_window_size, dim) | |
| embeddings = embeddings.reshape(batch_size, -1, self.num_heads, *embeddings.shape[2:]) | |
| # raw_values are actually token indices; need to look them up | |
| if (not self.use_datastore) or self.test_datastore: | |
| this_layer_prompt_keys = self.prompt_keys[self.cur_decoder_layer_index] | |
| this_layer_prompt_values = self.prompt_values[self.cur_decoder_layer_index] | |
| # query: (batch * beam, head, dim) | |
| batch_size = self.prompt_input_ids.shape[0] | |
| beam_size = query.shape[0] // batch_size | |
| # query: (batch, beam, head, dim) | |
| query = query.reshape(batch_size, beam_size, *query.shape[1:]) | |
| # this_layer_prompt_keys: (batch, head, source_len, dim) | |
| # this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) | |
| # query.unsqueeze(-1): (batch, beam, head, dim, 1) | |
| # attn_weights: (batch, beam, head, source_len) | |
| attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums], query.unsqueeze(-1)).squeeze(-1) | |
| # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) | |
| prompt_attention_mask_to_add = (1 - self.prompt_attention_mask) * -1e9 # (batch, source_len) | |
| prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1) | |
| attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) | |
| if self.exclude_attention and attn_weights.shape[-1] > self.actual_model_window_size: | |
| attn_weights[..., :self.actual_model_window_size] -= 1e9 | |
| # target_keys, target_values, topk = self.get_target_slices(output) | |
| top_key_scores, top_key_indices = torch.topk(attn_weights, k=topk, dim=-1, sorted=True) # (batch, beam, head, trunc_source) | |
| if self.save_heatmap: | |
| # heatrow: (beam, heads, source_len) | |
| heatrow = torch.zeros([top_key_indices.shape[1], top_key_indices.shape[2], this_layer_prompt_keys.shape[-2]], dtype=torch.float) | |
| heatrow = heatrow.scatter(index=top_key_indices[0], src=torch.ones_like(top_key_scores[0]), dim=-1) | |
| # heatrow = torch.nn.functional.softmax(heatrow, dim=-1) | |
| # self.heatmap: (beam, heads, targets, source_len) | |
| self.heatmap = torch.cat([self.heatmap, heatrow.unsqueeze(-2)], axis=-2) | |
| if self.test_datastore: | |
| assert top_key_indices.shape == top_search_key_indices.shape | |
| assert torch.mean((top_key_indices == top_search_key_indices).float()) > 0.99 | |
| if self.verbose: | |
| if self.is_encoder_decoder: | |
| for i, beam in enumerate(self.generated_input_ids): | |
| print(f'({i}) Generated: {self.tokenizer.decode(beam)}') | |
| # else: | |
| # print(f'Generated: {self.tokenizer.decode(self.input_ids)}') | |
| print() | |
| if self.use_datastore: | |
| # k_proj_layer.weight, v_proj_layer.weight: (embed_dim, embed_dim) | |
| # embeddings: (batch, beam, head, encoder_len, embed_dim) | |
| retrieved_keys, retrieved_values = self.post_process_retrieved(embeddings, k_proj_layer, v_proj_layer, top_search_key_indices) | |
| else: | |
| # this_layer_prompt_keys: (batch, head, source_len, dim) | |
| # top_key_indices: (batch, beam, head, trunc_source) | |
| retrieved_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, head, trunc_source, attn_dim) | |
| retrieved_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, head, trunc_source, attn_dim) | |
| if self.test_datastore: | |
| correct_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, head, trunc_source, attn_dim) | |
| correct_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, head, trunc_source, attn_dim) | |
| assert correct_keys.shape == retrieved_keys.shape | |
| assert correct_values.shape == retrieved_values.shape | |
| assert torch.mean(torch.isclose(correct_keys, retrieved_keys, rtol=1e-3, atol=1e-3).float()) > 0.99 | |
| assert torch.mean(torch.isclose(correct_values, retrieved_values, rtol=1e-3, atol=1e-3).float()) > 0.99 | |
| # retrieved_keys, retrieved_values: (batch * beam, head, encoder_len, attn_dim) | |
| retrieved_keys = retrieved_keys.flatten(0, 1)[:,:,:topk] | |
| retrieved_values = retrieved_values.flatten(0, 1)[:,:,:topk] | |
| self.cur_layer_key_value_placeholder[0] = torch.cat([retrieved_keys, self.cur_layer_key_value_placeholder[0][:,:,topk:]], dim=-2) | |
| self.cur_layer_key_value_placeholder[1] = torch.cat([retrieved_values, self.cur_layer_key_value_placeholder[1][:,:,topk:]], dim=-2) | |
| return | |
| def train_attention_forward_hook(self, module, input, output): | |
| # output: (batch, time, 3 * heads * attention_dim) | |
| if self.is_input_encoding_pass or self.is_first_test_decoding_step: | |
| return | |
| this_layer_prompt_keys = self.cur_layer_key_value_placeholder[0] | |
| this_layer_prompt_values = self.cur_layer_key_value_placeholder[1] | |
| with torch.no_grad(): | |
| query = self.process_query(output) # (batch * beam, tgt_len, head, dim) | |
| # query = query[:, :, self.head_nums] # (batch * beam, head, dim) | |
| # query: (batch * beam, tgt_len, head, dim) | |
| batch_size = this_layer_prompt_keys.shape[0] | |
| tgt_len = query.shape[0] // batch_size | |
| # query: (batch, tgt, head, dim) | |
| query = query.reshape(batch_size, tgt_len, *query.shape[2:]) | |
| # this_layer_prompt_keys: (batch, head, source_len, dim) | |
| # this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) | |
| # attn_weights: (batch, tgt_len, head, 1, source_len) | |
| # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1).permute(0,1,2,4,3)) | |
| attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \ | |
| .reshape(batch_size, tgt_len, query.shape[-2], 1, this_layer_prompt_keys.shape[-2]) | |
| # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) | |
| prompt_attention_mask_to_add = (1 - self.long_inputs_mask) * -1e9 # (batch, source_len) | |
| prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1).unsqueeze(1) | |
| attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) | |
| # target_keys, target_values, topk = self.get_target_slices(output) | |
| topk = min(self.actual_model_window_size, attn_weights.shape[-1]) | |
| top_key_scores, top_key_indices = torch.topk(attn_weights, k=min(topk, attn_weights.shape[-1]), dim=-1, sorted=True) # (batch, beam, head, tgt, trunc_source) | |
| # this_layer_prompt_keys: (batch, head, source_len, dim) | |
| # top_key_indices: (batch, tgt_len, head, 1, trunc_source) | |
| new_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) | |
| new_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
| dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) | |
| # (batch * beam, head, tgt_len, trunc_source, attn_dim) | |
| self.cur_layer_key_value_placeholder[0] = new_keys.flatten(0, 1).squeeze(2) | |
| self.cur_layer_key_value_placeholder[1] = new_values.flatten(0, 1).squeeze(2) | |
| return | |
| def preprocess_query(self, query, k_proj_weight): | |
| k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) | |
| datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) | |
| datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim) | |
| datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) | |
| return datastore_query | |
| def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): | |
| embed_dim = embeddings.shape[-1] | |
| k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) | |
| k_bias = 0 | |
| if k_proj_layer.bias is not None: | |
| k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
| v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) | |
| v_bias = 0 | |
| if v_proj_layer.bias is not None: | |
| v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
| # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) | |
| retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) | |
| retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) | |
| return retrieved_keys, retrieved_values | |
| def set_gradient_checkpointing(self, value): | |
| self.model.base_model.decoder.gradient_checkpointing = value | |
| def reorder_cache_hook(self, past, beam_idx): | |
| self.last_beam_idx = beam_idx | |
| self.generated_input_ids = self.generated_input_ids[beam_idx] | |
| for i, layer_prev_tokens in enumerate(self.prev_tokens): | |
| if layer_prev_tokens is not None: | |
| self.prev_tokens[i] = layer_prev_tokens.flatten(0, 1)[beam_idx].reshape(layer_prev_tokens.shape) | |
| if self.save_heatmap and self.heatmap.numel() > 0: | |
| self.heatmap = self.heatmap[beam_idx] | |
| return self.original_reorder_cache_func(past, beam_idx) | |
| def convert_model(cls, model, *args, **kwargs): | |
| # if type(model.config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: | |
| # elif type(model.config) in MODEL_WITH_LM_HEAD_MAPPING: | |
| # else: | |
| # raise ValueError(f'Unsupported model type: {type(model.config)}') | |
| # if model.config.is_encoder_decoder: | |
| # model_clone = AutoModelForSeq2SeqLM.from_config(model.config) | |
| # else: | |
| # model_clone = AutoModelForCausalLM.from_config(model.config) | |
| # model_clone.load_state_dict(model.state_dict()).to(args.device) | |
| type_to_class = { | |
| BartModel: UnlimiformerBART, | |
| BartForConditionalGeneration: UnlimiformerBART, | |
| T5Model: UnlimiformerT5, | |
| T5ForConditionalGeneration: UnlimiformerT5, | |
| LEDModel: UnlimiformerLED, | |
| LEDForConditionalGeneration: UnlimiformerLED, | |
| # LlamaModel: UnlimiformerLLaMa, | |
| # LlamaForCausalLM: UnlimiformerLLaMa, | |
| } | |
| type_to_class[type(model)](model, *args, **kwargs) | |
| return model | |
| def plot_heatmap(self, data, xticklabels='auto', yticklabels='auto'): | |
| # data: (heads, targets, source_len) | |
| import seaborn as sb | |
| import matplotlib.pyplot as plt | |
| # print('gat = np.array([') | |
| # for row in data[0]: | |
| # rowstr = ', '.join([f'{x:.2f}' for x in row]) | |
| # print(f' [{rowstr}],') | |
| # print(']') | |
| # sb.set(font_scale=1.5, rc={'text.usetex': True}) | |
| for i in range(data.shape[0]): | |
| fig, axes = plt.subplots(1, 1, figsize=(40, 100)) | |
| cur_ax = axes | |
| axes.set_title(f'Head #{i}, length: {data.shape[2]}, target length: {data.shape[1]}') | |
| cur_ax = axes | |
| # annot = [[x for x in row] for row in data] | |
| ax = sb.heatmap(data[i], annot=False, fmt='.2f', | |
| xticklabels=512, yticklabels=yticklabels, ax=cur_ax) | |
| ax.xaxis.tick_top() | |
| plt.savefig(f'knns_head{i}.pdf') | |
| # plt.savefig('gat_s10_contrast.pdf') | |
| plt.show() | |
| class UnlimiformerBART(Unlimiformer[BartModel]): | |
| def __init__(self, model: BartModel, *args, **kwargs): | |
| super().__init__(model, *args, **kwargs) | |
| def create_key_value(self, encoder_hidden_states, decoder_layer): | |
| # (batch, time, hidden_dim) | |
| attention = decoder_layer.encoder_attn | |
| # key, value: (batch, heads, time, attn_dim) | |
| key = attention.k_proj(encoder_hidden_states) | |
| key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| value = attention.v_proj(encoder_hidden_states) | |
| value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| # key, value: (batch, heads, time, attn_dim) | |
| return key, value | |
| def process_key_value(self, capturers): | |
| key_capturer, value_capturer = capturers | |
| key, value = key_capturer.captured, value_capturer.captured | |
| # (batch, time, heads, attn_dim) | |
| attention = self.model.base_model.decoder.layers[-1].encoder_attn | |
| # query, key, value: (batch, heads, time, attn_dim) | |
| # query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| return key, value | |
| def process_query(self, output): | |
| # (batch, time, heads, attn_dim) | |
| attention = self.model.base_model.decoder.layers[-1].encoder_attn | |
| # query: (batch, heads, time, attn_dim) | |
| # query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() | |
| return query | |
| def get_kv_projections(self, layer_begin, layer_end): | |
| return [ | |
| [layer.encoder_attn.k_proj, layer.encoder_attn.v_proj] | |
| for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] | |
| ] | |
| def activation_to_capture(self, layer_begin, layer_end): | |
| if self.use_datastore: | |
| return [self.model.base_model.encoder.layers[-1]] | |
| else: | |
| return self.get_kv_projections(layer_begin, layer_end) | |
| def attention_op_to_run(self, layer_begin, layer_end): | |
| return [ | |
| layer.encoder_attn.q_proj | |
| for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] | |
| ] | |
| def attention_layer_to_run(self, layer_begin, layer_end): | |
| return self.model.base_model.decoder.layers[layer_begin:layer_end] | |
| def self_attention(self, decoder_layer): | |
| return decoder_layer.self_attn | |
| def cross_attention(self, decoder_layer): | |
| return decoder_layer.encoder_attn | |
| def window_size(self): | |
| return self.model.config.max_position_embeddings | |
| def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, | |
| encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, | |
| past_key_value, output_attentions, position_bias, | |
| encoder_decoder_position_bias, use_cache, key, value): | |
| args = {'hidden_states': hidden_states, | |
| 'attention_mask': attention_mask, | |
| 'encoder_hidden_states': encoder_hidden_states, | |
| 'encoder_attention_mask': encoder_attention_mask, | |
| 'layer_head_mask': layer_head_mask, | |
| 'cross_attn_layer_head_mask': cross_attn_layer_head_mask, | |
| 'past_key_value': (None, None, key, value), | |
| 'output_attentions': output_attentions, | |
| 'use_cache': use_cache,} | |
| if key is None and value is None: | |
| args['past_key_value'] = None | |
| return args | |
| class UnlimiformerT5(Unlimiformer[T5Model]): | |
| def __init__(self, model: T5Model, *args, **kwargs): | |
| super().__init__(model, *args, **kwargs) | |
| def create_key_value(self, encoder_hidden_states, decoder_layer): | |
| # (batch, time, hidden_dim) | |
| attention = decoder_layer.layer[1].EncDecAttention | |
| # key, value: (batch, heads, time, attn_dim) | |
| key = attention.k(encoder_hidden_states) | |
| key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
| value = attention.v(encoder_hidden_states) | |
| value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
| return key, value | |
| def process_key_value(self, capturers): | |
| key_capturer, value_capturer = capturers | |
| key, value = key_capturer.captured, value_capturer.captured | |
| # (batch, time, heads, attn_dim) | |
| attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention | |
| # query, key, value: (batch, heads, time, attn_dim) | |
| # query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
| value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
| return key, value | |
| def process_query(self, output): | |
| # (batch, time, heads, attn_dim) | |
| attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention | |
| # query: (batch, heads, time, attn_dim) | |
| query = output.view(output.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).contiguous() | |
| return query | |
| def get_kv_projections(self, layer_begin, layer_end): | |
| return [ | |
| [layer.layer[1].EncDecAttention.k, layer.layer[1].EncDecAttention.v] | |
| for layer in self.model.base_model.decoder.block[layer_begin:layer_end] | |
| ] | |
| def activation_to_capture(self, layer_begin, layer_end): | |
| if self.use_datastore: | |
| return [self.model.base_model.encoder.layers[-1]] | |
| else: | |
| return self.get_kv_projections(layer_begin, layer_end) | |
| def attention_op_to_run(self, layer_begin, layer_end): | |
| return [ | |
| layer.layer[1].EncDecAttention.q | |
| for layer in self.model.base_model.decoder.block[layer_begin:layer_end] | |
| ] | |
| def attention_layer_to_run(self, layer_begin, layer_end): | |
| return self.model.base_model.decoder.block[layer_begin:layer_end] | |
| def self_attention(self, decoder_layer): | |
| return decoder_layer.layer[0] | |
| def cross_attention(self, decoder_layer): | |
| return decoder_layer.layer[1] | |
| def window_size(self): | |
| try: | |
| size = self.model.config.n_positions | |
| except AttributeError: | |
| size = 1024 | |
| return size | |
| def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, | |
| encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, | |
| past_key_value, output_attentions, position_bias, | |
| encoder_decoder_position_bias, use_cache, key, value): | |
| args = {'hidden_states': hidden_states, | |
| 'attention_mask': attention_mask, | |
| 'position_bias': position_bias, | |
| 'encoder_hidden_states': encoder_hidden_states, | |
| 'encoder_attention_mask': encoder_attention_mask, | |
| 'encoder_decoder_position_bias': encoder_decoder_position_bias, | |
| 'layer_head_mask': layer_head_mask, | |
| 'cross_attn_layer_head_mask': cross_attn_layer_head_mask, | |
| 'past_key_value': (None, None, key, value), | |
| 'use_cache': use_cache, | |
| 'output_attentions': output_attentions} | |
| if key is None and value is None: | |
| args['past_key_value'] = None | |
| return args | |
| class UnlimiformerLED(UnlimiformerBART): | |
| def __init__(self, model: LEDModel, *args, **kwargs): | |
| super().__init__(model, *args, **kwargs) | |
| def window_size(self): | |
| return self.model.config.max_encoder_position_embeddings | |
| # class UnlimiformerLLaMa(Unlimiformer[LlamaModel]): | |
| # def __init__(self, model: LlamaModel, *args, **kwargs): | |
| # super().__init__(model, *args, **kwargs) | |
| # def get_kv_projections(self, layer_begin, layer_end): | |
| # return [ | |
| # [layer.self_attn.k_proj, layer.self_attn.v_proj] | |
| # for layer in self.model.base_model.layers[layer_begin:layer_end] | |
| # ] | |
| # def activation_to_capture(self, layer_begin, layer_end): | |
| # if self.use_datastore: | |
| # return [ | |
| # layer.input_layernorm | |
| # for layer in self.model.base_model.layers[layer_begin:layer_end] | |
| # ] | |
| # else: | |
| # return self.get_kv_projections(layer_begin, layer_end) | |
| # def attention_op_to_run(self, layer_begin, layer_end): | |
| # return [ | |
| # layer.self_attn.q_proj | |
| # for layer in self.model.base_model.layers[layer_begin:layer_end] | |
| # ] | |
| # def attention_layer_to_run(self, layer_begin, layer_end): | |
| # return self.model.base_model.layers[layer_begin:layer_end] | |
| # def self_attention(self, decoder_layer): | |
| # return decoder_layer.self_attn | |
| # def cross_attention(self, decoder_layer): | |
| # return decoder_layer.self_attn | |
| # def window_size(self): | |
| # return self.model.config.max_position_embeddings | |
| # def set_gradient_checkpointing(self, value): | |
| # self.model.base_model.gradient_checkpointing = value | |
| # def process_key_value(self, capturers): | |
| # key_capturer, value_capturer = capturers | |
| # # (batch, time, heads * attn_dim) | |
| # key, value = key_capturer.captured, value_capturer.captured | |
| # attention = self.model.base_model.layers[-1].self_attn | |
| # # (batch, heads, time, attn_dim) | |
| # key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| # value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
| # return key, value | |
| # def process_query(self, output): | |
| # # output: (batch, time, heads * attn_dim) | |
| # attention = self.model.base_model.layers[-1].self_attn | |
| # # query: (batch, time, heads, attn_dim) | |
| # query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() | |
| # return query | |
| # def rotate_half(self, x): | |
| # """Rotates half the hidden dims of the input.""" | |
| # x1 = x[..., : x.shape[-1] // 2] | |
| # x2 = x[..., x.shape[-1] // 2 :] | |
| # return torch.cat((-x2, x1), dim=-1) | |
| # def preprocess_query(self, query, k_proj_weight): | |
| # # query: (batch * time, head, dim) | |
| # attention = self.model.base_model.layers[-1].self_attn | |
| # num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1], self.actual_model_window_size) | |
| # cos, sin = attention.rotary_emb(query, seq_len=num_generated) | |
| # cos = cos[:,:,-1] # [1, 1, dim] | |
| # sin = sin[:,:,-1] # [1, 1, dim] | |
| # # cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim] | |
| # # sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim] | |
| # query = (query * cos) + (self.rotate_half(query) * sin) | |
| # k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) | |
| # k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :] | |
| # k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :] | |
| # k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2) | |
| # datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) | |
| # datastore_query = torch.matmul(datastore_query, k_proj + k_proj_rotated) # (batch * beam, num_heads, 1, embed_dim) | |
| # datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) | |
| # return datastore_query | |
| # def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): | |
| # embed_dim = embeddings.shape[-1] | |
| # k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) | |
| # k_bias = 0 | |
| # if k_proj_layer.bias is not None: | |
| # k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
| # v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) | |
| # v_bias = 0 | |
| # if v_proj_layer.bias is not None: | |
| # v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
| # # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) | |
| # retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) | |
| # retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) | |
| # attention = self.model.base_model.layers[-1].self_attn | |
| # cos, sin = attention.rotary_emb(retrieved_values, seq_len=self.hidden_states[0].shape[1]) | |
| # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | |
| # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | |
| # if self.prompt_input_ids.shape[1] > self.actual_model_window_size: | |
| # # scale the top key indices to the actual model window size, such that the model will not see | |
| # # positional embeddings that did not appear at training time | |
| # scaled_key_indices = ((top_search_key_indices / self.prompt_input_ids.shape[1]) * self.actual_model_window_size).int() | |
| # else: | |
| # scaled_key_indices = top_search_key_indices | |
| # # top_search_key_indices = top_search_key_indices.to(cos.device) | |
| # scaled_key_indices = scaled_key_indices.to(cos.device) | |
| # cos = cos[scaled_key_indices] # [bs, 1, seq_len, dim] | |
| # sin = sin[scaled_key_indices] # [bs, 1, seq_len, dim] | |
| # retrieved_keys = (retrieved_keys * cos) + (self.rotate_half(retrieved_keys) * sin) | |
| # return retrieved_keys, retrieved_values | |
| class ActivationCapturer(nn.Module): | |
| def __init__(self, layer, capture_input=False): | |
| super().__init__() | |
| self.layer = layer | |
| self.capture_input = capture_input | |
| self.captured = None | |
| def unwrap_tuple(self, t): | |
| if isinstance(t, tuple) and len(t) == 1: | |
| t = t[0] | |
| return t | |
| def forward(self, module, layer_input, layer_output): | |
| if self.capture_input: | |
| self.captured = self.unwrap_tuple(layer_input) | |
| else: | |
| self.captured = self.unwrap_tuple(layer_output) | |