Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| class PartiallyFrozenEmbedding(nn.Module): | |
| """Split an existing `nn.Embedding` module that splits the embedding into: | |
| - A frozen embedding for indices [0..freeze_until_idx]. | |
| - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1]. | |
| This should work with both Zero-2 and Zero-3 seamlessly | |
| """ | |
| def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int): | |
| """ | |
| :param original_embedding: An instance of nn.Embedding (the original embedding layer). | |
| :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen. | |
| """ | |
| super().__init__() | |
| self.freeze_until_idx = freeze_until_idx | |
| self.original_vocab_size = original_embedding.num_embeddings | |
| self.embedding_dim = original_embedding.embedding_dim | |
| # Split the original embedding into frozen and trainable parts | |
| self.embedding_frozen = nn.Embedding( | |
| freeze_until_idx, | |
| self.embedding_dim, | |
| dtype=original_embedding.weight.dtype, | |
| device=original_embedding.weight.device, | |
| ) | |
| self.embedding_trainable = nn.Embedding( | |
| self.original_vocab_size - freeze_until_idx, | |
| self.embedding_dim, | |
| dtype=original_embedding.weight.dtype, | |
| device=original_embedding.weight.device, | |
| ) | |
| # Copy weights from the original embedding into the frozen and trainable parts | |
| with torch.no_grad(): | |
| self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx]) | |
| self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:]) | |
| # Freeze the frozen embedding | |
| self.embedding_frozen.weight.requires_grad = False | |
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the split embedding wrapper. | |
| :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1]. | |
| """ | |
| # Masks to separate frozen and trainable indices | |
| # (bsz, seq_len) | |
| mask_frozen = input_ids < self.freeze_until_idx | |
| mask_trainable = ~mask_frozen | |
| # Output tensor for embedding results | |
| batch_size, seq_len = input_ids.shape | |
| embeddings = torch.zeros( | |
| batch_size, | |
| seq_len, | |
| self.embedding_dim, | |
| device=input_ids.device, | |
| dtype=self.embedding_frozen.weight.dtype, | |
| ) | |
| # Handle frozen embedding | |
| if mask_frozen.any(): | |
| frozen_ids = input_ids[mask_frozen] | |
| frozen_emb = self.embedding_frozen(frozen_ids) | |
| embeddings[mask_frozen] = frozen_emb | |
| # Handle trainable embedding | |
| if mask_trainable.any(): | |
| # Adjust trainable IDs to the local index space of the trainable embedding | |
| trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx) | |
| trainable_emb = self.embedding_trainable(trainable_ids) | |
| embeddings[mask_trainable] = trainable_emb | |
| return embeddings | |
| def to_unsplit(self) -> nn.Embedding: | |
| unsplit_embedding = nn.Embedding( | |
| self.original_vocab_size, | |
| self.embedding_dim, | |
| dtype=self.embedding_frozen.weight.dtype, | |
| device=self.embedding_frozen.weight.device, | |
| ) | |
| with torch.no_grad(): | |
| unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight) | |
| unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight) | |
| return unsplit_embedding | |
| class PartiallyFrozenLinear(nn.Module): | |
| """A wrapper around nn.Linear to partially freeze part of the weight matrix.""" | |
| def __init__(self, original_linear: nn.Linear, freeze_until_idx: int): | |
| """ | |
| :param original_linear: The original nn.Linear layer. | |
| :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen. | |
| """ | |
| super().__init__() | |
| assert original_linear.bias is None, "Currently only support linear module without bias" | |
| self.freeze_until_idx = freeze_until_idx | |
| self.input_dim = original_linear.in_features | |
| self.output_dim = original_linear.out_features | |
| # Create frozen and trainable linear layers | |
| self.linear_frozen = nn.Linear( | |
| self.input_dim, | |
| freeze_until_idx, | |
| bias=False, | |
| dtype=original_linear.weight.dtype, | |
| device=original_linear.weight.device, | |
| ) | |
| self.linear_trainable = nn.Linear( | |
| self.input_dim, | |
| self.output_dim - freeze_until_idx, | |
| bias=False, | |
| dtype=original_linear.weight.dtype, | |
| device=original_linear.weight.device, | |
| ) | |
| # Copy weights from the original linear layer | |
| with torch.no_grad(): | |
| self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx]) | |
| self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:]) | |
| # Freeze the frozen linear layer | |
| self.linear_frozen.weight.requires_grad = False | |
| def forward(self, input_tensor): | |
| # input_tensor: (bsz, seq_len, hidden_state_dim) | |
| frozen_output = self.linear_frozen(input_tensor) | |
| trainable_output = self.linear_trainable(input_tensor) | |
| return torch.cat((frozen_output, trainable_output), dim=-1) | |
| def to_unsplit(self) -> nn.Linear: | |
| unsplit_linear = nn.Linear( | |
| self.input_dim, | |
| self.output_dim, | |
| bias=False, | |
| dtype=self.linear_frozen.weight.dtype, | |
| device=self.linear_frozen.weight.device, | |
| ) | |
| # Copy weights from the frozen and trainable layers into the unsplit linear layer | |
| with torch.no_grad(): | |
| unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight) | |
| unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight) | |
| return unsplit_linear | |