File size: 4,499 Bytes
794f115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import List, Optional
from transformers import PretrainedConfig
import logging

logger = logging.getLogger(__name__)

class BranchyModelConfig(PretrainedConfig):
    """
    Configuration class for BranchyModel. This class extends the PretrainedConfig class from the Transformers
    library, providing configuration specific to models with branch functionality.

    Attributes:
        branch_locations (List[int]): Specifies the indices of layers after which branches are added. These indices
            start from 0, and each index represents a layer in the underlying transformer model.
        penalty_weight (Optional[float]): The weight of the penalty term used in the "penalized_cross_entropy" loss.
            This parameter is required and must be greater than 0
        window_size (int): Determines the number of tokens each branch considers from the input sequence. This allows
            for reducing the computational load by limiting the context size each branch processes.

    Example:
        config = BranchyModelConfig(
            branch_locations=[2, 4, 6],
            window_size=512
        )

    Note:
        This configuration class is specifically designed for use with the BranchyModel class, enabling flexible
        and customizable branching within transformer models.
    """
    model_type = "branchy"  # Optional, but useful for identifying the model type in the Transformers library

    def __init__(
        self,
        model_str: str = None,
        head_thresholds: Optional[List[float]] = None,
        confidence_metric: Optional[str] = "breaking_ties",
        branch_locations: Optional[List[int]] = None,
        branch_number: Optional[int] = 3,
        penalty_weight: Optional[float] = 0,
        head_window_size: int = 512,
        copy_lm_head: Optional[bool] = False,
        **kwargs
    ):
        """
        Initializes the BranchyModelConfig.

        Args:
            model_str (str): The model string to be used for the model. From Huggingface's model hub.
            branch_locations (List[int], optional): Locations of the branches. Defaults to None, indicating no branches.
            branch_number (Optional[int], optional): Number of branches if branch_locations is not provided. Defaults to 3.
            penalty_weight (Optional[float], optional): Weight for the penalty in loss calculation.
                . Defaults to None.
            head_window_size (int, optional): Number of tokens each branch can see. Defaults to 512.
        """
        self.model_str = model_str
        self.head_thresholds = head_thresholds
        self.confidence_metric = confidence_metric
        assert self.confidence_metric in ["breaking_ties", "max"], "confidence_metric must be 'breaking_ties' or 'max'. It should depend on how you found the thresholds."
        self.branch_locations = branch_locations 
        self.penalty_weight = penalty_weight
        self.head_window_size = head_window_size
        if branch_locations is not None and branch_number is not None:
            logger.warning("Both branch_locations and branch_number are provided. Using branch_locations.")
        self.branch_number = branch_number if branch_locations is None else len(branch_locations)
        self.copy_lm_head = copy_lm_head
        #assert self.model_str is not None, "model_str must be provided."
        assert self.branch_number > 0, "branch_number must be a positive integer."
        assert isinstance(self.penalty_weight, float) or isinstance(self.penalty_weight, int), "penalty_weight must be a float or an integer."
        assert self.penalty_weight >= 0 and self.penalty_weight <= 1, "penalty_weight must be in the range [0, 1]."
        if branch_locations is not None:
            assert all([isinstance(loc, int) for loc in self.branch_locations]), "Branch locations must be integers."
            assert all([loc >= 0 for loc in self.branch_locations]), "Branch locations must be non-negative."
        if self.head_window_size is not None:
            assert self.head_window_size > 0 , "head_window_size must be a positive integer or None."
        if type(self.head_thresholds) == list:
            assert len(self.head_thresholds) == self.branch_number, "Number of thresholds must match number of branches."
            assert all([isinstance(threshold, float) for threshold in self.head_thresholds]), "Thresholds must be floats."
        super().__init__(**kwargs)  # Initialize with base class parameters