|
|
from typing import List, Optional |
|
|
from transformers import PretrainedConfig |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BranchyModelConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration class for BranchyModel. This class extends the PretrainedConfig class from the Transformers |
|
|
library, providing configuration specific to models with branch functionality. |
|
|
|
|
|
Attributes: |
|
|
branch_locations (List[int]): Specifies the indices of layers after which branches are added. These indices |
|
|
start from 0, and each index represents a layer in the underlying transformer model. |
|
|
penalty_weight (Optional[float]): The weight of the penalty term used in the "penalized_cross_entropy" loss. |
|
|
This parameter is required and must be greater than 0 |
|
|
window_size (int): Determines the number of tokens each branch considers from the input sequence. This allows |
|
|
for reducing the computational load by limiting the context size each branch processes. |
|
|
|
|
|
Example: |
|
|
config = BranchyModelConfig( |
|
|
branch_locations=[2, 4, 6], |
|
|
window_size=512 |
|
|
) |
|
|
|
|
|
Note: |
|
|
This configuration class is specifically designed for use with the BranchyModel class, enabling flexible |
|
|
and customizable branching within transformer models. |
|
|
""" |
|
|
model_type = "branchy" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_str: str = None, |
|
|
head_thresholds: Optional[List[float]] = None, |
|
|
confidence_metric: Optional[str] = "breaking_ties", |
|
|
branch_locations: Optional[List[int]] = None, |
|
|
branch_number: Optional[int] = 3, |
|
|
penalty_weight: Optional[float] = 0, |
|
|
head_window_size: int = 512, |
|
|
copy_lm_head: Optional[bool] = False, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Initializes the BranchyModelConfig. |
|
|
|
|
|
Args: |
|
|
model_str (str): The model string to be used for the model. From Huggingface's model hub. |
|
|
branch_locations (List[int], optional): Locations of the branches. Defaults to None, indicating no branches. |
|
|
branch_number (Optional[int], optional): Number of branches if branch_locations is not provided. Defaults to 3. |
|
|
penalty_weight (Optional[float], optional): Weight for the penalty in loss calculation. |
|
|
. Defaults to None. |
|
|
head_window_size (int, optional): Number of tokens each branch can see. Defaults to 512. |
|
|
""" |
|
|
self.model_str = model_str |
|
|
self.head_thresholds = head_thresholds |
|
|
self.confidence_metric = confidence_metric |
|
|
assert self.confidence_metric in ["breaking_ties", "max"], "confidence_metric must be 'breaking_ties' or 'max'. It should depend on how you found the thresholds." |
|
|
self.branch_locations = branch_locations |
|
|
self.penalty_weight = penalty_weight |
|
|
self.head_window_size = head_window_size |
|
|
if branch_locations is not None and branch_number is not None: |
|
|
logger.warning("Both branch_locations and branch_number are provided. Using branch_locations.") |
|
|
self.branch_number = branch_number if branch_locations is None else len(branch_locations) |
|
|
self.copy_lm_head = copy_lm_head |
|
|
|
|
|
assert self.branch_number > 0, "branch_number must be a positive integer." |
|
|
assert isinstance(self.penalty_weight, float) or isinstance(self.penalty_weight, int), "penalty_weight must be a float or an integer." |
|
|
assert self.penalty_weight >= 0 and self.penalty_weight <= 1, "penalty_weight must be in the range [0, 1]." |
|
|
if branch_locations is not None: |
|
|
assert all([isinstance(loc, int) for loc in self.branch_locations]), "Branch locations must be integers." |
|
|
assert all([loc >= 0 for loc in self.branch_locations]), "Branch locations must be non-negative." |
|
|
if self.head_window_size is not None: |
|
|
assert self.head_window_size > 0 , "head_window_size must be a positive integer or None." |
|
|
if type(self.head_thresholds) == list: |
|
|
assert len(self.head_thresholds) == self.branch_number, "Number of thresholds must match number of branches." |
|
|
assert all([isinstance(threshold, float) for threshold in self.head_thresholds]), "Thresholds must be floats." |
|
|
super().__init__(**kwargs) |