File size: 20,594 Bytes
11d5b8c f3d35ec 11d5b8c 4ec421a 11d5b8c 4ec421a f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c 4ec421a 1e4dad3 4ec421a f3d35ec 4ec421a f3d35ec 4ec421a 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b 11d5b8c f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 24bcb4b f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec 11d5b8c f3d35ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 |
from collections import OrderedDict
from hamcrest import is_
import torch
import logging
import torch.nn as nn
import torch.nn.functional as F
import copy
from dataclasses import dataclass
from torch import Tensor
from .BranchyModelConfig import BranchyModelConfig
from typing import List, Optional, Dict, Tuple
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def breaking_ties(tensor: torch.Tensor):
"""
Break ties in a tensor by subtracting the second highest value from the highest value.
Args:
tensor (torch.Tensor): The tensor to break ties in. shape [..., vocab_size]
Returns:
torch.Tensor: The tensor with ties broken. shape [...]
Example:
Input : Tensor of shape [head_number, batch, seq_len, vocab_size]
Output: Tensor of shape [head_number, batch, seq_len]
"""
return torch.sub(torch.topk(tensor, 2, dim=-1).values[..., 0], torch.topk(tensor, 2, dim=-1).values[..., 1])
class Branch(nn.Module):
"""
A branch module for use in the BranchyModel, representing an auxiliary output head attached at a specified layer
within a transformer model. Each branch processes the output of its corresponding layer and produces an output
which can be used for early exits or auxiliary tasks.
This class is designed to be flexible, allowing for different configurations of the linear layer based on the
underlying model's architecture.
Attributes:
layernorm (torch.nn.LayerNorm): Applies Layer Normalization over a mini-batch of inputs.
lm_head (torch.nn.Linear): The linear layer that maps the hidden states to the vocabulary size, producing
the output logits for each token in the sequence.
Example Usage:
# Assuming `config` is an instance of the model's configuration class with attributes `hidden_size` and
# `vocab_size` properly set.
branch = Branch(config)
# `x` is a tensor representing the output from a transformer layer, shaped as [batch_size, seq_length, hidden_size]
output_logits = branch(x)
"""
def __init__(self, config: BranchyModelConfig):
"""
Initializes the Branch module.
Args:
config (PretrainedConfig): The configuration object containing parameters like hidden size and vocabulary
size. This object provides the necessary settings for initializing the layer normalization and linear
layers within the Branch.
"""
super().__init__()
self.layernorm: nn.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head: nn.Linear = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the Branch module.
Args:
x (Tensor): Input tensor of shape [batch_size, seq_length, hidden_size], representing the output
from a transformer layer.
Returns:
Tensor: Output logits of shape [batch_size, seq_length, vocab_size], resulting from passing the
input through layer normalization and a linear layer.
"""
x = self.layernorm(x)
x = self.lm_head(x)
return x
class BranchyCausalModel(PreTrainedModel):
"""A class for Causal branchy Model, this one integrate the early exit mechanism and only output one logit on each step as a conventional model.
"""
config_class = BranchyModelConfig
def __init__(self,
config: BranchyModelConfig):
super().__init__(config)
self.model = AutoModelForCausalLM.from_pretrained(config.model_str)
self.lm_head = self.model.lm_head
self.vocab_size = self.model.vocab_size
self.model = self.model.model
self.head_thresholds = torch.tensor(config.head_thresholds)
self.confidence_metric_fn = breaking_ties
# Get number of layer from main model
if hasattr(self.model.config, "n_layer") or hasattr(self.model.config, "num_hidden_layers"):
self.num_layers = (
self.model.config.n_layer
if hasattr(self.model.config, "n_layer")
else self.model.config.num_hidden_layers
)
assert self.num_layers is not None and self.num_layers > 0, "n_layer must be a positive integer."
else:
raise ValueError("cannot find n_layer in config")
assert config.branch_number < self.num_layers , "branch_number must be a positive integer less than the number of layers in the model."
# If we provide only the number of branches, we will distribute them evenly across the model
if config.branch_locations is None:
interval = self.num_layers // (config.branch_number + 1)
config.branch_locations = [i * interval for i in range(1, config.branch_number+1)]
# Check that specified branch locations are within the range of the model's layers
if any([loc >= self.num_layers for loc in config.branch_locations]):
raise ValueError("Branch location exceeds the number of layers in the model.")
self.branches = torch.nn.ModuleList()
if config.copy_lm_head:
logger.info("Fine-tuning branches")
for branch in config.branch_locations:
self.branches.append(copy.deepcopy(self.lm_head))
else:
for _ in config.branch_locations:
new_branch = Branch(self.model.config)
new_branch.apply(self.model._init_weights)
self.branches.append(new_branch)
self.gradient_checkpointing = False
self.post_init()
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.model = self.model.to(*args, **kwargs)
self.head_thresholds = self.head_thresholds.to(*args, **kwargs)
return self
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def model_pre_forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.model.config.use_cache
return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if self.model.gradient_checkpointing and self.model.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
use_legacy_cache = None
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
inputs_embeds = self.model.embed_dropout(inputs_embeds)
# Attention mask.
if self.model._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.model._use_sdpa and not output_attentions:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
return inputs_embeds, use_legacy_cache, attention_mask, position_ids, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
head_window_size: Optional[int] = None,
):
use_cache = False # Disable it for now TODO Update how cache is handled to allow early exits
inputs_embeds, use_legacy_cache, attention_mask, position_ids, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict = self.model_pre_forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_logits = ()
is_early_exited = False
next_decoder_cache = None
batch_size = hidden_states.shape[0]
seq_length = hidden_states.shape[1]
device = hidden_states.device
# Track which samples have exited early
early_exit_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
exit_layer = torch.full((batch_size,), self.num_layers, dtype=torch.long, device=device)
final_logits = torch.zeros((batch_size, seq_length, self.vocab_size), device=device)
for layer, decoder_layer in enumerate(self.model.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.model.gradient_checkpointing and self.model.training:
layer_outputs, use_legacy_cache = self.model._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
if layer in self.config.branch_locations:
branch_logits = self.branches[self.config.branch_locations.index(layer)](layer_outputs[0])
if not self.training:
# During inference, calculate score on the fly to decide if we should early exit
scores = self.confidence_metric_fn(branch_logits)[..., -1]
exit_samples = (scores > self.head_thresholds[self.config.branch_locations.index(layer)]) & ~early_exit_mask
early_exit_mask |= exit_samples
exit_layer[exit_samples] = layer
final_logits[exit_samples] = branch_logits[exit_samples]
if early_exit_mask.all():
break # All samples have exited early
else:
# if in training we return full logits
all_logits += (branch_logits,)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if not early_exit_mask.all():
remaining_hidden_states = hidden_states[~early_exit_mask]
remaining_hidden_states = self.model.final_layernorm(remaining_hidden_states)
remaining_logits = self.lm_head(remaining_hidden_states)
final_logits[~early_exit_mask] = remaining_logits
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
loss = [None, None, None, None]
if self.training:
loss = self.compute_self_supervision_loss(
torch.stack(all_logits), hidden_states
)
if not return_dict:
raise NotImplementedError("return_dict=False is not implemented")
return CausalBranchyLLMOutputWithPast(
loss=loss[0],
head_loss=loss[1],
entropies=loss[2],
entropy=loss[3],
logits=final_logits,
head_logits=all_logits,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
head_indices=exit_layer,
)
def compute_self_supervision_loss(self,
aux_logits: torch.Tensor,
lm_logits: torch.Tensor,
return_dict: bool = True
) -> Dict[str, torch.Tensor]:
last_aux_logits = aux_logits[..., -1, :]
last_lm_logits = lm_logits[..., -1, :]
losses = ()
entropies = ()
# Can be useful to have detailed loss per head for comparison of performance
for head_logit in last_aux_logits:
ce_loss = nn.CrossEntropyLoss(reduction="mean")(
head_logit, torch.argmax(last_lm_logits, dim=-1)
)
probas = F.softmax(head_logit, dim=-1)
log_probas = torch.log(probas + 1e-8)
assert not torch.isnan(log_probas).any(), "NaNs found in log_probas"
entropy = -torch.sum(probas * log_probas, dim=-1)
assert not torch.isnan(entropy).any(), "NaNs found in entropy before mean"
entropy = torch.mean(entropy)
entropies += (entropy,)
losses += ((1 - self.config.penalty_weight) * ce_loss - self.config.penalty_weight * entropy,)
loss = torch.stack(losses, dim=0).mean(dim=-1)
entropy = torch.stack(entropies, dim=0).mean(dim=-1)
if not return_dict:
return tuple(v for v in (loss, losses, entropy, entropies) if v is not None)
return SelfSupervisedLossOutput(
loss=loss,
head_losses= losses,
entropies= entropies,
entropy= entropy
)
@dataclass
class CausalBranchyLLMOutputWithPast(ModelOutput):
loss: Optional[torch.Tensor] = None # Main loss
head_loss: Optional[torch.Tensor] = None
entropy: Optional[torch.Tensor] = None
entropies: Optional[Tuple[torch.Tensor]] = None
logits: torch.Tensor = None
head_logits: Optional[torch.Tensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
head_indices: Optional[torch.Tensor] = None
@dataclass
class SelfSupervisedLossOutput(ModelOutput):
loss: torch.Tensor = None
head_losses: torch.Tensor = None
entropy: torch.Tensor = None
entropies: torch.Tensor = None |