|
|
|
|
|
from collections import OrderedDict |
|
|
from hamcrest import is_ |
|
|
import torch |
|
|
import logging |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import copy |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from torch import Tensor |
|
|
from .BranchyModelConfig import BranchyModelConfig |
|
|
from typing import List, Optional, Dict, Tuple |
|
|
from transformers import AutoModelForCausalLM, PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.utils import ModelOutput |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.modeling_attn_mask_utils import ( |
|
|
_prepare_4d_causal_attention_mask, |
|
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
|
) |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def breaking_ties(tensor: torch.Tensor): |
|
|
""" |
|
|
Break ties in a tensor by subtracting the second highest value from the highest value. |
|
|
|
|
|
Args: |
|
|
tensor (torch.Tensor): The tensor to break ties in. shape [..., vocab_size] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The tensor with ties broken. shape [...] |
|
|
|
|
|
Example: |
|
|
Input : Tensor of shape [head_number, batch, seq_len, vocab_size] |
|
|
Output: Tensor of shape [head_number, batch, seq_len] |
|
|
""" |
|
|
return torch.sub(torch.topk(tensor, 2, dim=-1).values[..., 0], torch.topk(tensor, 2, dim=-1).values[..., 1]) |
|
|
|
|
|
class Branch(nn.Module): |
|
|
""" |
|
|
A branch module for use in the BranchyModel, representing an auxiliary output head attached at a specified layer |
|
|
within a transformer model. Each branch processes the output of its corresponding layer and produces an output |
|
|
which can be used for early exits or auxiliary tasks. |
|
|
|
|
|
This class is designed to be flexible, allowing for different configurations of the linear layer based on the |
|
|
underlying model's architecture. |
|
|
|
|
|
Attributes: |
|
|
layernorm (torch.nn.LayerNorm): Applies Layer Normalization over a mini-batch of inputs. |
|
|
lm_head (torch.nn.Linear): The linear layer that maps the hidden states to the vocabulary size, producing |
|
|
the output logits for each token in the sequence. |
|
|
|
|
|
Example Usage: |
|
|
# Assuming `config` is an instance of the model's configuration class with attributes `hidden_size` and |
|
|
# `vocab_size` properly set. |
|
|
branch = Branch(config) |
|
|
|
|
|
# `x` is a tensor representing the output from a transformer layer, shaped as [batch_size, seq_length, hidden_size] |
|
|
output_logits = branch(x) |
|
|
""" |
|
|
def __init__(self, config: BranchyModelConfig): |
|
|
""" |
|
|
Initializes the Branch module. |
|
|
|
|
|
Args: |
|
|
config (PretrainedConfig): The configuration object containing parameters like hidden size and vocabulary |
|
|
size. This object provides the necessary settings for initializing the layer normalization and linear |
|
|
layers within the Branch. |
|
|
""" |
|
|
super().__init__() |
|
|
self.layernorm: nn.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.lm_head: nn.Linear = nn.Linear(config.hidden_size, config.vocab_size, bias=True) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
""" |
|
|
Forward pass through the Branch module. |
|
|
|
|
|
Args: |
|
|
x (Tensor): Input tensor of shape [batch_size, seq_length, hidden_size], representing the output |
|
|
from a transformer layer. |
|
|
|
|
|
Returns: |
|
|
Tensor: Output logits of shape [batch_size, seq_length, vocab_size], resulting from passing the |
|
|
input through layer normalization and a linear layer. |
|
|
""" |
|
|
x = self.layernorm(x) |
|
|
x = self.lm_head(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class BranchyCausalModel(PreTrainedModel): |
|
|
"""A class for Causal branchy Model, this one integrate the early exit mechanism and only output one logit on each step as a conventional model. |
|
|
""" |
|
|
config_class = BranchyModelConfig |
|
|
|
|
|
def __init__(self, |
|
|
config: BranchyModelConfig): |
|
|
super().__init__(config) |
|
|
self.model = AutoModelForCausalLM.from_pretrained(config.model_str) |
|
|
self.lm_head = self.model.lm_head |
|
|
self.vocab_size = self.model.vocab_size |
|
|
self.model = self.model.model |
|
|
self.head_thresholds = torch.tensor(config.head_thresholds) |
|
|
self.confidence_metric_fn = breaking_ties |
|
|
|
|
|
|
|
|
if hasattr(self.model.config, "n_layer") or hasattr(self.model.config, "num_hidden_layers"): |
|
|
self.num_layers = ( |
|
|
self.model.config.n_layer |
|
|
if hasattr(self.model.config, "n_layer") |
|
|
else self.model.config.num_hidden_layers |
|
|
) |
|
|
assert self.num_layers is not None and self.num_layers > 0, "n_layer must be a positive integer." |
|
|
else: |
|
|
raise ValueError("cannot find n_layer in config") |
|
|
|
|
|
assert config.branch_number < self.num_layers , "branch_number must be a positive integer less than the number of layers in the model." |
|
|
|
|
|
|
|
|
if config.branch_locations is None: |
|
|
interval = self.num_layers // (config.branch_number + 1) |
|
|
config.branch_locations = [i * interval for i in range(1, config.branch_number+1)] |
|
|
|
|
|
|
|
|
if any([loc >= self.num_layers for loc in config.branch_locations]): |
|
|
raise ValueError("Branch location exceeds the number of layers in the model.") |
|
|
|
|
|
self.branches = torch.nn.ModuleList() |
|
|
if config.copy_lm_head: |
|
|
logger.info("Fine-tuning branches") |
|
|
for branch in config.branch_locations: |
|
|
self.branches.append(copy.deepcopy(self.lm_head)) |
|
|
else: |
|
|
for _ in config.branch_locations: |
|
|
new_branch = Branch(self.model.config) |
|
|
new_branch.apply(self.model._init_weights) |
|
|
self.branches.append(new_branch) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
self = super().to(*args, **kwargs) |
|
|
self.model = self.model.to(*args, **kwargs) |
|
|
self.head_thresholds = self.head_thresholds.to(*args, **kwargs) |
|
|
return self |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
|
): |
|
|
if past_key_values is not None: |
|
|
if isinstance(past_key_values, Cache): |
|
|
cache_length = past_key_values.get_seq_length() |
|
|
past_length = past_key_values.seen_tokens |
|
|
max_cache_length = past_key_values.get_max_length() |
|
|
else: |
|
|
cache_length = past_length = past_key_values[0][0].shape[2] |
|
|
max_cache_length = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
|
|
|
|
|
|
|
|
elif past_length < input_ids.shape[1]: |
|
|
input_ids = input_ids[:, past_length:] |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
max_cache_length is not None |
|
|
and attention_mask is not None |
|
|
and cache_length + input_ids.shape[1] > max_cache_length |
|
|
): |
|
|
attention_mask = attention_mask[:, -max_cache_length:] |
|
|
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if past_key_values: |
|
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
|
else: |
|
|
model_inputs = {"input_ids": input_ids} |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"position_ids": position_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache"), |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
) |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
def model_pre_forward(self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
): |
|
|
output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.model.config.use_cache |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict |
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
batch_size, seq_length = input_ids.shape[:2] |
|
|
elif inputs_embeds is not None: |
|
|
batch_size, seq_length = inputs_embeds.shape[:2] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
past_key_values_length = 0 |
|
|
|
|
|
if self.model.gradient_checkpointing and self.model.training: |
|
|
if use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
) |
|
|
use_cache = False |
|
|
use_legacy_cache = None |
|
|
if use_cache: |
|
|
use_legacy_cache = not isinstance(past_key_values, Cache) |
|
|
if use_legacy_cache: |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
if position_ids is None: |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
position_ids = torch.arange( |
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
|
|
) |
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
inputs_embeds = self.model.embed_dropout(inputs_embeds) |
|
|
|
|
|
|
|
|
if self.model._use_flash_attention_2: |
|
|
|
|
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
|
|
elif self.model._use_sdpa and not output_attentions: |
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
|
attention_mask, |
|
|
(batch_size, seq_length), |
|
|
inputs_embeds, |
|
|
past_key_values_length, |
|
|
) |
|
|
else: |
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask( |
|
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
|
|
) |
|
|
|
|
|
return inputs_embeds, use_legacy_cache, attention_mask, position_ids, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict |
|
|
|
|
|
def forward(self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
head_window_size: Optional[int] = None, |
|
|
): |
|
|
use_cache = False |
|
|
inputs_embeds, use_legacy_cache, attention_mask, position_ids, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict = self.model_pre_forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) |
|
|
|
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
all_logits = () |
|
|
is_early_exited = False |
|
|
next_decoder_cache = None |
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
seq_length = hidden_states.shape[1] |
|
|
device = hidden_states.device |
|
|
|
|
|
|
|
|
early_exit_mask = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
exit_layer = torch.full((batch_size,), self.num_layers, dtype=torch.long, device=device) |
|
|
final_logits = torch.zeros((batch_size, seq_length, self.vocab_size), device=device) |
|
|
|
|
|
for layer, decoder_layer in enumerate(self.model.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if self.model.gradient_checkpointing and self.model.training: |
|
|
layer_outputs, use_legacy_cache = self.model._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
) |
|
|
hidden_states = layer_outputs[0] |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
if layer in self.config.branch_locations: |
|
|
branch_logits = self.branches[self.config.branch_locations.index(layer)](layer_outputs[0]) |
|
|
if not self.training: |
|
|
|
|
|
scores = self.confidence_metric_fn(branch_logits)[..., -1] |
|
|
exit_samples = (scores > self.head_thresholds[self.config.branch_locations.index(layer)]) & ~early_exit_mask |
|
|
early_exit_mask |= exit_samples |
|
|
exit_layer[exit_samples] = layer |
|
|
final_logits[exit_samples] = branch_logits[exit_samples] |
|
|
|
|
|
if early_exit_mask.all(): |
|
|
break |
|
|
else: |
|
|
|
|
|
all_logits += (branch_logits,) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
if not early_exit_mask.all(): |
|
|
remaining_hidden_states = hidden_states[~early_exit_mask] |
|
|
remaining_hidden_states = self.model.final_layernorm(remaining_hidden_states) |
|
|
remaining_logits = self.lm_head(remaining_hidden_states) |
|
|
final_logits[~early_exit_mask] = remaining_logits |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
next_cache = None |
|
|
if use_cache: |
|
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
loss = [None, None, None, None] |
|
|
if self.training: |
|
|
loss = self.compute_self_supervision_loss( |
|
|
torch.stack(all_logits), hidden_states |
|
|
) |
|
|
if not return_dict: |
|
|
raise NotImplementedError("return_dict=False is not implemented") |
|
|
|
|
|
return CausalBranchyLLMOutputWithPast( |
|
|
loss=loss[0], |
|
|
head_loss=loss[1], |
|
|
entropies=loss[2], |
|
|
entropy=loss[3], |
|
|
logits=final_logits, |
|
|
head_logits=all_logits, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
head_indices=exit_layer, |
|
|
) |
|
|
|
|
|
def compute_self_supervision_loss(self, |
|
|
aux_logits: torch.Tensor, |
|
|
lm_logits: torch.Tensor, |
|
|
return_dict: bool = True |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
last_aux_logits = aux_logits[..., -1, :] |
|
|
last_lm_logits = lm_logits[..., -1, :] |
|
|
|
|
|
losses = () |
|
|
entropies = () |
|
|
|
|
|
for head_logit in last_aux_logits: |
|
|
ce_loss = nn.CrossEntropyLoss(reduction="mean")( |
|
|
head_logit, torch.argmax(last_lm_logits, dim=-1) |
|
|
) |
|
|
probas = F.softmax(head_logit, dim=-1) |
|
|
log_probas = torch.log(probas + 1e-8) |
|
|
assert not torch.isnan(log_probas).any(), "NaNs found in log_probas" |
|
|
entropy = -torch.sum(probas * log_probas, dim=-1) |
|
|
assert not torch.isnan(entropy).any(), "NaNs found in entropy before mean" |
|
|
entropy = torch.mean(entropy) |
|
|
entropies += (entropy,) |
|
|
losses += ((1 - self.config.penalty_weight) * ce_loss - self.config.penalty_weight * entropy,) |
|
|
|
|
|
loss = torch.stack(losses, dim=0).mean(dim=-1) |
|
|
entropy = torch.stack(entropies, dim=0).mean(dim=-1) |
|
|
if not return_dict: |
|
|
return tuple(v for v in (loss, losses, entropy, entropies) if v is not None) |
|
|
return SelfSupervisedLossOutput( |
|
|
loss=loss, |
|
|
head_losses= losses, |
|
|
entropies= entropies, |
|
|
entropy= entropy |
|
|
) |
|
|
|
|
|
@dataclass |
|
|
class CausalBranchyLLMOutputWithPast(ModelOutput): |
|
|
loss: Optional[torch.Tensor] = None |
|
|
head_loss: Optional[torch.Tensor] = None |
|
|
entropy: Optional[torch.Tensor] = None |
|
|
entropies: Optional[Tuple[torch.Tensor]] = None |
|
|
logits: torch.Tensor = None |
|
|
head_logits: Optional[torch.Tensor] = None |
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
head_indices: Optional[torch.Tensor] = None |
|
|
|
|
|
@dataclass |
|
|
class SelfSupervisedLossOutput(ModelOutput): |
|
|
loss: torch.Tensor = None |
|
|
head_losses: torch.Tensor = None |
|
|
entropy: torch.Tensor = None |
|
|
entropies: torch.Tensor = None |